Sign In

Classification of Food-101 datasets

As the name says Food-101 has 101 numbers of classes.
Let's have a quick overview of the dataset:

  • No. of food categories (classes): 101
  • Total no. of images: 101,000 (1000 images/class)
  • Training images/class: 750
  • Test images/class: 250
  • Rescaled Image size (maximum): (512x512) pixels

N.B: The training images were not cleaned i.e. contain some amount of noise (like intense colors or wrong labels)


In [1]:
!pip install jovian -q
In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
In [3]:
import jovian
In [4]:
# pytorch imports
import torch
import torchvision
from torchvision import models, transforms, datasets
from import DataLoader
from torch import nn
from torch import optim
from torch.autograd import Variable

import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
import collections
from shutil import copy
from shutil import copytree, rmtree
import random
from tqdm import tqdm_notebook as tqdm
import math
import time
from IPython.core.debugger import set_trace
In [ ]:
bs = 64
epochs = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
imagenet_stats = [(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)]
In [ ]:
FOOD_PATH = "./food-101"
IMG_PATH = FOOD_PATH+"/images"
MODEL_PATH = 'model_data/'
In [ ]:
gpu = True if torch.cuda.is_available() else False
In [ ]:
filename = MODEL_PATH+'clr.pth'

Helper Functions

In [ ]:
def pp_(*args, n_dash=120):
    for arg in args:
In [ ]:
def list_dir(path="./"): return os.listdir(path)
In [ ]:
def cal_mean_std(train_data):
    return np.mean(train_data, axis=(0,1,2))/255, np.std(train_data, axis=(0,1,2))/255
In [ ]:
def save_checkpoint(model, is_best, filename='model_data/checkpoint.pth'):
    """Save checkpoint if a new best is achieved"""
    if is_best:, filename)  # save checkpoint
        print ("=> Validation Accuracy did not improve")
# from fastai library
def load_checkpoint(model, filename = 'model_data/checkpoint.pth'):
    sd = torch.load(filename, map_location=lambda storage, loc: storage)
    names = set(model.state_dict().keys())
    for n in list(sd.keys()): 
        if n not in names and n+'_raw' in names:
            if n+'_raw' not in sd: sd[n+'_raw'] = sd[n]
            del sd[n]
In [ ]:
def save_model(model, path):, path)
def load_model(model, path):
In [ ]:
def calc_iters(dataset, num_epochs, bs):
    return int(len(dataset) * num_epochs /bs)
In [ ]:
def accuracy(output, target, is_test=False):
    global total
    global correct
    batch_size = output.shape[0]
    total += batch_size
    _, pred = torch.max(output, 1)
    if is_test:
    correct += (pred == target).sum()
    return 100 * correct / total
In [ ]:
class AvgStats(object):
    def __init__(self):
    def reset(self):
        self.losses =[]
        self.precs =[]
        self.its = []
    def append(self, loss, prec, it):
In [ ]:
def freeze(model):
    child_counter = 0
    for name, child in model.named_children():
        if child_counter < 7:
            print("name ",name, "child ",child_counter," was frozen")
            for param in child.parameters():
                param.requires_grad = False
        elif child_counter == 7:
            children_of_child_counter = 0
            for children_of_child in child.children():
                if children_of_child_counter < 2:
                    for param in children_of_child.parameters():
                        param.requires_grad = False
                    print("name ",name, 'child ', children_of_child_counter, 'of child',child_counter,' was frozen')
                    print("name ",name, 'child ', children_of_child_counter, 'of child',child_counter,' was not frozen')
                children_of_child_counter += 1

            print("name ",name, "child ",child_counter," was not frozen")
        child_counter += 1
In [ ]:
def unfreeze(model):
    for param in model.parameters():
        param.requires_grad = True
In [ ]:
def print_frozen_state(model):
    child_counter = 0
    for name, child in model.named_children():
        for param in child.parameters():
            if param.requires_grad == True:
                print("child ",child_counter,"named:",name," is unfrezed")
            elif param.requires_grad == False:
                print("child ",child_counter,"named:",name," is frezed")
        child_counter += 1
In [ ]:
def update_lr(optimizer, lr):
    for g in optimizer.param_groups:
        g['lr'] = lr
In [ ]:
def update_mom(optimizer, mom):
    for g in optimizer.param_groups:
        g['momentum'] = mom

Dataset Preparation

Create a class to download & prepare train and validation data

In [ ]:
class FOOD101():
    def __init__(self):
        self.train_ds, self.valid_ds, self.train_cls, self.valid_cls = [None]*4
        self.imgenet_mean = imagenet_stats[0]
        self.imgenet_std = imagenet_stats[1]
    def get_data_extract(self):
        if "food-101" in os.listdir():
            print("Dataset already exists")
            print("Downloading the data...")
            print("Dataset downloaded!")
            print("Extracting data..")
            !tar xzvf food-101.tar.gz
            print("Extraction done!")
    def _get_tfms(self):
        train_tfms = transforms.Compose([
            transforms.Normalize(self.imgenet_mean, self.imgenet_std)])
        valid_tfms = transforms.Compose([
            transforms.Normalize(self.imgenet_mean, self.imgenet_std)])        
        return train_tfms, valid_tfms            
    def get_dataset(self,root_dir='./food-101/'):
        train_tfms, valid_tfms = self._get_tfms() # transformations
        self.train_ds = datasets.ImageFolder(root=TRAIN_PATH, transform=train_tfms)
        self.valid_ds = datasets.ImageFolder(root=VALID_PATH, transform=valid_tfms)        
        self.train_classes = self.train_ds.classes
        self.valid_classes = self.valid_ds.classes

        assert self.train_classes==self.valid_classes
        return self.train_ds, self.valid_ds, self.train_classes

    def get_dls(self, train_ds, valid_ds, bs, **kwargs):
        return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
               DataLoader(valid_ds, batch_size=bs//2, shuffle=False, **kwargs))
food = FOOD101()    

Download the data from source and extract

Download data
In [ ]:
Dataset already exists
In [ ]:
!ls food-101/
README.txt images license_agreement.txt meta train valid
In [ ]:
pp_(list_dir(FOOD_PATH), list_dir(IMG_PATH), list_dir(META_PATH))
['meta', 'images', '.ipynb_checkpoints', 'train', 'valid', 'license_agreement.txt', 'README.txt'] ------------------------------------------------------------------------------------------------------------------------ ['gnocchi', 'cheese_plate', 'ramen', 'ice_cream', 'donuts', 'cup_cakes', 'french_onion_soup', 'bruschetta', 'poutine', 'grilled_cheese_sandwich', 'grilled_salmon', 'chicken_curry', 'baby_back_ribs', 'crab_cakes', 'caprese_salad', 'strawberry_shortcake', 'garlic_bread', 'tuna_tartare', 'hummus', 'chocolate_mousse', 'tiramisu', 'huevos_rancheros', 'french_toast', 'samosa', 'peking_duck', 'beet_salad', 'spring_rolls', 'pizza', 'macarons', 'fish_and_chips', 'beignets', 'panna_cotta', 'beef_carpaccio', 'red_velvet_cake', 'foie_gras', 'clam_chowder', 'pulled_pork_sandwich', 'hamburger', 'lobster_bisque', 'spaghetti_bolognese', 'frozen_yogurt', 'macaroni_and_cheese', 'omelette', 'caesar_salad', 'croque_madame', 'chocolate_cake', 'pho', 'nachos', 'falafel', 'fried_calamari', 'chicken_wings', 'club_sandwich', 'mussels', 'lobster_roll_sandwich', 'risotto', 'breakfast_burrito', 'ravioli', 'edamame', 'chicken_quesadilla', 'baklava', 'bibimbap', 'pancakes', 'bread_pudding', 'guacamole', 'greek_salad', 'miso_soup', 'eggs_benedict', 'fried_rice', 'pad_thai', 'shrimp_and_grits', 'hot_and_sour_soup', 'deviled_eggs', 'cannoli', 'oysters', 'escargots', 'tacos', 'apple_pie', 'seaweed_salad', 'paella', 'carrot_cake', 'filet_mignon', 'churros', 'sushi', 'takoyaki', 'waffles', 'ceviche', 'cheesecake', 'dumplings', 'lasagna', 'sashimi', 'hot_dog', 'prime_rib', 'onion_rings', 'steak', 'creme_brulee', 'pork_chop', 'spaghetti_carbonara', 'french_fries', 'scallops', 'beef_tartare', 'gyoza'] ------------------------------------------------------------------------------------------------------------------------ ['test.json', 'train.txt', 'classes.txt', 'labels.txt', 'test.txt', 'train.json'] ------------------------------------------------------------------------------------------------------------------------

meta folder contains the text files - train.txt and test.txt
train.txt contains the list of images that belong to training set
test.txt contains the list of images that belong to test set
classes.txt contains the list of all classes of food

In [ ]:
!head food-101/meta/train.txt
apple_pie/1005649 apple_pie/1014775 apple_pie/1026328 apple_pie/1028787 apple_pie/1043283 apple_pie/1050519 apple_pie/1057749 apple_pie/1057810 apple_pie/1072416 apple_pie/1074856
In [ ]:
!head food-101/meta/classes.txt
apple_pie baby_back_ribs baklava beef_carpaccio beef_tartare beet_salad beignets bibimbap bread_pudding breakfast_burrito
Split the image data into train and test using train.txt and test.txt
In [ ]:
# Helper method to split dataset into train and test folders
def prepare_data(filepath, src, dest):
    classes_images = defaultdict(list)
    with open(filepath, 'r') as txt:
        paths = [read.strip() for read in txt.readlines()]
        for p in paths:
            food = p.split('/')
            classes_images[food[0]].append(food[1] + '.jpg')

    for food in classes_images.keys():
        print("\nCopying images into ",food)
        if not os.path.exists(os.path.join(dest,food)):
        for i in classes_images[food]:
            copy(os.path.join(src,food,i), os.path.join(dest,food,i))
    print("Copying Done!")
In [ ]:
# # Prepare train dataset by copying images from food-101/images to food-101/train using the file train.txt
# print("Creating train data...")
# prepare_data(META_PATH+'train.txt', IMG_PATH, TRAIN_PATH)
In [ ]:
# # Prepare validation data by copying images from food-101/images to food-101/valid using the file test.txt
# print("Creating validation data...")
# prepare_data(META_PATH+'test.txt', IMG_PATH, VALID_PATH)
In [ ]:
# Check how many files are in the train folder
print("Total number of samples in train folder")
!find food-101/train -type d -or -type f -printf '.' | wc -c
Total number of samples in train folder 75750
In [ ]:
# Check how many files are in the test folder
print("Total number of samples in validation folder")
!find food-101/valid -type d -or -type f -printf '.' | wc -c
Total number of samples in validation folder 25250
Create Datasets & DataLoaders
In [ ]:
train_ds, valid_ds, classes =  food.get_dataset()
num_classes = len(classes)
['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles'] ------------------------------------------------------------------------------------------------------------------------ 101 ------------------------------------------------------------------------------------------------------------------------
In [ ]:
train_dl, valid_dl = food.get_dls(train_ds, valid_ds, bs=bs, num_workers=2)

Wrap all inside databunch

In [ ]:
class DataBunch():
    def __init__(self, train_dl, valid_dl, c=None):
        self.train_dl, self.valid_dl, self.c = train_dl, valid_dl, c
    def train_ds(self): return self.train_dl.dataset
    def valid_ds(self): return self.valid_dl.dataset
    def __repr__(self):
        return str(self.__class__.__name__)+" obj (train & valid DataLoaders)"
In [ ]:
data = DataBunch(train_dl, valid_dl, c=num_classes)
In [ ]:
pp_(data.valid_ds, data.c, data.train_ds)
Dataset ImageFolder Number of datapoints: 25250 Root location: ./food-101/valid ------------------------------------------------------------------------------------------------------------------------ 101 ------------------------------------------------------------------------------------------------------------------------ Dataset ImageFolder Number of datapoints: 75750 Root location: ./food-101/train ------------------------------------------------------------------------------------------------------------------------
In [ ]:
# a = None
# f"Showing one random image from each {'Validation' if a else 'Train'} classes"
Visualize random image from each of the 101 classes
In [ ]:
# This can be used to print predictions too
def show_ds(trainset, classes, validset=None, cols=6, rows=17, preds=None, is_pred=False, is_valid=False):        
    fig = plt.figure(figsize=(25,25))
    fig.suptitle(f"Showing one random image from each {'Validation' if is_valid else 'Train'} classes", y=0.92, fontsize=24) # Adding  y=1.05, fontsize=24 helped me fix the suptitle overlapping with axes issue
    columns = cols
    rows = rows
    imgenet_mean = imagenet_stats[0]
    imgenet_std = imagenet_stats[1]  

    for i in range(1, columns*rows +1):
        fig.add_subplot(rows, columns, i)
        if is_pred and testset:
            img_xy = np.random.randint(len(testset));
            np_img = testset[img_xy][0].numpy()
            img = np.transpose(np_img, (1,2,0))            
            img = img * imgenet_std + imgenet_mean
            img_xy = np.random.randint(len(trainset));
            np_img = trainset[img_xy][0].numpy()
            img = np.transpose(np_img, (1,2,0))
            img = img * imgenet_std + imgenet_mean
        if is_pred:
            plt.title(classes[int(preds[img_xy])] + "/" + classes[testset[img_xy][1]])
        img = np.clip(img, 0, 1)
        plt.imshow(img, interpolation='nearest')
In [ ]:
show_ds(data.train_ds, classes)