Learn practical skills, build real-world projects, and advance your career

Handy Functions to deal with tensors using PyTorch

Here,we'll discuss few functions that are used on tensors which make programming more intuitive and easier

About PyTorch:
PyTorch is a library developed for Deep Learning using CPU and GPU.It helps to perform Tensor Computations and Deep Nueral Networks.It has different components such a torch,torch.nn,torch.utils,torch.autograd,torch.jit,torch.multiprocessing each component used for different purposes.PyTorch helps us gain maximun flexibility to utilize resources for developing deep learning research.
PyTorch provides a data structure called Tensor which is either stored in CPU or GPU.A Tensor is an n-dimentional data structure, which is immutable such as a tuple in python.The biggest difference between a numpy array and a PyTorch Tensor is that a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU, just cast the Tensor to a cuda datatype.

An short introduction about PyTorch and about the chosen functions.

  • function 1:torch.take(input, index)
    This Function rteurns the elements of tensor at the particular indices specified, the input here is treated as one-dimensional Tensor.
    This function takes two parameters namely: INPUT ->Tensor and INDEX -> the indices of our interest

  • function 2:torch.squeeze(input, dim=None, out=None)
    This function returns a tensor with all dimensions of input of size 1 removed i.e dimension 231 becomes 2*3

  • function 3:torch.bernoulli(input, *, generator=None, out=None)
    This function returns a binary value either 0 or 1 based on bernoulli distribution

  • function 4:torch.conj(input,out)
    This function returns conjugate of a tensor,i.e a conjugate is defined as a complex number which is same in magnitude and opposite in direction

  • function 5:torch.erf(input)
    This function helps us to calculate error fuction on a given tensor

  • function 6:torch.lerp(start,end,weight,out)
    This function helps us to calculate linear interpolation of given points.

# Import torch and other required modules
import torch

Function 1 - torch.take()

This functions takes the tensor and indices as input and returns elements at the given index as shown below :

# Example 1 
t = torch.tensor([[1, 2], [3, 4.]])
torch.take(t,torch.tensor([0,2]))
tensor([1., 3.])

Here, the above example take input of 2*2 matrix with elements 1.,2.,3.,4. creates tensor using torch.tensor ,now on applying torch.take(tensor,indices) where indices are 0,2 which indicate the position of tensor 1.,3. .Hence the output of the above function be tensor value of 1.,3.