Indexing a numpy array with a numpy array of indexes [duplicate]

2024/10/14 8:28:52

I have a 3D numpy array data and another array pos of indexes (an index is a numpy array on its own, which makes the latter array a 2D array):

import numpy as np
data = np.arange(8).reshape(2, 2, -1)
#array([[[0, 1],
#    [2, 3]],
#
#  [[4, 5],
#    [6, 7]]])pos = np.array([[1, 1, 0], [0, 1, 0], [1, 0, 0]])
#array([[1, 1, 0],
#       [0, 1, 0],
#       [1, 0, 0]])

I want to select and/or mutate the elements from data using the indexes from pos. I can do the selection using a for loop or a list comprehension:

[data[tuple(i)] for i in pos]
#[6, 2, 4]
data[[i for i in pos.T]]
#array([6, 2, 4])

But this does not seem to be a numpy way. Is there a vectorized numpy solution to this problem?

Answer

You can split pos into 3 separate arrays and index, like so—

>>> i, j, k = pos.T
>>> data[i, j, k]
array([6, 2, 4])

Here, the number of columns in pos correspond to the depth of data. As long as you're dealing with 3D matrices, getting i, j, and k well never get more complicated than this.

On python-3.6+, you can shorten this to—

>>> data[[*pos.T]]
array([6, 2, 4])
https://en.xdnf.cn/q/69434.html

Related Q&A

Input redirection with python

I have the following program to test input redirection in Python.a = int(raw_input("Enter a number: ")) b = raw_input("Enter a string: ") print "number entered = ", a prin…

TypeError: can only concatenate str (not numpy.int64) to str

I want to print the variable based on the index number based on the following dataset:Here I used the following code:import pandas as pdairline = pd.read_csv("AIR-LINE.csv")pnr = input("…

Saving scatterplot animations with matplotlib produces blank video file

I am having a very similar problem to this questionbut the suggested solution doesnt work for me. I have set up an animated scatter plot using the matplotlib animation module. This works fine when it i…

How to group near-duplicate values in a pandas dataframe?

If there are duplicate values in a DataFrame pandas already provides functions to replace or drop duplicates. In many experimental datasets on the other hand one might have near duplicates. How can one…

python looping and creating new dataframe for each value of a column

I want to create a new dataframe for each unique value of station.I tried below which gives me only last station data updated in the dataframe = tai_new.itai[station].unique() has 500 values.for i in t…

How to put more whitespace around my plots?

I have a figure that contains two subplots in two rows and one column like so:fig, (ax1, ax2) = subplots(nrows=2,ncols=1, )The two subplots are pie charts, therefore I want their axes to be square. Aft…

using ols from statsmodels.formula.api - how to remove constant term?

Im following this first example in statsmodels tutorial:http://statsmodels.sourceforge.net/devel/How do I specify not to use constant term for linear fit in ols?# Fit regression model (using the natur…

Is numerical encoding necessary for the target variable in classification?

I am using sklearn for text classification, all my features are numerical but my target variable labels are in text. I can understand the rationale behind encoding features to numerics but dont think t…

django - regex for optional url parameters

I have a view in django that can accept a number of different filter parameters, but they are all optional. If I have 6 optional filters, do I really have to write urls for every combination of the 6 …

How do I remove transparency from a histogram created using Seaborn in python?

Im creating histograms using seaborn in python and want to customize the colors. The default settings create transparent histograms, and I would like mine to be solid. How do I remove the transparency?…