Use of torch.stack()

2024/11/17 9:44:08
t1 = torch.tensor([1,2,3])
t2 = torch.tensor([4,5,6])
t3 = torch.tensor([7,8,9])torch.stack((t1,t2,t3),dim=1)

When implementing the torch.stack(), I can't understand how stacking is done for different dim. Here stacking is done for columns but I can't understand the details as to how it is done. It becomes more complicated dealing with 2-d or 3-D tensors.

tensor([[1, 4, 7],[2, 5, 8],[3, 6, 9]])
Answer

Imagine have n tensors. If we stay in 3D, those correspond to volumes, namely rectangular cuboids. Stacking corresponds to combining those n volumes on an additional dimension: here a 4th dimension is added to host the n 3D volumes. This operation is in clear contrast with concatenation, where the volumes would be combined on one of the existing dimensions. So concatenation of three-dimensional tensors would result in a 3D tensor.

Here is a possible representation of the stacking operations for limited dimensions sizes (up to three-dimensional inputs):

enter image description here

Where you chose to perform the stacking defines along which new dimension the stack will take place. In the above examples, the newly created dimension is last, hence the idea of "added dimension" makes more sense.

In the following visualization, we observe how tensors can be stacked on different axes. This in turn affects the resulting tensor shape

  • For the 1D case, for instance, it can also happen on the first axis, see below:

    enter image description here

    With code:

    >>> x_1d = list(torch.empty(3, 2))     # 3 lines>>> torch.stack(x_1d, 0).shape         # axis=0 stacking
    torch.Size([3, 2])>>> torch.stack(x_1d, 1).shape         # axis=1 stacking
    torch.Size([2, 3])
    
  • Similarly for two-dimensional inputs:

    enter image description here

    With code:

    >>> x_2d = list(torch.empty(3, 2, 2))   # 3 2x2-squares>>> torch.stack(x_2d, 0).shape          # axis=0 stacking
    torch.Size([3, 2, 2])>>> torch.stack(x_2d, 1).shape          # axis=1 stacking
    torch.Size([2, 3, 2])>>> torch.stack(x_2d, 2).shape          # axis=2 stacking
    torch.Size([2, 2, 3])
    

With this state of mind, you can intuitively extend the operation to n-dimensional tensors.

https://en.xdnf.cn/q/71228.html

Related Q&A

Is it possible to sort a list with reduce?

I was given this as an exercise. I could of course sort a list by using sorted() or other ways from Python Standard Library, but I cant in this case. I think Im only supposed to use reduce().from funct…

Flask-WTF set time limit on CSRF token

Im currently using Flask-WTF v0.13.1, i have a few forms on my website, all created including the CSRF token.For some reasons i have to set a different expiration on each form, so far i could set manua…

Extracting Intermediate layer outputs of a CNN in PyTorch

I am using a Resnet18 model. ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats…

In Tensorflow, how to unravel the flattened indices obtained by tf.nn.max_pool_with_argmax?

I meet a problem: After I use the tf.nn.max_pool_with_argmax, I obtain the indices i.e. argmax: A Tensor of type Targmax. 4-D. The flattened indices of the max values chosen for each output.How to unr…

How write csv file without new line character in last line?

I have a code like this to write csv file in python import csv with open(eggs.csv, wb) as csvfile:spamwriter = csv.writer(csvfile, delimiter= ,quotechar=|, quoting=csv.QUOTE_MINIMAL)spamwriter.writerow…

Getting tests to parallelize using nose in python

I have a directory with lots of .py files (say test_1.py, test_2.py and so on) Each one of them is written properly to be used with nose. So when I run nosetests script, it finds all the tests in all t…

Python IDLE is not starting on Windows 7

I used to use Python 2.7 and then IDLE was working. I uninstalled it and installed Python 3.1. Right now Idle cannot launch. What should i do to get it running?NOTE: I tried c:\Python31\pythonw.exe c:…

SocketIO + Flask Detect Disconnect

I had a different question here, but realized it simplifies to this: How do you detect when a client disconnects (closes their page or clicks a link) from a page (in other words, the socket connection …

Numpy Array Broadcasting with different dimensions

I a little confused by the broadcasting rules of numpy. Suppose you want to perform an axis-wise scalar product of a higher dimension array to reduce the array dimension by one (basically to perform a …

xml filtering with python

I have a following xml document:<node0><node1><node2 a1="x1"> ... </node2><node2 a1="x2"> ... </node2><node2 a1="x1"> ... </no…