How is the forward function within the MNIST Model class automatically invoked when we only pass the images. Shouldn’t the function be called as model.forward(image)?..The code works even when it is called as model(image).How does that happen?
nn.Module class has a
__call__ method defined.
This function also takes care of calling any additional hooks that may have been defined (before/after using the actual
I have the same question, I am unable to understand the reason why the forward function is automatically invoked?
Also, I couldn’t understand the answer posted here by @Sebgolos as I am new to python even programming. If anyone can post another answer, then it would be more helpful.
class Example: # ... some important methods ... def forward(self): print("hehe, forward is invoked") def __call__(self, *args): self.pre_important_stuff(*args) self.forward() # run forward, no need to do anything like object.forward() self.post_important_stuff(*args) object = Example() object.__call__() # invoke __call__ method object() # shortcut for the above (invokes the __call__ method too)
forward is automatically called by
__call__ method, which in turn can be called using a shortcut.
Why like this?
Well, because if you look at my example, you can see that the
__call__ method calls also some other methods besides
forward. If you would call directly
forward, you would avoid calling also these “important” functions.
Ok, I understood this @Sebgolos .
Just one more question to be clear. Does this
__call__ method call every function that has been defined in the class?
If it is so, does this mean that every function gets executed in that class even if we don’t require others to get executed?
For instance, if we execute
out = self(images), then, is it going to execute every defined function in the class that is
self.epoch_end() along with others.
No, it only calls whatever is coded up inside the
It’s just really like any other method, just calling it can be done using a shortcut
Ok, I got this, one more thing though @Sebgolos
class MnistModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(input_size, num_classes) def forward(self, xb): xb = xb.reshape(-1, 784) out = self.linear(xb) return out model = MnistModel()
In the above code, we are not coding the
forward function inside the
__call__ method, or are we?
Yes, but you inherit from
nn.Module, which has this method defined (among many others BTW).
By inheriting you also inherit all the functions (since why you can use
model.parameters(), without defining this method in your model). Making a new model then becomes easy because you only have to take care of
forward methods, the rest is provided by the
Ok, thanks, This was very helpful, and the links on your first answer too.