Learn practical skills, build real-world projects, and advance your career
Created 2 years ago
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
%matplotlib inline
# Use a white background for matplotlib figures
matplotlib.rcParams['figure.facecolor'] = '#ffffff'
dataset = MNIST(root='data/', download = True, transform=ToTensor())
dataset
Dataset MNIST
Number of datapoints: 60000
Root location: data/
Split: Train
StandardTransform
Transform: ToTensor()
image, label = dataset[0]
print('Label: ', label)
plt.imshow(image[0],cmap='gray')
Label: 5
<matplotlib.image.AxesImage at 0x1a66727f648>
image, label = dataset[90]
print('Label: ', label)
plt.imshow(image[0],cmap='gray');
Label: 6
image, label = dataset[6]
print('Label: ', label)
plt.imshow(image[0],cmap='gray');
Label: 1