About loss.detach() means

What does loss.detach() or simple loss means…
def validation_step(self, batch):
images, labels = batch
out = self(images)
loss = F.cross_entropy(out, labels)
acc = accuracy(out, labels)
return {‘val_loss’:loss.detach(), ‘val_acc’:acc}

def validation_step(self, batch):
images, labels = batch
out = self(images)
loss = F.cross_entropy(out, labels)
acc = accuracy(out, labels)
return {‘val_loss’:loss, ‘val_acc’:acc}

Well, loss basically means: how far our model is predicting from the target.

In the above line, loss is calculated for predicted outputs i.e. out and the target labels using cross_entropy function.

For more information, refer to the notebook:

Now, for loss.detach() method,

there are gradients associated with the loss in PyTorch which are used while training neural network model. So, while displaying the loss we don’t need this gradient and hence, .detach() method.

For more info on .detach() and with torch.no_grad(), you can visit:

2 Likes