Learn practical skills, build real-world projects, and advance your career
Updated 4 years ago
import os
import tarfile
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms as tt
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
PROJECT_NAME = 'resnet-learn'
DATASET_URL = 'http://files.fast.ai/data/cifar10.tgz'
download_url(DATASET_URL, './data')
with tarfile.open('./data/cifar10.tgz', 'r:gz') as tar:
tar.extractall(path='./data')
Using downloaded and verified file: ./data/cifar10.tgz
DATASET_DIR = './data/cifar10'
TRAIN_DATA_DIR = DATASET_DIR + '/train'
TEST_DATA_DIR = DATASET_DIR + '/test'
def std_mean(dataset):
imgs = []
for img, _ in dataset:
imgs.append(np.array(img) / 255)
imgs = np.stack(imgs)
reds, greens, blues = imgs[:,:,:,0], imgs[:,:,:,1], imgs[:,:,:,2]
means = np.mean(reds), np.mean(greens), np.mean(blues)
stds = np.std(reds), np.std(greens), np.std(blues)
return means, stds