Learn practical skills, build real-world projects, and advance your career
Created 3 years ago
import os
import torch
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset
from PIL import Image
data_dir='./cyclegan/data/train/'
print(os.listdir('./cyclegan/data/train/horse/')[:10])
Image.open(data_dir+'horse/n02381460_1001.jpg')
['n02381460_1001.jpg', 'n02381460_1002.jpg', 'n02381460_1003.jpg', 'n02381460_1006.jpg', 'n02381460_1008.jpg', 'n02381460_1009.jpg', 'n02381460_1011.jpg', 'n02381460_1014.jpg', 'n02381460_1019.jpg', 'n02381460_102.jpg']
class CycleganDataset(Dataset):
def __init__(self, zebra_dir, horse_dir, transform=None):
self.zebra_dir = zebra_dir
self.horse_dir = horse_dir
self.transform = transform
self.zebra_images = os.listdir(zebra_dir)
self.horse_images = os.listdir(horse_dir)
self.length_dataset = max(len(self.zebra_images), len(self.horse_images)) # 1000, 1500
self.zebra_len = len(self.zebra_images)
self.horse_len = len(self.horse_images)
def __len__(self):
return self.length_dataset
def __getitem__(self, index):
zebra_img = self.zebra_images[index]
horse_img = self.horse_images[index]
zebra_path = os.path.join(self.zebra_dir, zebra_img)
horse_path = os.path.join(self.horse_dir, horse_img)
zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
horse_img = np.array(Image.open(horse_path).convert("RGB"))
if self.transform:
augmentations = self.transform(image=zebra_img, image0=horse_img)
zebra_img = augmentations["image"]
horse_img = augmentations["image0"]
return zebra_img, horse_img
import albumentations as A
from albumentations.pytorch import ToTensorV2
horse_dir=data_dir+'horse'
zebra_dir=data_dir+'zebra'
transforms=A.Compose([
A.Resize(width=256, height=256),
A.HorizontalFlip(p=0.5),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
ToTensorV2(),
],
additional_targets={"image0": "image"},
)
ds=CycleganDataset(horse_dir,zebra_dir,transforms)