Jovian
⭐️
Sign In

Training CIFAR10 with ResNet18

We're using PyTorch and the ResNet18 architecture to train a model to 95% accuracy on CIFAR10 from scratch.

In [1]:
import os
import torch
import torchvision
import tarfile
from torchvision.datasets.utils import download_url
In [2]:
# Dowload the dataset
dataset_url = "http://files.fast.ai/data/cifar10.tgz"
download_url(dataset_url, '.')
0it [00:00, ?it/s]
Downloading http://files.fast.ai/data/cifar10.tgz to ./cifar10.tgz
168173568it [00:01, 88002056.66it/s]
In [3]:
# Extract from archive
with tarfile.open('./cifar10.tgz', 'r:gz') as tar:
    tar.extractall(path='./data')

In [4]:
data_dir = './data/cifar10'

print(os.listdir(data_dir))
classes = os.listdir(data_dir + "/train")
print(classes)
['labels.txt', 'train', 'test'] ['ship', 'bird', 'cat', 'horse', 'deer', 'automobile', 'dog', 'airplane', 'frog', 'truck']
In [5]:
PATH = "data/cifar10/"
trn_dir, val_dir = PATH + 'train', PATH + 'test'
In [14]:
import torchvision.transforms as tt
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
In [8]:
stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
# Data transforms (normalization & data augmentation)
tfms = [tt.ToTensor(), tt.Normalize(*stats)]
aug_tfms = tt.Compose([tt.RandomCrop(32, padding=4), 
                       tt.RandomHorizontalFlip()] + tfms)
