import torch
Merging Tensors: 5 functions you should be aware of
"PyTorch is an optimized tensor library for deep learning using GPUs and CPUs."
-from PyTorch Documentation
When it comes to machine learning, we have dealt with pandas DataFrame and ways to merge them. We could see those as merging of 2 dimensional data structures. But since tensors can be of 3 dimensions or even more, it is essential to know the ways in which they can be merged.
Here we try and look at 5 functions you should know of to understand different ways of merging tensors.
Function 1 - torch.cat
Function 2 - torch.stack
Function 3 - torch.unsqueeze
Function 4 - torch.hstack
Function 5 - torch.vstack
N.B:
Before we dive into merging, it's best to remember the following:
dimensions are numbered, starting from 0, similar to python indexing, But merging two 3 D tensors can be tricky and
Merging along
dimension 0 is like merging two tensors channel-wise visually
dimension 1 is like merging two tensors row-wise visually
dimension 2 is like merging two tensors column-wise visually
Hence a tensor with shape (2,3,4) would look like it has 2 channels, each containing 3 rows and 4 column (tensor t displayed below for reference) and a tensor with shape (4,1,3) looks like it has 4 channels, each containing a 1 row and 3 columns.
t =torch.tensor([[[4, 4, 4, 4],
[7, 7, 7, 7],
[7, 7, 7, 7]],
[[4, 4, 4, 4],
[7, 7, 7, 7],
[7, 7, 7, 7]]])
print(t)
print(t.size())
tensor([[[4, 4, 4, 4],
[7, 7, 7, 7],
[7, 7, 7, 7]],
[[4, 4, 4, 4],
[7, 7, 7, 7],
[7, 7, 7, 7]]])
torch.Size([2, 3, 4])
#Function 1 - torch.cat
*torch.cat(tensors, dim=0, , out=None) → Tensor
Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty.
This is one of the most common ways in which two tensors can be merged.
Note: In order to make the size of tensors apparent, I'll be using torch.full to create tensors.