Learn practical skills, build real-world projects, and advance your career
Created 4 years ago
import os
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, random_split, DataLoader
from PIL import Image
import torchvision.models as models
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from sklearn.metrics import f1_score
import torch.nn.functional as F
import torch.nn as nn
from torchvision.utils import make_grid
from torchvision.transforms import ToTensor
%matplotlib inline
cls = ['a','b','f']
# # translated_cls = [0,5,3]
# cl_a = os.listdir('fingerspelling5/dataset5/A/'+cls[0])
# cl_b = os.listdir('fingerspelling5/dataset5/A/'+cls[1])
# cl_f = os.listdir('fingerspelling5/dataset5/A/'+cls[2])
sessions = ['A','B','C','D','E']
data = []
for s in sessions:
for each in range(len(cls)):
for file in os.listdir('fingerspelling5/dataset5/'+s+'/'+cls[each]):
if 'color' in file:
tmp = [s+'/'+cls[each]+'/'+file, each]
data.append(tmp)
df = pd.DataFrame(data, columns = ['path', 'label'])
df
labels = {
0: 'zero',
1: 'five',
2: 'three'
}
def encode_label(label):
target = torch.zeros(3)
for l in str(label).split(' '):
target[int(l)] = 1.
return target
def decode_target(target, text_labels=False, threshold=0.5):
result = []
for i, x in enumerate(target):
if (x >= threshold):
if text_labels:
result.append(labels[i] + "(" + str(i) + ")")
else:
result.append(str(i))
return ' '.join(result)
class SignLangDataset(Dataset):
def __init__(self, df, root_dir, transform=None):
self.df = df
self.transform = transform
self.root_dir = root_dir
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.loc[idx]
img_id, img_label = row['path'], row['label']
img_fname = self.root_dir + "/" + img_id
img = Image.open(img_fname)
size = 130,130
img = img.resize((size))
if self.transform:
img = self.transform(img)
return img, encode_label(img_label)