In [11]:
# PyTorch datasets
trn_ds = ImageFolder(trn_dir, aug_tfms)
val_ds = ImageFolder(val_dir, tt.Compose(tfms))
In [16]:
batch_size=128
In [19]:
trn_dl = DataLoader(trn_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size, shuffle=False, num_workers=4, pin_memory=True)
In [20]:
from fastai.basic_data import DataBunch
In [25]:
data = DataBunch.create(trn_ds, val_ds, path='./data/cifar10', bs=batch_size)
In [27]:
data.train_dl
Out[27]:
DeviceDataLoader(dl=<torch.utils.data.dataloader.DataLoader object at 0x7f9ddc2ff5f8>, device=device(type='cuda'), tfms=[], collate_fn=<function data_collate at 0x7f9de3b95d08>)
In [28]:
data.valid_dl
Out[28]:
DeviceDataLoader(dl=<torch.utils.data.dataloader.DataLoader object at 0x7f9ddc2ff828>, device=device(type='cuda'), tfms=[], collate_fn=<function data_collate at 0x7f9de3b95d08>)
In [29]:
data.device
Out[29]:
device(type='cuda')
In [32]:
trn_ds.classes
Out[32]:
['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

Model

In [33]:
import torch.nn as nn
import torch.nn.functional as F

def conv_2d(ni, nf, stride=1, ks=3):
    return nn.Conv2d(in_channels=ni, out_channels=nf, 
                     kernel_size=ks, stride=stride, 
                     padding=ks//2, bias=False)

def bn_relu_conv(ni, nf):
    return nn.Sequential(nn.BatchNorm2d(ni), 
                         nn.ReLU(inplace=True), 
                         conv_2d(ni, nf))

class BasicBlock(nn.Module):
    def __init__(self, ni, nf, stride=1):
        super().__init__()
        self.bn = nn.BatchNorm2d(ni)
        self.conv1 = conv_2d(ni, nf, stride)
        self.conv2 = bn_relu_conv(nf, nf)
        self.shortcut = lambda x: x
        if ni != nf:
            self.shortcut = conv_2d(ni, nf, stride, 1)
    
    def forward(self, x):
        x = F.relu(self.bn(x), inplace=True)
        r = self.shortcut(x)
        x = self.conv1(x)
        x = self.conv2(x) * 0.2
        return x.add_(r)
In [34]:
def make_group(N, ni, nf, stride):
    start = BasicBlock(ni, nf, stride)
    rest = [BasicBlock(nf, nf) for j in range(1, N)]
    return [start] + rest

class Flatten(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, x): return x.view(x.size(0), -1)

class WideResNet(nn.Module):
    def __init__(self, n_groups, N, n_classes, k=1, n_start=16):
        super().__init__()      
        # Increase channels to n_start using conv layer
        layers = [conv_2d(3, n_start)]
        n_channels = [n_start]
        
        # Add groups of BasicBlock(increase channels & downsample)
        for i in range(n_groups):
            n_channels.append(n_start*(2**i)*k)
            stride = 2 if i>0 else 1
            layers += make_group(N, n_channels[i], 
                                 n_channels[i+1], stride)
        
        # Pool, flatten & add linear layer for classification
        layers += [nn.BatchNorm2d(n_channels[3]), 
                   nn.ReLU(inplace=True), 
                   nn.AdaptiveAvgPool2d(1), 
                   Flatten(), 
                   nn.Linear(n_channels[3], n_classes)]
        
        self.features = nn.Sequential(*layers)
        
    def forward(self, x): return self.features(x)
    
def wrn_22(): 
    return WideResNet(n_groups=3, N=3, n_classes=10, k=6)
In [39]:
from fastai.train import Learner
In [44]:
from fastai.metrics import accuracy
In [55]:
import torch.nn.functional as F
In [56]:
?Learner
In [57]:
model = wrn_22()
In [58]:
learner = Learner(data, model, loss_func=F.cross_entropy, metrics=[accuracy], path='./data')
In [59]:
learner.clip = 0.1
In [60]:
?learner.fit_one_cycle
In [61]:
learner.lr_find()
epoch train_loss valid_loss accuracy time 0 7.645119 #na# 00:14 LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [62]:
learner.recorder.plot()
/opt/conda/lib/python3.6/site-packages/fastai/sixel.py:16: UserWarning: You could see this plot with `libsixel`. See https://github.com/saitoha/libsixel warn("You could see this plot with `libsixel`. See https://github.com/saitoha/libsixel")
Notebook Image
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [63]:
%%time
learner.fit_one_cycle(20, 1e-3, wd=1e-4)
epoch train_loss valid_loss accuracy time 0 1.445478 1.411547 0.479000 01:00 1 1.077693 1.299420 0.553300 01:00 2 0.875239 1.250685 0.603000 01:00 3 0.723028 0.847676 0.723100 01:01 4 0.614039 0.839728 0.724200 01:01 5 0.520733 0.612923 0.802500 01:01 6 0.443135 0.603840 0.804800 01:01 7 0.380558 0.431947 0.855300 01:01 8 0.317870 0.387011 0.869200 01:01 9 0.300501 0.373976 0.879900 01:01 10 0.256610 0.368976 0.883200 01:00 11 0.208630 0.330221 0.894700 01:01 12 0.170709 0.333554 0.896600 01:01 13 0.141391 0.298024 0.907900 01:01 14 0.111643 0.293715 0.912600 01:01 15 0.080102 0.276092 0.919900 01:01 16 0.066331 0.269173 0.921000 01:01 17 0.052194 0.272048 0.923100 01:01 18 0.043042 0.264378 0.925100 01:01 19 0.043691 0.265273 0.925800 01:02 CPU times: user 10min 23s, sys: 8min 44s, total: 19min 7s Wall time: 20min 24s
In [64]:
!pip install jovian --upgrade
Collecting jovian Downloading https://files.pythonhosted.org/packages/b8/a1/dd7b5bcf0a3f043c894151f5e523c1f88ab72ed5672694f325e2f39e85f1/jovian-0.1.97-py2.py3-none-any.whl (50kB) |████████████████████████████████| 51kB 1.6MB/s eta 0:00:011 Requirement already satisfied, skipping upgrade: requests in /opt/conda/lib/python3.6/site-packages (from jovian) (2.22.0) Collecting uuid Downloading https://files.pythonhosted.org/packages/ce/63/f42f5aa951ebf2c8dac81f77a8edcc1c218640a2a35a03b9ff2d4aa64c3d/uuid-1.30.tar.gz Requirement already satisfied, skipping upgrade: click in /opt/conda/lib/python3.6/site-packages (from jovian) (7.0) Requirement already satisfied, skipping upgrade: pyyaml in /opt/conda/lib/python3.6/site-packages (from jovian) (5.2) Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/lib/python3.6/site-packages (from requests->jovian) (1.25.7) Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /opt/conda/lib/python3.6/site-packages (from requests->jovian) (2019.11.28) Requirement already satisfied, skipping upgrade: idna<2.9,>=2.5 in /opt/conda/lib/python3.6/site-packages (from requests->jovian) (2.8) Requirement already satisfied, skipping upgrade: chardet<3.1.0,>=3.0.2 in /opt/conda/lib/python3.6/site-packages (from requests->jovian) (3.0.4) Building wheels for collected packages: uuid Building wheel for uuid (setup.py) ... done Created wheel for uuid: filename=uuid-1.30-cp36-none-any.whl size=6501 sha256=4e64796532c96fb4d594a21edb59b22d57db4cb53a400126a9a1a991828c8a52 Stored in directory: /tmp/.cache/pip/wheels/2a/80/9b/015026567c29fdffe31d91edbe7ba1b17728db79194fca1f21 Successfully built uuid Installing collected packages: uuid, jovian Successfully installed jovian-0.1.97 uuid-1.30
In [ ]:
jovian.reset()
In [67]:
import jovian
jovian.log_hyperparams({'arch':'wrn22', 'lr':1e-3, 'epochs':20, 'one_cycle':True, 'wd':1e-4, })
[jovian] Please enter your API key ( from https://jovian.ml/ ): API KEY: ········ [jovian] Hyperparameters logged.
In [68]:
jovian.log_metrics({'train_loss': 0.043691, 'val_loss': 0.265273, 'val_acc': 0.925800, 'time': '20:24'})
[jovian] Metrics logged.
In [ ]:
jovian.commit()
[jovian] Saving notebook..
In [ ]: