Purpose of torchvision.transforms.ToTensor()

What does transform=ToTensor() achieve in this command of the template for the Fashion-MNIST dataset?

dataset = FashionMNIST(root='data/', download=True, transform=ToTensor())

My understanding is that some of the of the datasets in torchvision.datasets are by default processed as PIL-images instead of Pytorch tensors and transform=ToTensor() ensures that the elements of the dataset get converted to torch.tensor objects first.

In the case of Fashion-MNIST there seems to be no difference, as I still get torch.Tensor from:

dataset = torchvision.datasets.FashionMNIST(root="data/", download=True, train=True)
type(dataset.data[0])

Am I missing something else that transform=ToTensor() is doing?

ToTensor() not only ensures that PIL images are converted to tensors, but also convert NumPy array and takes care of permuting too. NumPy is (H x W x C) where as PyTorch expects (C x H x W). It also scales [0, 255] to [0.0, 1.0]

2 Likes