Learn practical skills, build real-world projects, and advance your career

Working with Images & Logistic Regression in PyTorch

Part 3 of "Deep Learning with Pytorch: Zero to GANs"

This tutorial covers the following topics:

  • Working with images in PyTorch (using the MNIST dataset)
  • Splitting a dataset into training, validation, and test sets
  • Creating PyTorch models with custom logic by extending the nn.Module class
  • Interpreting model outputs as probabilities using Softmax and picking predicted labels
  • Picking a useful evaluation metric (accuracy) and loss function (cross-entropy) for classification problems
  • Setting up a training loop that also evaluates the model using the validation set
  • Testing the model manually on randomly picked examples
  • Saving and loading model checkpoints to avoid retraining from scratch

Working with Images

In this tutorial, we'll use our existing knowledge of PyTorch and linear regression to solve a very different kind of problem: image classification. We'll use the famous MNIST Handwritten Digits Database as our training dataset. It consists of 28px by 28px grayscale images of handwritten digits (0 to 9) and labels for each image indicating which digit it represents. Here are some sample images from the dataset:

mnist-sample

We begin by installing and importing torch and torchvision. torchvision contains some utilities for working with image data. It also provides helper classes to download and import popular datasets like MNIST automatically

# Uncomment and run the appropriate command for your operating system, if required

# Linux / Binder
# !pip install numpy matplotlib torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html

# Windows
# !pip install numpy matplotlib torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html

# MacOS
# !pip install numpy matplotlib torch torchvision torchaudio