Learn practical skills, build real-world projects, and advance your career
import torch
import numpy as np
import torchvision
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataloader import DataLoader
dataset = MNIST(root='data/',
               download = True,
               transform=ToTensor())
0it [00:00, ?it/s]
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
100%|█████████▉| 9904128/9912422 [00:24<00:00, 439229.67it/s]
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz
0it [00:00, ?it/s]
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s] 57%|█████▋ | 16384/28881 [00:00<00:00, 121650.65it/s] 32768it [00:00, 57165.24it/s] 0it [00:00, ?it/s]
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s] 1%| | 16384/1648877 [00:00<00:15, 103299.19it/s] 1%|▏ | 24576/1648877 [00:00<00:18, 89370.30it/s] 2%|▏ | 40960/1648877 [00:00<00:16, 99439.65it/s] 5%|▌ | 90112/1648877 [00:00<00:12, 124048.15it/s] 8%|▊ | 139264/1648877 [00:01<00:10, 144674.85it/s] 11%|█▏ | 188416/1648877 [00:01<00:08, 165096.08it/s] 14%|█▍ | 237568/1648877 [00:01<00:06, 204047.29it/s] 16%|█▋ | 270336/1648877 [00:01<00:06, 200392.31it/s] 18%|█▊ | 303104/1648877 [00:01<00:06, 201366.16it/s] 21%|██▏ | 352256/1648877 [00:01<00:05, 243556.21it/s] 23%|██▎ | 385024/1648877 [00:02<00:05, 229636.63it/s] 26%|██▌ | 425984/1648877 [00:02<00:05, 231926.79it/s] 29%|██▉ | 475136/1648877 [00:02<00:04, 273657.18it/s] 31%|███ | 507904/1648877 [00:02<00:04, 255495.42it/s] 33%|███▎ | 548864/1648877 [00:02<00:03, 283648.13it/s] 35%|███▌ | 581632/1648877 [00:02<00:04, 260024.62it/s] 38%|███▊ | 622592/1648877 [00:02<00:03, 289903.46it/s] 40%|███▉ | 655360/1648877 [00:03<00:03, 257949.91it/s] 43%|████▎ | 704512/1648877 [00:03<00:03, 293010.20it/s] 45%|████▌ | 745472/1648877 [00:03<00:03, 269828.12it/s] 48%|████▊ | 794624/1648877 [00:03<00:02, 303915.32it/s] 51%|█████ | 835584/1648877 [00:03<00:02, 281603.77it/s] 54%|█████▍ | 892928/1648877 [00:03<00:02, 320350.52it/s] 57%|█████▋ | 933888/1648877 [00:03<00:02, 284115.74it/s] 60%|█████▉ | 983040/1648877 [00:03<00:02, 324838.31it/s] 62%|██████▏ | 1024000/1648877 [00:04<00:02, 291585.31it/s] 65%|██████▌ | 1073152/1648877 [00:04<00:01, 320215.92it/s] 68%|██████▊ | 1114112/1648877 [00:04<00:01, 289772.33it/s] 71%|███████ | 1163264/1648877 [00:04<00:01, 319138.71it/s] 73%|███████▎ | 1204224/1648877 [00:04<00:01, 286127.64it/s] 76%|███████▌ | 1253376/1648877 [00:04<00:01, 327057.75it/s] 78%|███████▊ | 1294336/1648877 [00:05<00:01, 296435.58it/s] 82%|████████▏ | 1351680/1648877 [00:05<00:00, 336907.23it/s] 85%|████████▍ | 1400832/1648877 [00:05<00:00, 368961.15it/s] 87%|████████▋ | 1441792/1648877 [00:05<00:00, 302516.28it/s] 90%|█████████ | 1490944/1648877 [00:05<00:00, 336593.51it/s] 93%|█████████▎| 1531904/1648877 [00:05<00:00, 299923.42it/s] 96%|█████████▌| 1581056/1648877 [00:05<00:00, 339294.97it/s] 99%|█████████▉| 1638400/1648877 [00:05<00:00, 376406.82it/s] 0it [00:00, ?it/s]
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s] 8192it [00:00, 38181.01it/s]
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz Processing... Done!
9920512it [00:39, 439229.67it/s] 1654784it [00:24, 376406.82it/s]
def split_indices(n, val_pct):
    n_val = int(val_pct*n)
    idxs = np.random.permutation(n)
    return idxs[n_val:], idxs[:n_val]
train_indices, val_indices = split_indices(len(dataset), val_pct=0.2)
print(len(train_indices), len(val_indices))
48000 12000