Updated 4 years ago
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import torch
import torchvision
import tarfile
import torchvision
from torchvision.datasets.utils import download_url
from torch.utils.data import random_split
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.transforms import ToTensor
from torch.utils.data.dataloader import DataLoader
import torch.nn as nn
import torch.nn.functional as F
base_dir = "../input/flowers-recognition/flowers/flowers"
os.listdir(base_dir)
['daisy', 'rose', 'dandelion', 'sunflower', 'tulip']
transformer = torchvision.transforms.Compose(
[ # Applying Augmentation
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.RandomVerticalFlip(p=0.5),
torchvision.transforms.RandomRotation(30),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
dataset = ImageFolder(base_dir, transform=transformer)
len(dataset)
4323
import matplotlib.pyplot as plt
def show_example(img, label):
print('Label: ', dataset.classes[label], "("+str(label)+")")
plt.imshow(img.permute(1, 2, 0))