Assignment #2 - problem with data types


I’m finishing now the 2nd assignent and I’ve encountered the following problem:

when running the following code:

epochs = 20
lr = 1e-3
history1 = fit(epochs, lr, model, train_loader, val_loader)

I get the following error from line #3:
RuntimeError: Found dtype Double but expected Float
the problem arises from ‘train_loader’ variable, but when I check the dtype - it’s float64 and not double. what could be the problem?

Thank you

float64 essentially means double. It should be float32 (64 is double of 32 → that’s where it’s from).

How can I convert DataLoader from float64 to float32?

Thank you

You can convert tensor with to() method (dataloader is something that already has tensors).

It accepts dtype you want it to be, torch.float32 in this case.