Learn practical skills, build real-world projects, and advance your career

Five tensor functions one should know

Lesson 01 - Assignment

PyTorch is a python based library for machine learning. torch.tensor is the fundamental data structure to store multi-dimensional tensors. A tensor is a multi-dimensional matrix that contains elements of single data type. tensor has many in-built functions that are useful for basic operations. Some of the functions are shown in this notebook.

  • view
  • squeeze (unsqueeze)
  • masked_select
  • where
  • backward
# Import torch and other required modules
import torch
import numpy as np

Function 1 - view()

view() modifies the shape of the tensor. tensor.view returns a tensor that has the same data and same number of elements as the original tensor but changes only the shape of the tensor. view() does not change the underlaying data but only the shape.

view(*shape) → Tensor

https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view

# Example 1 - working 
a = torch.rand(3,5)
print(a.size())
b = a.view(15)
print(b.size())
c = a.view(-1,3)
print(c.size())
torch.Size([3, 5]) torch.Size([15]) torch.Size([5, 3])

b is a tensor created from a.view(15) that reshaped the 2D tensor into a 1D tensor. To create tensor c from a using tensor.view we passed the first dimension as -1 and second dimension as 3. The first dimension is inferred from the shape of original tensor.