Learn practical skills, build real-world projects, and advance your career
import os
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, random_split, DataLoader
from PIL import Image
import torchvision.models as models
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from sklearn.metrics import f1_score
import torch.nn.functional as F
import torch.nn as nn
from torchvision.utils import make_grid
from torchvision.transforms import ToTensor
%matplotlib inline
cls = ['a','b','f']
# # translated_cls = [0,5,3]
# cl_a = os.listdir('fingerspelling5/dataset5/A/'+cls[0])
# cl_b = os.listdir('fingerspelling5/dataset5/A/'+cls[1])
# cl_f = os.listdir('fingerspelling5/dataset5/A/'+cls[2])
sessions = ['A','B','C','D','E']
data = []
for s in sessions:
    for each in range(len(cls)):
        for file in os.listdir('fingerspelling5/dataset5/'+s+'/'+cls[each]):
            if 'color' in file:
                tmp = [s+'/'+cls[each]+'/'+file, each]
                data.append(tmp)
df = pd.DataFrame(data, columns = ['path', 'label'])
df
labels = {
    0: 'zero',
    1: 'five',
    2: 'three'
}
def encode_label(label):
    target = torch.zeros(3)
    for l in str(label).split(' '):
        target[int(l)] = 1.
    return target

def decode_target(target, text_labels=False, threshold=0.5):
    result = []
    for i, x in enumerate(target):
        if (x >= threshold):
            if text_labels:
                result.append(labels[i] + "(" + str(i) + ")")
            else:
                result.append(str(i))
    return ' '.join(result)
class SignLangDataset(Dataset):
    def __init__(self, df, root_dir, transform=None):
        self.df = df
        self.transform = transform
        self.root_dir = root_dir
        
    def __len__(self):
        return len(self.df)    
    
    def __getitem__(self, idx):
        row = self.df.loc[idx]
        img_id, img_label = row['path'], row['label']
        img_fname = self.root_dir + "/" + img_id
        img = Image.open(img_fname)
        size = 130,130
        img = img.resize((size))
        if self.transform:
            img = self.transform(img)
        return img, encode_label(img_label)