Learn practical skills, build real-world projects, and advance your career

Mortys generation with Generative Adversarial Networks in PyTorch

I have been thinking about the final course project - what should I do? Classification? Regression? But these tasks are very common and I have already practised on them. I wanted to try something unusual and specific, geeky.

Generative Adversarial Networks (GANs) have always admired me by their mysterious nature. And with no doubts, I decided finally to make some fun and study the basic principles of this type of neural networks.

What should I generate? Numbers? Text? Cats? Dogs? Birds? Fruits? Clouds? No, this is too mainstream.
Many objects, something unique but at the same time the same, ... many... Eureka! The choice of objects to generate came into my mind literally immediately!
I would generate Mortys!
As you are a fan of this perfect cartoon and an ML/DL enthusiast - let's take an adventure to GANs - in and out!

alt

Generative Adversarial Networks in PyTorch

Deep neural networks are used mainly for supervised learning: classification or regression. Generative Adversarial Networks or GANs, however, use neural networks for a very different purpose: Generative modeling

Generative modeling is an unsupervised learning task in machine learning that involves automatically discovering and learning the regularities or patterns in input data in such a way that the model can be used to generate or output new examples that plausibly could have been drawn from the original dataset. - Source

Probably one of the most popular example of GUNs is manipulations with person's face. To get a sense of the power of generative models, just visit thispersondoesnotexist.com. Every time you reload the page, a new image of a person's face is generated on the fly. The results are pretty fascinating!

While there are many approaches used for generative modeling, a Generative Adversarial Network takes the following approach:

alt

Source

There are two neural networks: a Generator and a Discriminator. The generator generates a "fake" sample given a random vector/matrix, and the discriminator attempts to detect whether a given sample is "real" (picked from the training data) or "fake" (generated by the generator). Training happens in tandem: we train the discriminator for a few epochs, then train the generator for a few epochs, and repeat. This way both the generator and the discriminator get better at doing their jobs.

GANs however, can be notoriously difficult to train, and are extremely sensitive to hyperparameters, activation functions and regularization. In this tutorial, we'll train a GAN to generate images of anime characters' faces.

Dataset

I tried to find prepared datasets, but I faced with a problem - no appropriate dataset with Rick and Morty images exists! Downloading images manually or from Google or Bing didn't bring me fascinating results as the received data was noisy, too vary and required additional preprocessing and time.
Maybe it's ok for classification, but I chased another purpose.
As source images, I used sprites from Pocket Mortys, also known as Rick and Morty.

Pocket Mortys, is a free-to-play Rick and Morty-themed role-playing video game developed by Big Pixel Studios and published by Adult Swim Games. The game was released worldwide on 13 January 2016 for iOS and Android devices. The game is based on Rick and Morty and the mechanics serve as a parody of the Pokémon franchise.

alt

So I wrote a scraper for Pocket Rick and Morty site using BeautifulSoup library. There you can get all information about any character in Pocketmortys:
attack, type, speed, characteristics, description and of course a bunch of images!

Some programming and woala - 300+ unique and bizarre Mortys and other characters in your pocket. Cronenberg Morty, Pickle Morty, Car Morty, Beth, Rick, Jerry, Summer... and more!
Let's go further deeper!

After all preparation, the final dataset consists of over 2000+ avatars.

Note that generative modelling is an unsupervised learning task, so the images do not have any labels.

project_name = 'rickmortygan'
# Uncomment and install libraries if imports fail
# !conda install numpy pandas pytorch torchvision cpuonly -c pytorch -y
# !pip install matplotlib --upgrade --quiet

Loading the Data

I stored all data on my Google Drive disk, so some additional code for mounting, DATA_DIR variable below is used to point to the relative path of the dataset.