Learn data science and machine learning by building real-world projects on Jovian

Training Generative Adversarial Networks (GANs) in PyTorch

- by Vivek Patel

Before we start...

What is Generative adversarial networks (GANs)?

Generative adversarial networks (GANs) are an exciting recent innovation in machine learning. GANs are generative models: they create new data instances that resemble your training data. For example, GANs can create images that look like photographs of human faces, even though the faces don't belong to any real person.

These images were created by a GAN:

more info at https://thispersondoesnotexist.com/

This notebook is a minimum demo for GAN. Since PaperSpace allows maximum run time limit of 6 hrs in free tier, we will only train a lightweight model in this notebook.

Background: What is a Generative Model?

What does "generative" mean in the name "Generative Adversarial Network"? "Generative" describes a class of statistical models that contrasts with discriminative models. Informally:

  • Generative models can generate new data instances.
  • Discriminative models discriminate between different kinds of data instances.
in other words There are two neural networks: a Generator and a Discriminator.
  • The generator generates a latent vector(random vector/random tensor/random matrix)
  • the discriminator attempts to detect whether a given sample is "real" (picked from the training data) or "fake" (generated by the generator).
Training happens in tandem: we train the discriminator for a few epochs, then train the generator for a few epochs, and repeat. This way both the generator and the discriminator get better at doing their jobs.

1. About dataset

Orignal dataset :-Wiki-Art : Visual Art Encyclopedia

Subset dataset :- art-portraits

Orignal dataset size :- 38GB

Subset dataset size :- 1.46GB

Folder Structure

. /root ├── Training GAN on gpu.ipynb ├── art-portraits │ └── Portraits └── image-1 └── image-2 └── image-n

objective

modeling objective Generative modeling is an unsupervised learning task in machine learning that involves automatically discovering and learning the regularities or patterns in input data in such a way that the model can be used to generate or output new examples that plausibly could have been drawn from the original dataset.
Let's install required libraries
# Uncomment and run the appropriate command for your operating system, if required
# No installation is reqiured on Google Colab / Kaggle notebooks

# Linux / Binder / Windows (No GPU)
# !pip install numpy matplotlib torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html

# Linux / Windows (GPU)
# pip install numpy matplotlib torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
 
# MacOS (NO GPU)
# !pip install numpy matplotlib torch torchvision torchaudio
!pip install opendatasets
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com Collecting opendatasets Downloading opendatasets-0.1.20-py3-none-any.whl (14 kB) Requirement already satisfied: click in /opt/conda/lib/python3.8/site-packages (from opendatasets) (8.0.1) Requirement already satisfied: tqdm in /opt/conda/lib/python3.8/site-packages (from opendatasets) (4.62.3) Collecting kaggle Downloading kaggle-1.5.12.tar.gz (58 kB) |████████████████████████████████| 58 kB 36.4 MB/s eta 0:00:01 Requirement already satisfied: six>=1.10 in /opt/conda/lib/python3.8/site-packages (from kaggle->opendatasets) (1.16.0) Requirement already satisfied: certifi in /opt/conda/lib/python3.8/site-packages (from kaggle->opendatasets) (2021.5.30) Requirement already satisfied: python-dateutil in /opt/conda/lib/python3.8/site-packages (from kaggle->opendatasets) (2.8.2) Requirement already satisfied: requests in /opt/conda/lib/python3.8/site-packages (from kaggle->opendatasets) (2.26.0) Requirement already satisfied: python-slugify in /opt/conda/lib/python3.8/site-packages (from kaggle->opendatasets) (5.0.2) Requirement already satisfied: urllib3 in /opt/conda/lib/python3.8/site-packages (from kaggle->opendatasets) (1.26.7) Requirement already satisfied: text-unidecode>=1.3 in /opt/conda/lib/python3.8/site-packages (from python-slugify->kaggle->opendatasets) (1.3) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests->kaggle->opendatasets) (3.1) Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.8/site-packages (from requests->kaggle->opendatasets) (2.0.0) Building wheels for collected packages: kaggle Building wheel for kaggle (setup.py) ... done Created wheel for kaggle: filename=kaggle-1.5.12-py3-none-any.whl size=73051 sha256=51ab921d4f9a107d3aedccf6e6f8145d74cd73ce0b24ade2e60f43d66c6dd837 Stored in directory: /tmp/pip-ephem-wheel-cache-w7sy1jpp/wheels/29/da/11/144cc25aebdaeb4931b231e25fd34b394e6a5725cbb2f50106 Successfully built kaggle Installing collected packages: kaggle, opendatasets Successfully installed kaggle-1.5.12 opendatasets-0.1.20 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Let's Download Dataset with

opendatasets
import opendatasets as od
dataset_url = 'https://www.kaggle.com/karnikakapoor/art-portraits'
od.download(dataset_url)
Skipping, found downloaded files in "./art-portraits" (use force=True to force download)
Let's Set our data directory
import os

