Learn practical skills, build real-world projects, and advance your career

Working with Images & Logistic Regression in PyTorch

Part 3 of "Deep Learning with Pytorch: Zero to GANs"

This tutorial series is a hands-on beginner-friendly introduction to deep learning using PyTorch, an open-source neural networks library. These tutorials take a practical and coding-focused approach. The best way to learn the material is to execute the code and experiment with it yourself. Check out the full series here:

  1. PyTorch Basics: Tensors & Gradients
  2. Gradient Descent & Linear Regression
  3. Working with Images & Logistic Regression
  4. Training Deep Neural Networks on a GPU
  5. Image Classification using Convolutional Neural Networks
  6. Data Augmentation, Regularization and ResNets
  7. Generating Images using Generative Adversarial Networks

This tutorial covers the following topics:

  • Working with images in PyTorch (using the MNIST dataset)
  • Splitting a dataset into training, validation, and test sets
  • Creating PyTorch models with custom logic by extending the nn.Module class
  • Interpreting model outputs as probabilities using Softmax and picking predicted labels
  • Picking a useful evaluation metric (accuracy) and loss function (cross-entropy) for classification problems
  • Setting up a training loop that also evaluates the model using the validation set
  • Testing the model manually on randomly picked examples
  • Saving and loading model checkpoints to avoid retraining from scratch

Working with Images

In this tutorial, we'll use our existing knowledge of PyTorch and linear regression to solve a very different kind of problem: image classification. We'll use the famous MNIST Handwritten Digits Database as our training dataset. It consists of 28px by 28px grayscale images of handwritten digits (0 to 9) and labels for each image indicating which digit it represents. Here are some sample images from the dataset:

mnist-sample

We begin by installing and importing torch and torchvision. torchvision contains some utilities for working with image data. It also provides helper classes to download and import popular datasets like MNIST automatically

# Imports
import torch
import torchvision
from torchvision.datasets import MNIST