Setting results of torch.gather(...) calls

2024/9/19 10:11:30

I have a 2D pytorch tensor of shape n by m. I want to index the second dimension using a list of indices (which could be done with torch.gather) then then also set new values to the result of the indexing.

Example:

data = torch.tensor([[0,1,2], [3,4,5], [6,7,8]]) # shape (3,3)
indices = torch.tensor([1,2,1], dtype=torch.long).unsqueeze(-1) # shape (3,1)
# data tensor:
# tensor([[0, 1, 2],
#         [3, 4, 5],
#         [6, 7, 8]])

I want to select the specified indices per row (which would be [1,5,7] but then also set these values to another number - e.g. 42

I can select the desired columns row wise by doing:

data.gather(1, indices)
tensor([[1],[5],[7]])
data.gather(1, indices)[:] = 42 # **This does NOT work**, since the result of gather # does not use the same storage as the original tensor

which is fine, but I would like to change these values now, and have the change also affect the data tensor.

I can do what I want to achieve using this, but it seems to be very un-pythonic:

max_index = torch.max(indices)
for i in range(0, max_index + 1):mask = (indices == i).nonzero(as_tuple=True)[0]data[mask, i] = 42
print(data)
# tensor([[ 0, 42,  2],
#         [ 3,  4, 42],
#         [ 6, 42,  8]])

Any hints on how to do that more elegantly?

Answer

What you are looking for is torch.scatter_ with the value option.

Tensor.scatter_(dim, index, src, reduce=None) → Tensor
Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

With 2D tensors as input and dim=1, the operation is:
self[i][index[i][j]] = src[i][j]

No mention of the value parameter though...


With value=42, and dim=1, this will have the following effect on data:

data[i][index[i][j]] = 42

Here applied in-place:

>>> data.scatter_(index=indices, dim=1, value=42)
>>> data
tensor([[ 0, 42,  2],[ 3,  4, 42],[ 6, 42,  8]])
https://en.xdnf.cn/q/72854.html

Related Q&A

Why does PyCharm use double backslash to indicate escaping?

For instance, I write a normal string and another "abnormal" string like this:Now I debug it, finding that in the debug tool, the "abnormal" string will be shown like this:Heres the…

Execute Shell Script from Python with multiple pipes

I want to execute the following Shell Command in a python script:dom=myserver cat /etc/xen/$myserver.cfg | grep limited | cut -d= -f2 | tr -d \"I have this:dom = myserverlimit = subprocess.cal…

Python: Is there a way to split a string of numbers into every 3rd number?

For example, if I have a string a=123456789876567543 could i have a list like...123 456 789 876 567 543

How to override the __dir__ method for a class?

I want to change the dir() output for my class. Normally, for all other objects, its done by defining own __dir__ method in their class. But if I do this for my class, its not called.class X(object):de…

Equivalent of wget in Python to download website and resources

Same thing asked 2.5 years ago in Downloading a web page and all of its resource files in Python but doesnt lead to an answer and the please see related topic isnt really asking the same thing.I want t…

Python plt: close or clear figure does not work

I generate a lots of figures with a script which I do not display but store to harddrive. After a while I get the message/usr/lib/pymodules/python2.7/matplotlib/pyplot.py:412: RuntimeWarning: More than…

PCA of RGB Image

Im trying to figure out how to use PCA to decorrelate an RGB image in python. Im using the code found in the OReilly Computer vision book:from PIL import Image from numpy import *def pca(X):# Principa…

delete the first element in subview of a matrix

I have a dataset like this:[[0,1],[0,2],[0,3],[0,4],[1,5],[1,6],[1,7],[2,8],[2,9]]I need to delete the first elements of each subview of the data as defined by the first column. So first I get all elem…

How to scroll QListWidget to selected item

The code below creates a single dialog window with QListWidget and QPushButton. Clicking the button fires up a scroll() function which finds and selects an "ITEM-0011". I wonder if there is a…

Declaring Subclass without passing self

I have an abstract base class Bicycle:from abc import ABC, abstractmethodclass Bicycle(ABC):def __init__(self, cadence = 10, gear = 10, speed = 10):self._cadence = cadenceself._gear = gear self…