Learn practical skills, build real-world projects, and advance your career
from multiprocessing.spawn import freeze_support
import torchvision
import os
import torchvision.models.alexnet
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import datasets
import torch.nn as nn
import torchvision.models.vgg
from matplotlib import cm
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.manifold import TSNE

# base to base code - https://github.com/fyu/drn
# base paper - https://arxiv.org/pdf/1712.02560.pdf
# base code - https://github.com/mil-tokyo/MCD_DA/tree/master/classification
# our implementation - https://github.com/Nyn-ynu/MCD
dataset_root = "dataset"
batch_size = 256
num_workers = 4
device = "cuda"
lr = 1e-4
N = 10
num_epoch = 100
# num_epoch = 2
models_save = "models_trained"
criterion = nn.CrossEntropyLoss()
class LeNetClassifier(nn.Module):
    def __init__(self, prob=0.5):
          super(LeNetClassifier, self).__init__()
          self.fc1 = nn.Linear(48*4*4, 100)
          self.bn1_fc = nn.BatchNorm1d(100)
          self.fc2 = nn.Linear(100, 100)
          self.bn2_fc = nn.BatchNorm1d(100)
          self.fc3 = nn.Linear(100, 10)
          self.bn_fc3 = nn.BatchNorm1d(10)
          self.prob = prob

    def set_lambda(self, lambd):
        self.lambd = lambd
    def forward(self, x):
        x = F.dropout(x, training=self.training, p=self.prob)
        x = F.relu(self.bn1_fc(self.fc1(x)))
        x = F.dropout(x, training=self.training, p=self.prob)
        x = F.relu(self.bn2_fc(self.fc2(x)))
        x = F.dropout(x, training=self.training, p=self.prob)
        x = self.fc3(x)
        return x
class ClassifierVgg(nn.Module):
    def __init__(self, num_classes=31):
        super(ClassifierVgg, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.classifier(x)
        return x
class LeNetEncoder(nn.Module):
    """LeNet encoder model for ADDA."""

    def __init__(self):
      super(LeNetEncoder, self).__init__()
      self.conv1 = nn.Conv2d(1, 32, kernel_size=5, stride=1)
      self.bn1 = nn.BatchNorm2d(32)
      self.conv2 = nn.Conv2d(32, 48, kernel_size=5, stride=1)
      self.bn2 = nn.BatchNorm2d(48)

    def forward(self, x):
        x = torch.mean(x,1).view(x.size()[0],1,x.size()[2],x.size()[3])
        x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))), stride=2, kernel_size=2, dilation=(1, 1))
        x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), stride=2, kernel_size=2, dilation=(1, 1))
        #print(x.size())
        x = x.view(x.size(0), 48*4*4)
        return x