Learn practical skills, build real-world projects, and advance your career

Introduction to Generative Modeling

Training Generative Adversarial Networks (GANs) in PyTorch

Deep neural networks are used mainly for supervised learning: classification or regression. Generative Adversarial Networks or GANs, however, use neural networks for a very different purpose: Generative modeling

Generative modeling is an unsupervised learning task in machine learning that involves automatically discovering and learning the regularities or patterns in input data in such a way that the model can be used to generate or output new examples that plausibly could have been drawn from the original dataset. - Source

To get a sense of the power of generative models, just visit thispersondoesnotexist.com. Every time you reload the page, a new image of a person's face is generated on the fly. The results are pretty fascinating:

alt

There are two neural networks: a Generator and a Discriminator. The generator generates a "fake" sample given a random vector/matrix, and the discriminator attempts to detect whether a given sample is "real" (picked from the training data) or "fake" (generated by the generator). Training happens in tandem: we train the discriminator for a few epochs, then train the generator for a few epochs, and repeat. This way both the generator and the discriminator get better at doing their jobs.

GANs however, can be notoriously difficult to train, and are extremely sensitive to hyperparameters, activation functions and regularization. In this tutorial, we'll train a GAN to generate images of anime characters' faces.

We'll use the Anime Face Dataset, which consists of over 63,000 cropped anime faces. Note that generative modeling is an unsupervised learning task, so the images do not have any labels. Most of the code in this tutorial is based on this notebook.

import opendatasets as od

#dataset_url = 'https://www.kaggle.com/splcher/animefacedataset'
#od.download(dataset_url)
import os

DATA_DIR = './data/animefacedataset'
print(os.listdir(DATA_DIR))
['images']
print(os.listdir(DATA_DIR+'/images')[:10])
['0_2000.jpg', '10000_2004.jpg', '10001_2004.jpg', '10002_2004.jpg', '10003_2004.jpg', '10004_2004.jpg', '10005_2004.jpg', '10006_2004.jpg', '10007_2004.jpg', '10008_2004.jpg']
"""
Let's load this dataset using the ImageFolder class from torchvision. We will also resize and crop the images to 64x64 px, and normalize the pixel values with a mean & standard deviation of 0.5 for each channel. This will ensure that pixel values are in the range (-1, 1), which is more convenient for training the discriminator. We will also create a data loader to load the data in batches.
"""
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T