Forward function in MNIST Model

How is the forward function within the MNIST Model class automatically invoked when we only pass the images. Shouldn’t the function be called as model.forward(image)?..The code works even when it is called as model(image).How does that happen?


A nn.Module class has a __call__ method defined.
This function also takes care of calling any additional hooks that may have been defined (before/after using the actual forward method).

WTH is __call__?


I have the same question, I am unable to understand the reason why the forward function is automatically invoked?
Also, I couldn’t understand the answer posted here by @Sebgolos as I am new to python even programming. If anyone can post another answer, then it would be more helpful.

class Example:
   # ... some important methods ...

   def forward(self):
      print("hehe, forward is invoked")

   def __call__(self, *args):
      self.forward() # run forward, no need to do anything like object.forward()

object = Example()
object.__call__() # invoke __call__ method
object() # shortcut for the above (invokes the __call__ method too)

The forward is automatically called by __call__ method, which in turn can be called using a shortcut.

Why like this?

Well, because if you look at my example, you can see that the __call__ method calls also some other methods besides forward. If you would call directly forward, you would avoid calling also these “important” functions.

1 Like

Ok, I understood this @Sebgolos .
Just one more question to be clear. Does this __call__ method call every function that has been defined in the class?

If it is so, does this mean that every function gets executed in that class even if we don’t require others to get executed?

For instance, if we execute out = self(images), then, is it going to execute every defined function in the class that is self.forward(), self.training_step(), self.validation_step(), self.validation_epoch_end(), self.epoch_end() along with others.

No, it only calls whatever is coded up inside the __call__ function.

It’s just really like any other method, just calling it can be done using a shortcut :stuck_out_tongue:

1 Like

Ok, I got this, one more thing though @Sebgolos

class MnistModel(nn.Module):
    def __init__(self):
        self.linear = nn.Linear(input_size, num_classes)
    def forward(self, xb):
        xb = xb.reshape(-1, 784)
        out = self.linear(xb)
        return out
model = MnistModel()

In the above code, we are not coding the forward function inside the __call__ method, or are we?

Yes, but you inherit from nn.Module, which has this method defined (among many others BTW).

By inheriting you also inherit all the functions (since why you can use model.parameters(), without defining this method in your model). Making a new model then becomes easy because you only have to take care of __init__ and forward methods, the rest is provided by the nn.Module.

1 Like

Ok, thanks, This was very helpful, and the links on your first answer too. :slightly_smiling_face: