Learn practical skills, build real-world projects, and advance your career
import torch

Merging Tensors: 5 functions you should be aware of

"PyTorch is an optimized tensor library for deep learning using GPUs and CPUs."

-from PyTorch Documentation

When it comes to machine learning, we have dealt with pandas DataFrame and ways to merge them. We could see those as merging of 2 dimensional data structures. But since tensors can be of 3 dimensions or even more, it is essential to know the ways in which they can be merged.

Here we try and look at 5 functions you should know of to understand different ways of merging tensors.

Function 1 - torch.cat

Function 2 - torch.stack

Function 3 - torch.unsqueeze

Function 4 - torch.hstack

Function 5 - torch.vstack

N.B:

Before we dive into merging, it's best to remember the following:
dimensions are numbered, starting from 0, similar to python indexing, But merging two 3 D tensors can be tricky and

Merging along

dimension 0 is like merging two tensors channel-wise visually

dimension 1 is like merging two tensors row-wise visually

dimension 2 is like merging two tensors column-wise visually

Hence a tensor with shape (2,3,4) would look like it has 2 channels, each containing 3 rows and 4 column (tensor t displayed below for reference) and a tensor with shape (4,1,3) looks like it has 4 channels, each containing a 1 row and 3 columns.

t =torch.tensor([[[4, 4, 4, 4],
         [7, 7, 7, 7],
         [7, 7, 7, 7]],

        [[4, 4, 4, 4],
         [7, 7, 7, 7],
         [7, 7, 7, 7]]])
print(t)
print(t.size())
tensor([[[4, 4, 4, 4], [7, 7, 7, 7], [7, 7, 7, 7]], [[4, 4, 4, 4], [7, 7, 7, 7], [7, 7, 7, 7]]]) torch.Size([2, 3, 4])

#Function 1 - torch.cat

*torch.cat(tensors, dim=0, , out=None) → Tensor

Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.

This is one of the most common ways in which two tensors can be merged.

Note: In order to make the size of tensors apparent, I'll be using torch.full to create tensors.