Learn practical skills, build real-world projects, and advance your career

5 Basic Pytorch Functions

PyTorch is an open source machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, primarily developed by Facebook's AI Research lab.

It’s a Python-based scientific computing package targeted at two sets of audiences:

  1. A replacement for NumPy to use the power of GPUs
  2. a deep learning research platform that provides maximum flexibility and speed
  • function 1 - torch.split
  • function 2 - torch.arange
  • function 3 - torch.randperm
  • function 4 - torch.where
  • function 5 - torch.remainder

Before we begin, let's install and import PyTorch



# Windows
#!pip install numpy torch==1.7.0+cpu torchvision==0.8.1+cpu torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html

# Import torch and other required modules
import torch

Function 1 - torch.split

Splits the tensor into chunks. Each chunk is a view of the original tensor.

If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.

If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks with sizes in dim according to split_size_or_sections

Parameters:
  • tensor (Tensor) – tensor to split.

  • split_size_or_sections (int) or (list(int)) – size of a single chunk or list of sizes for each chunk

  • dim (int) – dimension along which to split the tensor.

# Example 1 - working 
a1= torch.tensor([[1., 2, 5], [3, 4, 7],[2, 9, 6,]])
a1
tensor([[1., 2., 5.],
        [3., 4., 7.],
        [2., 9., 6.]])