Five tensor functions one should know
Lesson 01 - Assignment
PyTorch is a python based library for machine learning. torch.tensor is the fundamental data structure to store multi-dimensional tensors. A tensor is a multi-dimensional matrix that contains elements of single data type. tensor has many in-built functions that are useful for basic operations. Some of the functions are shown in this notebook.
- view
- squeeze (unsqueeze)
- masked_select
- where
- backward
# Import torch and other required modules
import torch
import numpy as np
Function 1 - view()
view() modifies the shape of the tensor. tensor.view returns a tensor that has the same data and same number of elements as the original tensor but changes only the shape of the tensor. view() does not change the underlaying data but only the shape.
view(*shape) → Tensor
https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view
# Example 1 - working
a = torch.rand(3,5)
print(a.size())
b = a.view(15)
print(b.size())
c = a.view(-1,3)
print(c.size())
torch.Size([3, 5])
torch.Size([15])
torch.Size([5, 3])
b is a tensor created from a.view(15) that reshaped the 2D tensor into a 1D tensor. To create tensor c from a using tensor.view we passed the first dimension as -1 and second dimension as 3. The first dimension is inferred from the shape of original tensor.