Learn practical skills, build real-world projects, and advance your career

PyTorch Tensors Function That Will Save Time

PyTorch is an open-source library for Python. It is a machine learning library. A tensor is scalar, vector, matrix, or n-dimensional data container which is similar to NumPy’s ndarray. Lets see 5 PyTorch tensors. They are as follows -

  • torch.where(condition, a, b)
  • torch.unbind(input, dim=0)
  • torch.roll(input, shifts, dims=None)
  • torch.tril(input, diagonal=0, *, out=None)
  • select(dim, index)

Before we begin, let's install and import PyTorch

import torch

Function 1 - torch.where(condition, x, y)

This returns a tensor of elements selected from either x or y, depending on condition. The operation is defined as:

where.JPG

Example - 1

a1 = torch.randn(3, 3)
b1 = torch.ones([3, 3], dtype = torch.float)
print('Tensor a1:\n', a1)
print('Tensor b1:\n', b1)
Tensor a1: tensor([[-0.4091, 0.6474, -0.5680], [-0.3016, 1.7412, 0.6511], [ 0.3040, 0.3834, -0.4173]]) Tensor b1: tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]])