Learn practical skills, build real-world projects, and advance your career
Updated 3 years ago
Ajinkya Patil
Introduction to PyTorch tensor operations
PyTorch offers users a variety of optimization functions to handle tensors for deep learning utilizing GPU's and CPU's. I'll take you through some of the methods offered by the torch api.
- Finding the max value and its index in a tensor
- Splitting Tensors
- Gathering elements along axes
- Stacking Tensors
- Applying Masks
Before we begin, let's install and import PyTorch
!pip install numpy torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Requirement already satisfied: numpy in /opt/conda/lib/python3.8/site-packages (1.19.2)
Requirement already satisfied: torch==1.7.0+cpu in /opt/conda/lib/python3.8/site-packages (1.7.0+cpu)
Requirement already satisfied: torchvision==0.8.1+cpu in /opt/conda/lib/python3.8/site-packages (0.8.1+cpu)
Requirement already satisfied: torchaudio==0.7.0 in /opt/conda/lib/python3.8/site-packages (0.7.0)
Requirement already satisfied: future in /opt/conda/lib/python3.8/site-packages (from torch==1.7.0+cpu) (0.18.2)
Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.8/site-packages (from torch==1.7.0+cpu) (3.7.4.3)
Requirement already satisfied: dataclasses in /opt/conda/lib/python3.8/site-packages (from torch==1.7.0+cpu) (0.6)
Requirement already satisfied: pillow>=4.1.1 in /opt/conda/lib/python3.8/site-packages (from torchvision==0.8.1+cpu) (8.0.0)
# Import torch and other required modules
import torch
Function 1 - Finding the max value and its index in a tensor
The torch function amax is used to get the maximum value of the tensor while the function argmax is used to get the index of the maximum value of the tensor.
Specifying
dim = 0 finds the maximum value in each column(along the rows), dim = 1 finds the maximum value in each row( along the columns), dim = [0,1] finds the overall maximum value,
(~ functions for minimum)
Lets look at a simple 1D tensor and find the max value and its index.
# Example 1 - working
tensor5 = torch.tensor([1., 2., 3., 4., 5.])
# rand_like returns a tensor that has the SAME SIZE as the input tensor with values distributed from [0,1) ) means last value is not inclusive
sample_tensor1 = torch.rand_like(tensor5)
## NUMEL -> number of elements in tensor
tensor_len = torch.numel(sample_tensor1)
print("number of elements in tensor :",tensor_len)
print("sample tensor1 :",sample_tensor1)
#amax -> finds the maximum value of each slice of the tensor in the given dimension
row_max = torch.amax(input=sample_tensor1, dim=0)
max_index_row = torch.argmax(input=sample_tensor1, dim=0) # use argmax to find INDEX of max value
print("\nmax value of {val} is found at index {ind}".format(val=row_max, ind=max_index_row))
number of elements in tensor : 5
sample tensor1 : tensor([0.2545, 0.4109, 0.3918, 0.5876, 0.7228])
max value of 0.7227811813354492 is found at index 4