Learn practical skills, build real-world projects, and advance your career

Training Generative Adversarial Networks (GANs) in PyTorch

- by Vivek Patel

Before we start...

What is Generative adversarial networks (GANs)?

Generative adversarial networks (GANs) are an exciting recent innovation in machine learning. GANs are generative models: they create new data instances that resemble your training data. For example, GANs can create images that look like photographs of human faces, even though the faces don't belong to any real person.

These images were created by a GAN:

alt
more info at https://thispersondoesnotexist.com/

This notebook is a minimum demo for GAN. Since PaperSpace allows maximum run time limit of 6 hrs in free tier, we will only train a lightweight model in this notebook.

Background: What is a Generative Model?


What does "generative" mean in the name "Generative Adversarial Network"? "Generative" describes a class of statistical models that contrasts with discriminative models.

Informally:

  • Generative models can generate new data instances.
  • Discriminative models discriminate between different kinds of data instances.
in other words There are two neural networks: a Generator and a Discriminator.
  • The generator generates a latent vector(random vector/random tensor/random matrix)
  • the discriminator attempts to detect whether a given sample is "real" (picked from the training data) or "fake" (generated by the generator).
alt Training happens in tandem: we train the discriminator for a few epochs, then train the generator for a few epochs, and repeat. This way both the generator and the discriminator get better at doing their jobs.

1. About dataset

Orignal dataset :-Wiki-Art : Visual Art Encyclopedia

Subset dataset :- art-portraits

Orignal dataset size :- 38GB

Subset dataset size :- 1.46GB

Folder Structure

. /root ├── Training GAN on gpu.ipynb ├── art-portraits │ └── Portraits └── image-1 └── image-2 └── image-n

objective

modeling objective Generative modeling is an unsupervised learning task in machine learning that involves automatically discovering and learning the regularities or patterns in input data in such a way that the model can be used to generate or output new examples that plausibly could have been drawn from the original dataset.
Let's install required libraries
# Uncomment and run the appropriate command for your operating system, if required
# No installation is reqiured on Google Colab / Kaggle notebooks

# Linux / Binder / Windows (No GPU)
# !pip install numpy matplotlib torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html

# Linux / Windows (GPU)
# pip install numpy matplotlib torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
 
# MacOS (NO GPU)
# !pip install numpy matplotlib torch torchvision torchaudio
!pip install opendatasets
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com Collecting opendatasets Downloading opendatasets-0.1.20-py3-none-any.whl (14 kB) Requirement already satisfied: click in /opt/conda/lib/python3.8/site-packages (from opendatasets) (8.0.1) Requirement already satisfied: tqdm in /opt/conda/lib/python3.8/site-packages (from opendatasets) (4.62.3) Collecting kaggle Downloading kaggle-1.5.12.tar.gz (58 kB) |████████████████████████████████| 58 kB 36.4 MB/s eta 0:00:01 Requirement already satisfied: six>=1.10 in /opt/conda/lib/python3.8/site-packages (from kaggle->opendatasets) (1.16.0) Requirement already satisfied: certifi in /opt/conda/lib/python3.8/site-packages (from kaggle->opendatasets) (2021.5.30) Requirement already satisfied: python-dateutil in /opt/conda/lib/python3.8/site-packages (from kaggle->opendatasets) (2.8.2) Requirement already satisfied: requests in /opt/conda/lib/python3.8/site-packages (from kaggle->opendatasets) (2.26.0) Requirement already satisfied: python-slugify in /opt/conda/lib/python3.8/site-packages (from kaggle->opendatasets) (5.0.2) Requirement already satisfied: urllib3 in /opt/conda/lib/python3.8/site-packages (from kaggle->opendatasets) (1.26.7) Requirement already satisfied: text-unidecode>=1.3 in /opt/conda/lib/python3.8/site-packages (from python-slugify->kaggle->opendatasets) (1.3) Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests->kaggle->opendatasets) (3.1) Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.8/site-packages (from requests->kaggle->opendatasets) (2.0.0) Building wheels for collected packages: kaggle Building wheel for kaggle (setup.py) ... done Created wheel for kaggle: filename=kaggle-1.5.12-py3-none-any.whl size=73051 sha256=51ab921d4f9a107d3aedccf6e6f8145d74cd73ce0b24ade2e60f43d66c6dd837 Stored in directory: /tmp/pip-ephem-wheel-cache-w7sy1jpp/wheels/29/da/11/144cc25aebdaeb4931b231e25fd34b394e6a5725cbb2f50106 Successfully built kaggle Installing collected packages: kaggle, opendatasets Successfully installed kaggle-1.5.12 opendatasets-0.1.20 WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv