Learn practical skills, build real-world projects, and advance your career
import os
import tarfile

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split
from torch.utils.data.dataloader import DataLoader

import torchvision.transforms as tt
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid


import numpy as np

import matplotlib.pyplot as plt
PROJECT_NAME = 'resnet-learn'
DATASET_URL = 'http://files.fast.ai/data/cifar10.tgz'
download_url(DATASET_URL, './data')

with tarfile.open('./data/cifar10.tgz', 'r:gz') as tar:
    tar.extractall(path='./data')
Using downloaded and verified file: ./data/cifar10.tgz
DATASET_DIR = './data/cifar10'
TRAIN_DATA_DIR = DATASET_DIR + '/train'
TEST_DATA_DIR = DATASET_DIR + '/test'
def std_mean(dataset):
    imgs = []
    for img, _ in dataset:
        imgs.append(np.array(img) / 255)
    
    imgs = np.stack(imgs)
    
    reds, greens, blues = imgs[:,:,:,0], imgs[:,:,:,1], imgs[:,:,:,2]
    means = np.mean(reds), np.mean(greens), np.mean(blues)
    stds = np.std(reds), np.std(greens), np.std(blues)
    
    return means, stds