Jovian
⭐️
Sign In
Learn data science and machine learning by building real-world projects on Jovian
In [105]:
# !conda install numpy pytorch torchvision cpuonly -c pytorch -y
# !pip install matplotlib --upgrade --quiet
In [106]:
# !pip install jovian --upgrade --quiet
In [107]:
import os
import numpy as np
import math
import glob
import random
from time import perf_counter 

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torchvision.transforms import ToTensor, ToPILImage
from torchvision.models import vgg19

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.utils.data import Dataset

from PIL import Image
In [108]:
project_name='Super_Resolution_Using_GAN'
In [109]:
import jovian
jovian.commit(project=project_name, environment=None)
[jovian] Attempting to save notebook.. [jovian] Detected Kaggle notebook... [jovian] Uploading notebook to https://jovian.ml/ms-krajesh/Super_Resolution_Using_GAN
In [110]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        
        vgg19_model = vgg19(pretrained=True)
        
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18]) #only using initial 18 layers
        #print(self.feature_extractor)

    def forward(self, img):
        return self.feature_extractor(img)   
In [111]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(in_features, 0.8)
        self.prelu = nn.PReLU()

    def forward(self, x):
        xin = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.conv1(x)
        x = self.bn1(x)
        return xin + x
In [112]:
class GeneratorNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
        super(GeneratorNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4)
        self.prelu = nn.PReLU()
        
        res_blocks = []
        for _ in range(n_residual_blocks):
            res_blocks.append(ResidualBlock(64))
            
        self.res_blocks = nn.Sequential(*res_blocks)

        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64, 0.8)

        upsampling = []
        for out_features in range(2):
            upsampling += [
                nn.Conv2d(64, 256, 3, 1, 1),
                nn.BatchNorm2d(256),
                nn.PixelShuffle(upscale_factor=2),
                nn.PReLU(),
            ]
        self.upsampling = nn.Sequential(*upsampling)

        self.conv3 = nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4)
        self.tanh = nn.Tanh()

    def forward(self, x):
        out1 = self.prelu(self.conv1(x))
        
        out = self.res_blocks(out1)
        
        out2 = self.bn2(self.conv2(out))
        
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.tanh(self.conv3(out))
        return out
In [113]:
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
        self.output_shape = (1, patch_h, patch_w)

        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            
            return layers

        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)
In [114]:
class opt:
    epoch = 0
    n_epochs = 5
    dataset_name = "img_align_celeba" #https://www.kaggle.com/jessicali9530/celeba-dataset/activity
    batch_size = 4

    lr = 0.0002
    b1 = 0.5
    b2 = 0.999
    n_cpu = 8
    
    hr_height = 256
    hr_width = 256
    channels = 3
    
    sample_interval = 100
    checkpoint_interval = 2
    nrOfImages = 50
In [115]:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

class ImageDataset(Dataset):
    def __init__(self, root, hr_shape):

        hr_height, hr_width = hr_shape
        
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_height), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        #print('root: ', root)
        self.files = sorted(glob.glob(root + "/*.*"))
        #print('self.files: ', self.files)
        self.files = self.files[0:opt.nrOfImages]

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)

        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.files)
In [116]:
os.makedirs("train_images", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)

cuda = torch.cuda.is_available()
hr_shape = (opt.hr_height, opt.hr_width)

generator = GeneratorNet()
discriminator = Discriminator(input_shape = (opt.channels, *hr_shape))
feature_extractor = FeatureExtractor()

feature_extractor.eval()

criterion_GAN = torch.nn.MSELoss()
criterion_content = torch.nn.L1Loss()

if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    feature_extractor = feature_extractor.cuda()
    criterion_GAN = criterion_GAN.cuda()
    criterion_content = criterion_content.cuda()

if opt.epoch != 0:
    generator.load_state_dict(torch.load("saved_models/generator_%d.pth" % opt.epoch-1))
    discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth" % opt.epoch-1))
    
