How to determine a numpy-array reshape strategy

2024/10/6 10:31:22

For a python project I often find myself reshaping and re-arranging n-dimensional numpy arrays. However, I have a hard time to determine how to approach the problem, visualize the outcome of the results of the reshaping methods and knowing my solution is efficient.

At the moment when confronted with such a problem my strategy is to start ipython, load some sample data and go trial and error until I find a combination of transpose()s, reshape()s and swapaxes()s. which gets the desired result. It gets the job done, but without a real understanding of what is going on and often produces code which is hard to maintain.

So, my question is about finding a strategy. How do you approach such a problem? How do you visualize an ndarray in your head when you have to shape it in the desired format? How do you come to the right actions?

To make answering a bit more concrete, an example to play with:

Assume you want to reshape the following 3d-array

array([[[ 0,  1,  2],[ 3,  4,  5],[ 6,  7,  8]],[[ 9, 10, 11],[12, 13, 14],[15, 16, 17]],[[18, 19, 20],[21, 22, 23],[24, 25, 26]]])

to a 2d-array where the first columns from the 3rd dimension are placed first, the 2nd columns second, ....etc

The result should look like this:

array([[ 0,  9, 18,  3, 12, 21,  6, 15, 24],[ 1, 10, 19,  4, 13, 22,  7, 16, 25],[ 2, 11, 20,  5, 14, 23,  8, 17, 26]])

PS. also any reading material on the subject would be great!

Answer

I regularly play about with shapes in ipython. However, to make things clearer, I start with array with distinct dimensions.

arr = np.arange(3*4*5).reshape(3,4,5)

That way, it's easier to identify how the axes get shifted, for example:

In [25]: arr.shape
Out[25]: (3, 4, 5)In [26]: arr.T.shape
Out[26]: (5, 4, 3)In [31]: arr.T.reshape(5,-1)
Out[31]: 
array([[ 0, 20, 40,  5, 25, 45, 10, 30, 50, 15, 35, 55],[ 1, 21, 41,  6, 26, 46, 11, 31, 51, 16, 36, 56],[ 2, 22, 42,  7, 27, 47, 12, 32, 52, 17, 37, 57],[ 3, 23, 43,  8, 28, 48, 13, 33, 53, 18, 38, 58],[ 4, 24, 44,  9, 29, 49, 14, 34, 54, 19, 39, 59]])

where as a different transpose (that does not switch the order of 3,4)

In [38]: np.transpose(arr,[2,0,1]).shape
Out[38]: (5, 3, 4)In [39]: np.transpose(arr,[2,0,1]).reshape(5,-1)
Out[39]: 
array([[ 0,  5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55],[ 1,  6, 11, 16, 21, 26, 31, 36, 41, 46, 51, 56],[ 2,  7, 12, 17, 22, 27, 32, 37, 42, 47, 52, 57],[ 3,  8, 13, 18, 23, 28, 33, 38, 43, 48, 53, 58],[ 4,  9, 14, 19, 24, 29, 34, 39, 44, 49, 54, 59]])

I also like to use 'oddly' shaped arrays like this when developing functions. That way, if I do mess up some transpose or broadcasting, dimensions errors will jump out at me. Experience tells me that once I get the dimensions right, the values will also be correct. Or at least the class of errors that affect values is quite different from those that affect dimensions.

I also liberally sprinkle development code with print arr.shape like statements, or even assert x.shape==y.shape assertions.

It also helps to label dimensions:

M, N, L = 3, 4, 5
np.empty((M,N,L))

or like in einsum

np.einsum('ijk,kj->i', A, B) # if A is (M,N,L), B must be (L,N)

https://stackoverflow.com/a/29903842/901925 is an example of trying to understand and explain rollaxis.

Another strategy is to look at the Python code of numpy functions. Often they accept axis arguments. It's instructive to see how they use those. Sometimes that particular axis is rotated to the front, or to the end. Sometimes a nd array is reshaped into a 2d array, collapsing all axes except one down to one. Other achieve generality by constructing and manipulating an indexing tuple. More advanced functions play with the strides as well as the shape.

Whether a dimension should be first or last is usally an optimization issue - and may involve tradeoffs between ease of use (broadcasting, indexing) and speed. Just keep in mind that for "C" order, the last dimension forms contiguous blocks.

https://en.xdnf.cn/q/70377.html

Related Q&A

matplotlib plotting multiple lines in 3D

I am trying to plot multiple lines in a 3D plot using matplotlib. I have 6 datasets with x and y values. What Ive tried so far was, to give each point in the data sets a z-value. So all points in data …

How to get a telegram private channel id with telethon

Hi cant figure out how to solve this problem, so any help will be really appreciated. Im subscribed to a private channel. This channel has no username and I dont have the invite link (the admin just ad…

boolean mask in pandas panel

i am having some trouble masking a panel in the same way that I would a DataFrame. What I want to do feels simple, but I have not found a way looking at the docs and online forums. I have a simple ex…

How can I move the text label of a radiobutton below the button in Python Tkinter?

Im wondering if theres a way to move the label text of a radiobutton to a different area, e.g. below the actual button.Below is an example of a few radiobuttons being placed using grid that Im using:fr…

play sound file in PyQt

Ive developed a software in PyQt which plays sound.Im using Phonon Library to play the sound but it has some lag.So how can I play a sound file in PyQt without using Phonon Library.This is how I am cur…

translating named list vectors from R into rpy2 in Python?

What is the equivalent of the following R code in Rpy2 in python?Var1 = c("navy", "darkgreen") names(Var1) = c("Class1", "Class2") ann_colors = list(Var1 = Var1…

Issue parsing multiline JSON file using Python

I am trying to parse a JSON multiline file using json library in Python 2.7. A simplified sample file is given below:{ "observations": {"notice": [{"copyright": "Copy…

timezone aware vs. timezone naive in python

I am working with datetime objects in python. I have a function that takes a time and finds the different between that time and now. def function(past_time):now = datetime.now()diff = now - past_timeWh…

How to return a value from Python script as a Bash variable?

This is a summary of my code:# import whateverdef createFolder():#someCodevar1=Gdrive.createFolder(name)return var1 def main():#someCodevar2=createFolder()return var2if __name__ == "__main__"…

How to align text to the right in ttk Treeview widget?

I am using a ttk.Treeview widget to display a list of Arabic books. Arabic is a right-to-left language, so the text should be aligned to the right. The justify option that is available for Label and o…