Learn practical skills, build real-world projects, and advance your career
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
%matplotlib inline
dataset = CIFAR10(root='data/', train=True, transform=ToTensor())
test_dataset = CIFAR10(root='data/', train=False, transform=ToTensor())
dataset_size = len(dataset)
test_dataset_size = len(test_dataset)
print(dataset_size, test_dataset_size)
50000 10000
classes = dataset.classes
print(classes)
num_classes = len(classes)
print("num classes: " + str(num_classes))
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] num classes: 10
img, label = dataset[0]
print(img.shape)
plt.imshow(img.permute(1,2,0))
torch.Size([3, 32, 32])
<matplotlib.image.AxesImage at 0x7fbb34527be0>
Notebook Image