optimizer_G = torch.optim.Adam(generator.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = opt.lr, betas = (opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

dataloader = DataLoader(
    ImageDataset("../input/celeba-dataset/img_align_celeba/%s" % opt.dataset_name, hr_shape = hr_shape),
    batch_size = opt.batch_size,
    shuffle = True,
    num_workers = opt.n_cpu, )
In [117]:
for epoch in range(opt.epoch, opt.n_epochs):
    for i, imgs in enumerate(dataloader):

        imgs_lr = Variable(imgs["lr"].type(Tensor))
        imgs_hr = Variable(imgs["hr"].type(Tensor))

        valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)

        optimizer_G.zero_grad()

        gen_hr = generator(imgs_lr)

        loss_GAN = criterion_GAN(discriminator(gen_hr), valid)

        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr)
        loss_content = criterion_content(gen_features, real_features.detach())

        loss_G = loss_content + 1e-3 * loss_GAN

        loss_G.backward()
        optimizer_G.step()

  
        optimizer_D.zero_grad()

        loss_real = criterion_GAN(discriminator(imgs_hr), valid)
        loss_fake = criterion_GAN(discriminator(gen_hr.detach()), fake)

        loss_D = (loss_real + loss_fake) / 2

        loss_D.backward()
        optimizer_D.step()

        print("\n[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader), loss_D.item(), loss_G.item()))

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
            imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
            img_grid = torch.cat((imgs_lr, gen_hr), -1)
            save_image(img_grid, "train_images/%d.png" % batches_done, normalize=False)

    if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
        torch.save(generator.state_dict(), "saved_models/generator_%d.pth" % epoch) 
        torch.save(discriminator.state_dict(), "saved_models/discriminator_%d.pth" % epoch)

torch.save(generator.state_dict(), "saved_models/generator.pth")
torch.save(discriminator.state_dict(), "saved_models/discriminator.pth")        
    
[Epoch 0/5] [Batch 0/13] [D loss: 0.726819] [G loss: 1.217518] [Epoch 0/5] [Batch 1/13] [D loss: 2.272282] [G loss: 1.308014] [Epoch 0/5] [Batch 2/13] [D loss: 0.834556] [G loss: 1.221834] [Epoch 0/5] [Batch 3/13] [D loss: 0.837838] [G loss: 1.209350] [Epoch 0/5] [Batch 4/13] [D loss: 0.487286] [G loss: 1.177606] [Epoch 0/5] [Batch 5/13] [D loss: 0.409396] [G loss: 1.101249] [Epoch 0/5] [Batch 6/13] [D loss: 0.378335] [G loss: 1.225453] [Epoch 0/5] [Batch 7/13] [D loss: 0.326242] [G loss: 1.122046] [Epoch 0/5] [Batch 8/13] [D loss: 0.253667] [G loss: 1.059780] [Epoch 0/5] [Batch 9/13] [D loss: 0.228466] [G loss: 0.925188] [Epoch 0/5] [Batch 10/13] [D loss: 0.226892] [G loss: 1.025046] [Epoch 0/5] [Batch 11/13] [D loss: 0.205899] [G loss: 1.060136] [Epoch 0/5] [Batch 12/13] [D loss: 0.158382] [G loss: 0.876392] [Epoch 1/5] [Batch 0/13] [D loss: 0.175717] [G loss: 1.071230] [Epoch 1/5] [Batch 1/13] [D loss: 0.199367] [G loss: 1.121383] [Epoch 1/5] [Batch 2/13] [D loss: 0.164427] [G loss: 1.137226] [Epoch 1/5] [Batch 3/13] [D loss: 0.167777] [G loss: 1.039816] [Epoch 1/5] [Batch 4/13] [D loss: 0.168828] [G loss: 1.000399] [Epoch 1/5] [Batch 5/13] [D loss: 0.106881] [G loss: 0.969819] [Epoch 1/5] [Batch 6/13] [D loss: 0.114161] [G loss: 1.217691] [Epoch 1/5] [Batch 7/13] [D loss: 0.085715] [G loss: 1.024368] [Epoch 1/5] [Batch 8/13] [D loss: 0.099214] [G loss: 0.935498] [Epoch 1/5] [Batch 9/13] [D loss: 0.224029] [G loss: 0.973002] [Epoch 1/5] [Batch 10/13] [D loss: 0.131349] [G loss: 0.978368] [Epoch 1/5] [Batch 11/13] [D loss: 0.084181] [G loss: 0.995407] [Epoch 1/5] [Batch 12/13] [D loss: 0.105951] [G loss: 0.968867] [Epoch 2/5] [Batch 0/13] [D loss: 0.092173] [G loss: 1.127437] [Epoch 2/5] [Batch 1/13] [D loss: 0.056599] [G loss: 0.885595] [Epoch 2/5] [Batch 2/13] [D loss: 0.055667] [G loss: 0.963479] [Epoch 2/5] [Batch 3/13] [D loss: 0.053647] [G loss: 0.973574] [Epoch 2/5] [Batch 4/13] [D loss: 0.043773] [G loss: 0.952029] [Epoch 2/5] [Batch 5/13] [D loss: 0.045229] [G loss: 1.024572] [Epoch 2/5] [Batch 6/13] [D loss: 0.050590] [G loss: 1.093738] [Epoch 2/5] [Batch 7/13] [D loss: 0.044590] [G loss: 1.084773] [Epoch 2/5] [Batch 8/13] [D loss: 0.032812] [G loss: 0.950934] [Epoch 2/5] [Batch 9/13] [D loss: 0.040071] [G loss: 1.012064] [Epoch 2/5] [Batch 10/13] [D loss: 0.027793] [G loss: 0.937086] [Epoch 2/5] [Batch 11/13] [D loss: 0.026305] [G loss: 1.040066] [Epoch 2/5] [Batch 12/13] [D loss: 0.027679] [G loss: 0.945597] [Epoch 3/5] [Batch 0/13] [D loss: 0.023307] [G loss: 0.985065] [Epoch 3/5] [Batch 1/13] [D loss: 0.030319] [G loss: 0.975277] [Epoch 3/5] [Batch 2/13] [D loss: 0.032863] [G loss: 1.045494] [Epoch 3/5] [Batch 3/13] [D loss: 0.025067] [G loss: 0.924258] [Epoch 3/5] [Batch 4/13] [D loss: 0.021873] [G loss: 0.938208] [Epoch 3/5] [Batch 5/13] [D loss: 0.024914] [G loss: 0.990609] [Epoch 3/5] [Batch 6/13] [D loss: 0.032830] [G loss: 1.063364] [Epoch 3/5] [Batch 7/13] [D loss: 0.016919] [G loss: 0.896510] [Epoch 3/5] [Batch 8/13] [D loss: 0.017711] [G loss: 1.047194] [Epoch 3/5] [Batch 9/13] [D loss: 0.017543] [G loss: 0.949439] [Epoch 3/5] [Batch 10/13] [D loss: 0.022030] [G loss: 0.857707] [Epoch 3/5] [Batch 11/13] [D loss: 0.025552] [G loss: 0.965952] [Epoch 3/5] [Batch 12/13] [D loss: 0.040143] [G loss: 1.130170] [Epoch 4/5] [Batch 0/13] [D loss: 0.023899] [G loss: 0.941401] [Epoch 4/5] [Batch 1/13] [D loss: 0.017586] [G loss: 0.866710] [Epoch 4/5] [Batch 2/13] [D loss: 0.025075] [G loss: 0.956455] [Epoch 4/5] [Batch 3/13] [D loss: 0.034156] [G loss: 1.029549] [Epoch 4/5] [Batch 4/13] [D loss: 0.036057] [G loss: 1.029844] [Epoch 4/5] [Batch 5/13] [D loss: 0.033219] [G loss: 0.927409] [Epoch 4/5] [Batch 6/13] [D loss: 0.051863] [G loss: 0.996887] [Epoch 4/5] [Batch 7/13] [D loss: 0.051878] [G loss: 0.863585] [Epoch 4/5] [Batch 8/13] [D loss: 0.044914] [G loss: 1.086044] [Epoch 4/5] [Batch 9/13] [D loss: 0.026235] [G loss: 0.951670] [Epoch 4/5] [Batch 10/13] [D loss: 0.023232] [G loss: 1.064327] [Epoch 4/5] [Batch 11/13] [D loss: 0.019589] [G loss: 0.877764] [Epoch 4/5] [Batch 12/13] [D loss: 0.036377] [G loss: 0.948228]
In [120]:
jovian.reset()
In [121]:
jovian.log_hyperparams(start_epoch=opt.epoch,
                       number_of_epochs=opt.n_epochs,
                       lrs=opt.lr,
                       beta1=opt.b1,
                       beta2=opt.b2)
[jovian] Hyperparams logged.
In [122]:
jovian.log_metrics(generator_loss=loss_G.item(), discriminator_loss=loss_D.item())
[jovian] Metrics logged.
In [123]:
#inference
os.makedirs("images_inference", exist_ok=True)

network = GeneratorNet()
network = network.eval()

if torch.cuda.is_available():
    network.cuda()
    network.load_state_dict(torch.load('saved_models/generator.pth'))
else:
    network.load_state_dict(torch.load('saved_models/generator.pth', map_location=lambda storage, loc: storage))

im_number = '200080'
imgs_lr = Image.open('../input/celeba-dataset/img_align_celeba/img_align_celeba/' + im_number + '.jpg')

imgs_lr = Variable(ToTensor()(imgs_lr)).unsqueeze(0)

if torch.cuda.is_available():
    imgs_lr = imgs_lr.cuda()
    
with torch.no_grad():
    start = perf_counter()
    gen_hr = network(imgs_lr)
    elapsed = (perf_counter() - start)

    print('time cost: ' + str(elapsed) + 'sec')
        
    print('Shape imgs_lr:', imgs_lr.shape)
    print('Shape gen_hr:', gen_hr.shape)
    imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
    
    #imgs_lr = ToPILImage()(imgs_lr[0].data.cpu())
    #gen_hr = ToPILImage()(gen_hr[0].data.cpu())
    print('Shape imgs_lr post interpolation:', imgs_lr.shape)

    gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
    imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
    img_grid = torch.cat((imgs_lr, gen_hr), -1)
    save_image(img_grid, "images_inference/"+ str(im_number) + ".png", normalize=False)
time cost: 5.231988329993328sec Shape imgs_lr: torch.Size([1, 3, 218, 178]) Shape gen_hr: torch.Size([1, 3, 872, 712]) Shape imgs_lr post interpolation: torch.Size([1, 3, 872, 712])
In [124]:
jovian.commit(project=project_name, outputs=["saved_models/generator.pth", "saved_models/discriminator.pth"], environment=None)
[jovian] Attempting to save notebook.. [jovian] Detected Kaggle notebook... [jovian] Uploading notebook to https://jovian.ml/ms-krajesh/Super_Resolution_Using_GAN