Hi, I don’t understand why we have to extend the functionality of nn.Module() in this case
class MnistModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(input_size, num_classes) def forward(self, xb): xb = xb.reshape(-1, 784) out = self.linear(xb) return out model = MnistModel()
I understand why reshaping is done, it is the tensors are in the shape [128,1,28,28] and we need 2 dimensions, that is [128,784]
Now, the code shown below throws an error as the images are not reshaped.
for images, labels in train_loader: print(images.shape) outputs = model(images) break
Instead of extending the nn.Module, why don’t we do something like this?
for images, labels in train_loader: print(labels) print(images.shape) images = images.reshape(128,784) print(images.shape) outputs = model(images) print(outputs) break
When I checked, it does give me the output. Could you help me in understanding this concept and if my approach is right?