Jovian
Sign In

Transfer Learning for Image Classification in PyTorch

How a CNN learns (source):

cnn-learning

Layer visualization (source):

cnn-learning

Downloading the Dataset

We'll use the Oxford-IIIT Pets dataset from https://course.fast.ai/datasets . It is 37 category (breeds) pet dataset with roughly 200 images for each class. The images have a large variations in scale, pose and lighting.

In [2]:
!pip install jovian --upgrade --quiet
In [3]:
from torchvision.datasets.utils import download_url
In [4]:
download_url('https://s3.amazonaws.com/fast-ai-imageclas/oxford-iiit-pet.tgz', '.')
Using downloaded and verified file: ./oxford-iiit-pet.tgz
In [5]:
import tarfile

with tarfile.open('./oxford-iiit-pet.tgz', 'r:gz') as tar:
    tar.extractall(path='./data')
In [6]:
from torch.utils.data import Dataset
In [7]:
import os

DATA_DIR = './data/oxford-iiit-pet/images'

files = os.listdir(DATA_DIR)
files[:5]
Out[7]:
['pug_189.jpg',
 'newfoundland_110.jpg',
 'Maine_Coon_52.jpg',
 'leonberger_70.jpg',
 'japanese_chin_23.jpg']
In [8]:
def parse_breed(fname):
    parts = fname.split('_')
    return ' '.join(parts[:-1])
In [9]:
parse_breed(files[4])
Out[9]:
'japanese chin'
In [10]:
from PIL import Image

def open_image(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

Creating a Custom PyTorch Dataset

In [11]:
import os

class PetsDataset(Dataset):
    def __init__(self, root, transform):
        super().__init__()
        self.root = root
        self.files = [fname for fname in os.listdir(root) if fname.endswith('.jpg')]
        self.classes = list(set(parse_breed(fname) for fname in files))
        self.transform = transform
    
    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        fname = self.files[i]
        fpath = os.path.join(self.root, fname)
        img = self.transform(open_image(fpath))
        class_idx = self.classes.index(parse_breed(fname))
        return img, class_idx
In [12]:
import torchvision.transforms as T

img_size = 224
imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
dataset = PetsDataset(DATA_DIR, T.Compose([T.Resize(img_size), 
                                           T.Pad(8, padding_mode='reflect'),
                                           T.RandomCrop(img_size), 
                                           T.ToTensor(), 
                                           T.Normalize(*imagenet_stats)]))
In [13]:
len(dataset)
Out[13]:
7390
In [32]:
import torch
import matplotlib.pyplot as plt
%matplotlib inline

def denormalize(images, means, stds):
    if len(images.shape) == 3:
        images = images.unsqueeze(0)
    means = torch.tensor(means).reshape(1, 3, 1, 1)
    stds = torch.tensor(stds).reshape(1, 3, 1, 1)
    return images * stds + means

def show_image(img_tensor, label):
    print('Label:', dataset.classes[label], '(' + str(label) + ')')
    img_tensor = denormalize(img_tensor, *imagenet_stats)[0].permute((1, 2, 0))
    plt.imshow(img_tensor)
In [33]:
show_image(*dataset[2])
Label: Maine Coon (32)

Creating Training and Validation Sets

In [34]:
from torch.utils.data import random_split
In [35]:
val_pct = 0.1
val_size = int(val_pct * len(dataset))

train_ds, valid_ds = random_split(dataset, [len(dataset) - val_size, val_size])
In [36]:
from torch.utils.data import DataLoader
batch_size = 256

train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size*2, num_workers=4, pin_memory=True)
In [39]:
from torchvision.utils import make_grid

def show_batch(dl):
    for images, labels in dl:
        fig, ax = plt.subplots(figsize=(16, 16))
        ax.set_xticks([]); ax.set_yticks([])
        images = denormalize(images[:64], *imagenet_stats)
        ax.imshow(make_grid(images, nrow=8).permute(1, 2, 0))
        break

In [40]:
show_batch(train_dl)
Output hidden; open in https://colab.research.google.com to view.

Modifying a Pretrained Model (ResNet34)

Transfer learning (source): transfer-learning

In [41]:
import torch.nn as nn
import torch.nn.functional as F

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))


class ImageClassificationBase(nn.Module):
    def training_step(self, batch):
        images, labels = batch
        out = self(images)                  # Generate predictions
        loss = F.cross_entropy(out, labels)  # Calculate loss
        return loss

    def validation_step(self, batch):
        images, labels = batch
        out = self(images)                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}

    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

    def epoch_end(self, epoch, result):
        print("Epoch [{}],{} train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, "last_lr: {:.5f},".format(result['lrs'][-1]) if 'lrs' in result else '', 
            result['train_loss'], result['val_loss'], result['val_acc']))

In [42]:
from torchvision import models

class PetsModel(ImageClassificationBase):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        # Use a pretrained model
        self.network = models.resnet34(pretrained=pretrained)
        # Replace last layer
        self.network.fc = nn.Linear(self.network.fc.in_features, num_classes)

    def forward(self, xb):
        return self.network(xb)

