Learn practical skills, build real-world projects, and advance your career
Updated 3 years ago
PyTorch Tensors Function That Will Save Time
PyTorch is an open-source library for Python. It is a machine learning library. A tensor is scalar, vector, matrix, or n-dimensional data container which is similar to NumPy’s ndarray. Lets see 5 PyTorch tensors. They are as follows -
- torch.where(condition, a, b)
- torch.unbind(input, dim=0)
- torch.roll(input, shifts, dims=None)
- torch.tril(input, diagonal=0, *, out=None)
- select(dim, index)
Before we begin, let's install and import PyTorch
import torch
Function 1 - torch.where(condition, x, y)
This returns a tensor of elements selected from either x or y, depending on condition. The operation is defined as:
Example - 1
a1 = torch.randn(3, 3)
b1 = torch.ones([3, 3], dtype = torch.float)
print('Tensor a1:\n', a1)
print('Tensor b1:\n', b1)
Tensor a1:
tensor([[-0.4091, 0.6474, -0.5680],
[-0.3016, 1.7412, 0.6511],
[ 0.3040, 0.3834, -0.4173]])
Tensor b1:
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])