Learn practical skills, build real-world projects, and advance your career

Ajinkya Patil

Introduction to PyTorch tensor operations

PyTorch offers users a variety of optimization functions to handle tensors for deep learning utilizing GPU's and CPU's. I'll take you through some of the methods offered by the torch api.

  • Finding the max value and its index in a tensor
  • Splitting Tensors
  • Gathering elements along axes
  • Stacking Tensors
  • Applying Masks

Before we begin, let's install and import PyTorch

!pip install numpy torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
Looking in links: https://download.pytorch.org/whl/torch_stable.html Requirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (1.19.2) Requirement already satisfied: torch==1.7.0+cpu in /opt/conda/lib/python3.8/site-packages (1.7.0+cpu) Requirement already satisfied: torchvision==0.8.1+cpu in /opt/conda/lib/python3.8/site-packages (0.8.1+cpu) Requirement already satisfied: torchaudio==0.7.0 in /opt/conda/lib/python3.8/site-packages (0.7.0) Requirement already satisfied: future in /opt/conda/lib/python3.8/site-packages (from torch==1.7.0+cpu) (0.18.2) Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.8/site-packages (from torch==1.7.0+cpu) (3.7.4.3) Requirement already satisfied: dataclasses in /opt/conda/lib/python3.8/site-packages (from torch==1.7.0+cpu) (0.6) Requirement already satisfied: pillow>=4.1.1 in /opt/conda/lib/python3.8/site-packages (from torchvision==0.8.1+cpu) (8.0.0)
# Import torch and other required modules
import torch

Function 1 - Finding the max value and its index in a tensor

The torch function amax is used to get the maximum value of the tensor while the function argmax is used to get the index of the maximum value of the tensor.

Specifying

dim = 0 finds the maximum value in each column(along the rows), dim = 1 finds the maximum value in each row( along the columns), dim = [0,1] finds the overall maximum value,

(~ functions for minimum)

Lets look at a simple 1D tensor and find the max value and its index.

# Example 1 - working

tensor5 = torch.tensor([1., 2., 3., 4., 5.])

# rand_like returns a tensor that has the SAME SIZE as the input tensor with values distributed from [0,1) ) means last value is not inclusive
sample_tensor1 = torch.rand_like(tensor5)

## NUMEL  -> number of elements in tensor
tensor_len = torch.numel(sample_tensor1)

print("number of elements in tensor :",tensor_len)
print("sample tensor1 :",sample_tensor1)

#amax -> finds the maximum value of each slice of the tensor in the given dimension
row_max = torch.amax(input=sample_tensor1, dim=0)

max_index_row = torch.argmax(input=sample_tensor1, dim=0) # use argmax to find INDEX of max value
print("\nmax value of {val} is found at index {ind}".format(val=row_max, ind=max_index_row))
number of elements in tensor : 5 sample tensor1 : tensor([0.2545, 0.4109, 0.3918, 0.5876, 0.7228]) max value of 0.7227811813354492 is found at index 4