GPU Utilities and Training Loop

In [43]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')


def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)


class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""

    def __init__(self, dl, device):
        self.dl = dl
        self.device = device

    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl:
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [44]:
import torch
from tqdm.notebook import tqdm

@torch.no_grad()
def evaluate(model, val_loader):
    model.eval()
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)


def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
    history = []
    optimizer = opt_func(model.parameters(), lr)
    for epoch in range(epochs):
        # Training Phase
        model.train()
        train_losses = []
        for batch in tqdm(train_loader):
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        model.epoch_end(epoch, result)
        history.append(result)
    return history

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader,
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
    torch.cuda.empty_cache()
    history = []

    # Set up custom optimizer with weight decay
    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)
    # Set up one-cycle learning rate scheduler
    sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs,
                                                steps_per_epoch=len(train_loader))

    for epoch in range(epochs):
        # Training Phase
        model.train()
        train_losses = []
        lrs = []
        for batch in tqdm(train_loader):
            loss = model.training_step(batch)
            train_losses.append(loss)
            loss.backward()

            # Gradient clipping
            if grad_clip:
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)

            optimizer.step()
            optimizer.zero_grad()

            # Record & update learning rate
            lrs.append(get_lr(optimizer))
            sched.step()

        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        model.epoch_end(epoch, result)
        history.append(result)
    return history
In [45]:
device = get_default_device()
device
Out[45]:
device(type='cuda')
In [46]:
train_dl = DeviceDataLoader(train_dl, device)
valid_dl = DeviceDataLoader(valid_dl, device)

Finetuning the Pretrained Model

In [53]:
model = PetsModel(len(dataset.classes))
to_device(model, device);
In [54]:
history = [evaluate(model, valid_dl)]
history
Out[54]:
[{'val_acc': 0.02591117098927498, 'val_loss': 3.8896164894104004}]
In [55]:
epochs = 6
max_lr = 0.01
grad_clip = 0.1
weight_decay = 1e-4
opt_func = torch.optim.Adam
In [56]:
%%time
history += fit_one_cycle(epochs, max_lr, model, train_dl, valid_dl, 
                         grad_clip=grad_clip, 
                         weight_decay=weight_decay, 
                         opt_func=opt_func)
HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))
Epoch [0],last_lr: 0.00589, train_loss: 1.3628, val_loss: 165.6215, val_acc: 0.0240
HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))
Epoch [1],last_lr: 0.00994, train_loss: 1.9505, val_loss: 4.8602, val_acc: 0.0447
HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))
Epoch [2],last_lr: 0.00812, train_loss: 1.3101, val_loss: 1.8901, val_acc: 0.4405
HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))
Epoch [3],last_lr: 0.00463, train_loss: 0.8174, val_loss: 1.1133, val_acc: 0.6309
HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))
Epoch [4],last_lr: 0.00133, train_loss: 0.4924, val_loss: 0.6684, val_acc: 0.7638
HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))
Epoch [5],last_lr: 0.00000, train_loss: 0.3033, val_loss: 0.5651, val_acc: 0.8167 CPU times: user 49.4 s, sys: 37.7 s, total: 1min 27s Wall time: 2min 59s

Training a model from scratch

Let's repeat the training without using weights from the pretrained ResNet34 model.

In [57]:
model2 = PetsModel(len(dataset.classes), pretrained=False)
to_device(model2, device);
In [58]:
history2 = [evaluate(model2, valid_dl)]
history2
Out[58]:
[{'val_acc': 0.02223292924463749, 'val_loss': 64.20709228515625}]
In [59]:
%%time
history2 += fit_one_cycle(epochs, max_lr, model2, train_dl, valid_dl, 
                         grad_clip=grad_clip, 
                         weight_decay=weight_decay, 
                         opt_func=opt_func)
HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))
Epoch [0],last_lr: 0.00589, train_loss: 3.5621, val_loss: 570.1448, val_acc: 0.0227
HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))
Epoch [1],last_lr: 0.00994, train_loss: 3.4390, val_loss: 4.6319, val_acc: 0.0438
HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))
Epoch [2],last_lr: 0.00812, train_loss: 3.2430, val_loss: 3.2921, val_acc: 0.1269
HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))
Epoch [3],last_lr: 0.00463, train_loss: 2.9670, val_loss: 3.0076, val_acc: 0.1647
HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))
Epoch [4],last_lr: 0.00133, train_loss: 2.7464, val_loss: 2.9556, val_acc: 0.1672
HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))
Epoch [5],last_lr: 0.00000, train_loss: 2.5684, val_loss: 2.6616, val_acc: 0.2449 CPU times: user 48.9 s, sys: 37.4 s, total: 1min 26s Wall time: 2min 57s

While the pretrained model reached an accuracy of 80% in less than 3 minutes, the model without pretrained weights could only reach an accuracy of 24%.

In [ ]:
!pip install jovian --upgrade --quiet
In [ ]:
import jovian
In [ ]:
jovian.commit(project='transfer-learning-pytorch')