Learn practical skills, build real-world projects, and advance your career
from apex import amp
import albumentations
import numpy as np
import pretrainedmodels
import time
import torch
import torch.nn as nn
from tqdm import tqdm
import jovian
import sklearn.metrics

from torch.nn import functional as F
from utils.dataset import BengaliDataset
from utils.pytorchtools import EarlyStopping
DEVICE = "cuda"
TRAINING_FOLDS_CSV = []
IMG_H = 137
IMG_W = 236
EPOCHS = 12
TRAIN_BS = 128
TEST_BS = 128
MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)
AUGS = [albumentations.ShiftScaleRotate(shift_limit = 0.1,
                                        scale_limit=0.1,
                                        rotate_limit=45, p=0.9),
       albumentations.Cutout(max_h_size=IMG_H//4, max_w_size=IMG_W//4, num_holes=2, p=.75)]
class SeResnet50(nn.Module):
    def __init__(self, pretrained):
        super(SeResnet50, self).__init__()
        if pretrained:
            self.model = pretrainedmodels.__dict__["se_resnext50_32x4d"](pretrained="imagenet")
        else:
            self.model = pretrainedmodels.__dict__["se_resnext50_32x4d"](pretrained=None)
            
        self.l0 = nn.Linear(2048, 168)
        self.l1 = nn.Linear(2048, 11)
        self.l2 = nn.Linear(2048, 7)
        
        
    def forward(self, x):
        bs, _, _, _ = x.shape
        x = self.model.features(x)
        x = F.adaptive_avg_pool2d(x, 1).reshape(bs, -1)
        l0 = self.l0(x)
        l1 = self.l1(x)
        l2 = self.l2(x)
        return l0, l1, l2
# def rand_bbox(size, lam):
#     W = size[2]
#     H = size[3]
#     cut_rat = np.sqrt(1. - lam)
#     cut_w = np.int(W * cut_rat)
#     cut_h = np.int(H * cut_rat)

#     # uniform
#     cx = np.random.randint(W)
#     cy = np.random.randint(H)

#     bbx1 = np.clip(cx - cut_w // 2, 0, W)
#     bby1 = np.clip(cy - cut_h // 2, 0, H)
#     bbx2 = np.clip(cx + cut_w // 2, 0, W)
#     bby2 = np.clip(cy + cut_h // 2, 0, H)

#     return bbx1, bby1, bbx2, bby2

# def cutmix(data, targets1, targets2, targets3, alpha):
#     indices = torch.randperm(data.size(0))
#     shuffled_data = data[indices]
#     shuffled_targets1 = targets1[indices]
#     shuffled_targets2 = targets2[indices]
#     shuffled_targets3 = targets3[indices]

#     lam = np.random.beta(alpha, alpha)
#     bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
#     data[:, :, bbx1:bbx2, bby1:bby2] = data[indices, :, bbx1:bbx2, bby1:bby2]
#     # adjust lambda to exactly match pixel ratio
#     lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))

#     targets = [targets1, shuffled_targets1, targets2, shuffled_targets2, targets3, shuffled_targets3, lam]
#     return data, targets

# def mixup(data, targets1, targets2, targets3, alpha):
#     indices = torch.randperm(data.size(0))
#     shuffled_data = data[indices]
#     shuffled_targets1 = targets1[indices]
#     shuffled_targets2 = targets2[indices]
#     shuffled_targets3 = targets3[indices]

#     lam = np.random.beta(alpha, alpha)
#     data = data * lam + shuffled_data * (1 - lam)
#     targets = [targets1, shuffled_targets1, targets2, shuffled_targets2, targets3, shuffled_targets3, lam]

#     return data, targets


# def cutmix_criterion(preds1,preds2,preds3, targets):
#     targets1, targets2,targets3, targets4,targets5, targets6, lam = targets[0], targets[1], targets[2], targets[3], targets[4], targets[5], targets[6]
#     criterion = nn.CrossEntropyLoss(reduction='mean')
#     return lam * criterion(preds1, targets1) + (1 - lam) * criterion(preds1, targets2) + lam * criterion(preds2, targets3) + (1 - lam) * criterion(preds2, targets4) + lam * criterion(preds3, targets5) + (1 - lam) * criterion(preds3, targets6)

# def mixup_criterion(preds1,preds2,preds3, targets):
#     targets1, targets2,targets3, targets4,targets5, targets6, lam = targets[0], targets[1], targets[2], targets[3], targets[4], targets[5], targets[6]
#     criterion = nn.CrossEntropyLoss(reduction='mean')
#     return lam * criterion(preds1, targets1) + (1 - lam) * criterion(preds1, targets2) + lam * criterion(preds2, targets3) + (1 - lam) * criterion(preds2, targets4) + lam * criterion(preds3, targets5) + (1 - lam) * criterion(preds3, targets6)