DATA_DIR = 'art-portraits'
print(os.listdir(DATA_DIR))
['Portraits']
Exploring the Data
print(os.listdir(DATA_DIR+'/Portraits')[:10])
['bc1cf30d88ee3291af791f77eac70582c.jpg', 'fe2f7028ade1f27ceddbb15d41a28717c.jpg', 'c74063dc222e6a9a0cbd59ab5bfec9cdc.jpg', 'ecbc995cc283ecf84a2eae0348a786e6c.jpg', 'bb06e5b6cc1e8bb99aa0e92766053c56c.jpg', '5f207999aa8fa9c36ef5e806acf1ffedc.jpg', 'b09e6ab959a1f9500ce8d2e1912c24f9c.jpg', '1b90bcd2a313c074920760ba5968c742c.jpg', '43a37fc480f85ddde62a41962e45e082c.jpg', '7bd7db51e351630d26c5943623f127bdc.jpg']
import cv2
import numpy as np
import matplotlib.pyplot as plt
def show_jpg(jpg_name):
    img = cv2.imread(jpg_name)
    plt.imshow(img)
    plt.show()
show_jpg(DATA_DIR+'/Portraits/'+'bc1cf30d88ee3291af791f77eac70582c.jpg')
Notebook Image
show_jpg('art-portraits/Portraits/00ca56f16c0bae52185ea31f95f0484cc.jpg')
Notebook Image
The dataset has a single folder called images which contains all 4,000+ images in JPG format.
Let's load this dataset using the ImageFolder class from torchvision. We will also resize and crop the images to 64x64 px, and normalize the pixel values with a mean & standard deviation of 0.5 for each channel. This will ensure that pixel values are in the range (-1, 1), which is more convenient for training the discriminator. We will also create a data loader to load the data in batches.
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
image_size = 64
batch_size = 128
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
train_ds = ImageFolder(DATA_DIR, transform=T.Compose([
    T.Resize(image_size),
    T.CenterCrop(image_size),
    T.ToTensor(),
    T.Normalize(*stats)]))

train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)
Let's create helper functions to denormalize the image tensors and display some sample images from a training batch.
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
%matplotlib inline
def denorm(img_tensors):
    return img_tensors * stats[1][0] + stats[0][0]
def show_images(images, nmax=64):
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))

def show_batch(dl, nmax=64):
    for images, _ in dl:
        show_images(images, nmax)
        break
show_batch(train_dl)
Notebook Image
Using a GPU To seamlessly use a GPU, if one is available, we define a couple of helper functions (get_default_device & to_device) and a helper class DeviceDataLoader to move our model & data to the GPU, if one is available.
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)
Based on where you're running this notebook, your default device could be a CPU (torch.device('cpu')) or a GPU (torch.device('cuda')).
device = get_default_device()
device
device(type='cuda')
We can now move our training data loader using DeviceDataLoader for automatically transferring batches of data to the GPU (if available).
train_dl = DeviceDataLoader(train_dl, device)

Discriminator Network


The discriminator is a classifier that determines if the input samples are real or fake. These input samples can be real samples coming from the training data, or fake coming from the generator. We'll use a convolutional neural networks (CNN) which outputs a single number output for every image We'll use stride of 2 to progressively reduce the size of the output feature map, and we're using the Leaky ReLU activation for the discriminator.
import torch.nn as nn
discriminator = nn.Sequential(
    # in: 3 x 64 x 64

    nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 64 x 32 x 32

    nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 128 x 16 x 16

    nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 256 x 8 x 8

    nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2, inplace=True),
    # out: 512 x 4 x 4

    nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
    # out: 1 x 1 x 1

    nn.Flatten(),
    nn.Sigmoid())
Let's move the discriminator model to the chosen device.
discriminator = to_device(discriminator, device)

Generator Network

The input to the generator is typically a vector or a matrix of random numbers (referred to as a latent tensor) which is used as a seed for generating an image. The generator will convert a latent tensor of shape (128, 1, 1) into an image tensor of shape 3 x 28 x 28. To achive this, we'll use the ConvTranspose2d layer from PyTorch, which is performs to as a transposed convolution (also referred to as a deconvolution). Learn more

latent_size = 128
generator = nn.Sequential(
    # in: latent_size x 1 x 1

    nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),
    # out: 512 x 4 x 4

    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(True),
    # out: 256 x 8 x 8

    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),
    # out: 128 x 16 x 16

    nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(True),
    # out: 64 x 32 x 32

    nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
    nn.Tanh()
    # out: 3 x 64 x 64
)
xb = torch.randn(batch_size, latent_size, 1, 1) # random latent tensors
fake_images = generator(xb)
print(fake_images.shape)
show_images(fake_images)
torch.Size([128, 3, 64, 64])