Learn practical skills, build real-world projects, and advance your career
import  os
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset
from PIL import Image
data_dir='./cyclegan/data/train/'
print(os.listdir('./cyclegan/data/train/horse/')[:10])
Image.open(data_dir+'horse/n02381460_1001.jpg')
['n02381460_1001.jpg', 'n02381460_1002.jpg', 'n02381460_1003.jpg', 'n02381460_1006.jpg', 'n02381460_1008.jpg', 'n02381460_1009.jpg', 'n02381460_1011.jpg', 'n02381460_1014.jpg', 'n02381460_1019.jpg', 'n02381460_102.jpg']
Notebook Image
class CycleganDataset(Dataset):
    def __init__(self, zebra_dir, horse_dir, transform=None):
        self.zebra_dir = zebra_dir
        self.horse_dir = horse_dir
        self.transform = transform

        self.zebra_images = os.listdir(zebra_dir)
        self.horse_images = os.listdir(horse_dir)
        self.length_dataset = max(len(self.zebra_images), len(self.horse_images)) # 1000, 1500
        self.zebra_len = len(self.zebra_images)
        self.horse_len = len(self.horse_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        zebra_img = self.zebra_images[index]
        horse_img = self.horse_images[index]

        zebra_path = os.path.join(self.zebra_dir, zebra_img)
        horse_path = os.path.join(self.horse_dir, horse_img)

        zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
        horse_img = np.array(Image.open(horse_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=zebra_img, image0=horse_img)
            zebra_img = augmentations["image"]
            horse_img = augmentations["image0"]

        return zebra_img, horse_img
import albumentations as A
from albumentations.pytorch import ToTensorV2

horse_dir=data_dir+'horse'
zebra_dir=data_dir+'zebra'

transforms=A.Compose([
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
     additional_targets={"image0": "image"},
)
ds=CycleganDataset(horse_dir,zebra_dir,transforms)