Learn practical skills, build real-world projects, and advance your career
Updated 4 years ago
Pytorch Starter Pre-Trained Resnet50
This kernel mostly implements the Pytorch Transfer Learning tutorial with a custom dataset class and the resnet50 pretrained model from torchvision.
%matplotlib inline
import matplotlib.pyplot as plt
import time
from shutil import copyfile
from os.path import isfile, join, abspath, exists, isdir, expanduser
from os import listdir, makedirs, getcwd, remove
from PIL import Image
from mpl_toolkits.axes_grid1 import ImageGrid
import pandas as pd
import numpy as np
import torch
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as func
import torchvision
from torchvision import transforms, datasets, models
Define Custom Dataset
class SeedlingDataset(Dataset):
def __init__(self, labels, root_dir, subset=False, transform=None):
self.labels = labels
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
img_name = self.labels.iloc[idx, 0]
fullname = join(self.root_dir, img_name)
image = Image.open(fullname).convert('RGB')
labels = self.labels.iloc[idx, 2]
if self.transform:
image = self.transform(image)
return image, int(labels)