Why should we extend nn.module()?

Hi, I don’t understand why we have to extend the functionality of nn.Module() in this case

    class MnistModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = nn.Linear(input_size, num_classes)
        
    def forward(self, xb):
        xb = xb.reshape(-1, 784)
        out = self.linear(xb)
        return out
    
model = MnistModel()

I understand why reshaping is done, it is the tensors are in the shape [128,1,28,28] and we need 2 dimensions, that is [128,784]
Now, the code shown below throws an error as the images are not reshaped.

for images, labels in train_loader:
    print(images.shape)
    outputs = model(images)
    break

Instead of extending the nn.Module, why don’t we do something like this?

   for images, labels in train_loader:
   print(labels)
   print(images.shape)
   images = images.reshape(128,784)
   print(images.shape)
   outputs = model(images)
   print(outputs)
   break

When I checked, it does give me the output. Could you help me in understanding this concept and if my approach is right?

You’ve posted link to a colab notebook, which is not accessible to everyone.

I have edited it :slight_smile: Thank you for letting me know!

Well, if it works both ways, then it’s up to you to decide which one you prefer :stuck_out_tongue:

Albeit I think it is better to pack everything “non-transform-related” in a module. You shouldn’t have to be aware that the model only accepts flattened inputs - you should care only about the inputs and outputs, and consider a model to be a sort of magic blackbox which just spits out the results.

Also, consider this: you have a training function, which accepts a model (ANY). You get inputs from data loader, you throw it into a model, you get outputs, calculate loss, back-propagate it and optimize the model.

All of this happens for every model. But if you have a model which needs to have data preprocessed somehow (even if it’s just reshaping) before giving it as an input, you’re limiting this training function to a specific set of models, for which such reshaping is necessary.

1 Like

Well, said “I’ll be limiting the training function to a specific set of models”. Aren’t we limiting the models in either case? Like for example, let’s say I am just providing the images without reshaping. In the class definition of MnistModel, we have the following:

class MnistModel(nn.Module):
def __init__(self):
    super().__init__()
    self.linear = nn.Linear(input_size, num_classes)
    
def forward(self, xb):
    xb = xb.reshape(-1, 784)
    out = self.linear(xb)
    return out

def training_step(self, batch):
    images, labels = batch 
    out = self(images)                  # Generate predictions
    loss = F.cross_entropy(out, labels) # Calculate loss
    return loss

Now, in this case, we pass the batch of images to training_step, which indirectly calls the forward function where the images are reshaped. So, I reshape the images and give them to the model or give the images directly, either way, we are limiting the training function to a specific set of models. Only, if the forward method is not included, what you said is valid, I suppose?

If you give the inputs (in whatever form) to your model, and do the necessary reshaping or other data preprocessing inside, then where is this “limit” you think of?

Take a closer look at what such function would do:

  • get inputs and targets from dataloader
  • get outputs from the model (using inputs)
  • compute loss (using outputs and targets)
  • update model through optimizer

There’s no point where the function should care about any shape or form of any tensor.

If a model accepts images - the dataloader spits out image data
If a model accepts data in a form of series of values - the dataloader spits out such data.
If a model accepts sounds - the dataloader spits out sound data

You now only feed it to a model which gives you some sort of outputs. Which can be used to get the loss using targets (yeah, this also comes from a dataloader).
You give this loss to a optimizer, and it’s job is to correctly update the model.

In none of these steps the training function actually computes anything on it’s own - it just passes around the data

Thank you for clarifying :slight_smile: