Learn practical skills, build real-world projects, and advance your career
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

from sklearn.model_selection import train_test_split
from fastai.vision import *
from fastai.metrics import error_rate

np.random.seed(2)
class CustomImageList(ImageList):
    def open(self, fn):
        img = fn.reshape(28,28)
        img = np.stack((img,)*3, axis=-1)
        return Image(pil2tensor(img, dtype=np.float32))
    
    @classmethod
    def from_csv_custom(cls, path:PathOrStr, csv_name:str, imgIdx:int=1, header:str='infer', **kwargs)->'ItemList': 
        df = pd.read_csv(Path(path)/csv_name, header=header)
        res = super().from_df(df, path=path, cols=0, **kwargs)
        
        res.items = df.iloc[:,imgIdx:].apply(lambda x: x.values / 255.0, axis=1).values
        
        return res
    
    @classmethod
    def from_df_custom(cls, path:PathOrStr, df:DataFrame, imgIdx:int=1, header:str='infer', **kwargs)->'ItemList': 
        res = super().from_df(df, path=path, cols=0, **kwargs)
        
        res.items = df.iloc[:,imgIdx:].apply(lambda x: x.values / 255.0, axis=1).values
        
        return res
test = CustomImageList.from_csv_custom(path='../input', csv_name='test.csv', imgIdx=0)
data = (CustomImageList.from_csv_custom(path='../input', csv_name='train.csv', imgIdx=1)
                .split_by_rand_pct(.2)
                .label_from_df(cols='label')
                .add_test(test, label=0)
                .transform(get_transforms(do_flip=False))
                .databunch(bs=128, num_workers=0)
                .normalize(imagenet_stats))
data.show_batch(rows=3, figsize=(5,5))
Notebook Image
print(type(data))
print(data.c)
<class 'fastai.vision.data.ImageDataBunch'> 10