Find diagonals sums in numpy (faster)

2024/10/7 8:22:33

I have some board numpy arrays like that:

array([[0, 0, 0, 1, 0, 0, 0, 0],[1, 0, 0, 0, 0, 1, 0, 1],[0, 0, 0, 0, 0, 0, 0, 1],[0, 1, 0, 1, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 1, 0, 0, 0],[0, 0, 1, 0, 0, 0, 1, 0],[1, 0, 0, 0, 0, 1, 0, 0]])

And I'm using the following code to find the sum of elements on each nth diagonal from -7 to 8 of the board (and the mirrored version of it).

n = 8
rate = [b.diagonal(i).sum()for b in (board, board[::-1])for i in range(-n+1, n)]

After some profiling, this operation is taking about 2/3 of overall running time and it seems to be because of 2 factors:

  • The .diagonal method builds a new array instead of a view (looks like numpy 1.7 will have a new .diag method to solve that)
  • The iteration is done in python inside the list comprehension

So, there are any methods to find these sums faster (possibly in the C layer of numpy)?


After some more tests, I could reduce 7.5x the total time by caching this operation... Maybe I was looking for the wrong bottleneck?


One more thing:

Just found the .trace method that replaces the diagonal(i).sum() thing and... There wasn't much improvement in performance (about 2 to 4%).

So the problem should be the comprehension. Any ideas?

Answer

There's a possible solution using stride_tricks. This is based in part on the plethora of information available in the answers to this question, but the problem is just different enough, I think, not to count as a duplicate. Here's the basic idea, applied to a square matrix -- see below for a function implementing the more general solution.

>>> cols = 8
>>> a = numpy.arange(cols * cols).reshape((cols, cols))
>>> fill = numpy.zeros((cols - 1) * cols, dtype='i8').reshape((cols - 1, cols))
>>> stacked = numpy.vstack((a, fill, a))
>>> major_stride, minor_stride = stacked.strides
>>> strides = major_stride, minor_stride * (cols + 1)
>>> shape = (cols * 2 - 1, cols)
>>> numpy.lib.stride_tricks.as_strided(stacked, shape, strides)
array([[ 0,  9, 18, 27, 36, 45, 54, 63],[ 8, 17, 26, 35, 44, 53, 62,  0],[16, 25, 34, 43, 52, 61,  0,  0],[24, 33, 42, 51, 60,  0,  0,  0],[32, 41, 50, 59,  0,  0,  0,  0],[40, 49, 58,  0,  0,  0,  0,  0],[48, 57,  0,  0,  0,  0,  0,  0],[56,  0,  0,  0,  0,  0,  0,  0],[ 0,  0,  0,  0,  0,  0,  0,  7],[ 0,  0,  0,  0,  0,  0,  6, 15],[ 0,  0,  0,  0,  0,  5, 14, 23],[ 0,  0,  0,  0,  4, 13, 22, 31],[ 0,  0,  0,  3, 12, 21, 30, 39],[ 0,  0,  2, 11, 20, 29, 38, 47],[ 0,  1, 10, 19, 28, 37, 46, 55]])
>>> diags = numpy.lib.stride_tricks.as_strided(stacked, shape, strides)
>>> diags.sum(axis=1)
array([252, 245, 231, 210, 182, 147, 105,  56,   7,  21,  42,  70, 105,147, 196])

Of course, I have no idea how fast this will actually be. But I bet it will be faster than a Python list comprehension.

For convenience, here's a fully general diagonals function. It assumes you want to move the diagonal along the longest axis:

def diagonals(a):rows, cols = a.shapeif cols > rows:a = a.Trows, cols = a.shapefill = numpy.zeros(((cols - 1), cols), dtype=a.dtype)stacked = numpy.vstack((a, fill, a))major_stride, minor_stride = stacked.stridesstrides = major_stride, minor_stride * (cols + 1)shape = (rows + cols - 1, cols)return numpy.lib.stride_tricks.as_strided(stacked, shape, strides)
https://en.xdnf.cn/q/70261.html

Related Q&A

Create dictionary from list python

I have many lists in this format:[1, O1, , , , 0.0000, 0.0000, , ] [2, AP, , , , 35.0000, 105.0000, , ] [3, EU, , , , 47.0000, 8.0000, , ]I need to create a dictionary with key as the first element in …

Outputting height of a pyramid

So for this coding exercise I have to input a number of imaginary blocks and it will tell me how many complete rows high the pyramid is. So for example if I input 6 blocks...I want it to tell me that t…

PySide SVG image formats not found?

I am using PyDev plugin for Eclipse with Qt integration. I have PySide installed and I am having trouble with SVG image formats. I know when I run my application the formats located in C:\Python27\Lib\…

convert ascii character to signed 8-bit integer python

This feels like it should be very simple, but I havent been able to find an answer..In a python script I am reading in data from a USB device (x and y movements of a USB mouse). it arrives in single AS…

What is the equivalent way of doing this type of pythonic vectorized assignment in MATLAB?

Im trying to translate this line of code from Python to MATLAB:new_img[M[0, :] - corners[0][0], M[1, :] - corners[1][0], :] = img[T[0, :], T[1, :], :]So, naturally, I wrote something like this:new_img(…

How do I connect mitmproxy to another proxy outside of my control?

The process would be that the browser send a request to MITMproxy and then generate a request that gets sent to target proxy server which isnt controlled by us. The proxy server would send a response t…

How does conda-env list / conda info --envs find environments?

Ive been experimenting with anaconda/miniconda because my users use structural biology programs installed with miniconda and none of the authors A) take into account that there might be other miniconda…

Updating a large number of entities in a datastore on Google App Engine

I would like to perform a small operation on all entities of a specific kind and rewrite them to the datastore. I currently have 20,000 entities of this kind but would like a solution that would scale …

Is there a neater alternative to `except: pass`?

I had a function that returned a random member of several groups in order of preference. It went something like this:def get_random_foo_or_bar():"Id rather have a foo than a bar."if there_are…

Get a permutation as a function of a unique given index in O(n)

I would like to have a function get_permutation that, given a list l and an index i, returns a permutation of l such that the permutations are unique for all i bigger than 0 and lower than n! (where n …