Numpy Array Broadcasting with different dimensions

2024/11/17 11:21:51

I a little confused by the broadcasting rules of numpy. Suppose you want to perform an axis-wise scalar product of a higher dimension array to reduce the array dimension by one (basically to perform a weighted summation along one axis):

from numpy import *A = ones((3,3,2))
v = array([1,2])B = zeros((3,3))# V01: this works
B[0,0] = v.dot(A[0,0])# V02: this works
B[:,:] = v[0]*A[:,:,0] + v[1]*A[:,:,1] # V03: this doesn't
B[:,:] = v.dot(A[:,:]) 

Why does V03 not work?

Cheers

Answer

np.dot(a, b) operates over the last axis of a and the second-to-last of b. So for your particular case in your question,you could always go with :

>>> a.dot(v)
array([[ 3.,  3.,  3.],[ 3.,  3.,  3.],[ 3.,  3.,  3.]])

If you want to keep the v.dot(a) order, you need to get the axis into position, which can easily be achieved with np.rollaxis:

>>> v.dot(np.rollaxis(a, 2, 1))
array([[ 3.,  3.,  3.],[ 3.,  3.,  3.],[ 3.,  3.,  3.]])

I don't like np.dot too much, unless it is for the obvious matrix or vector multiplication, because it is very strict about the output dtype when using the optional out parameter. Joe Kington has mentioned it already, but if you are going to be doing this type of things, get used to np.einsum: once you get the hang of the syntax, it will cut down the amount of time you spend worrying about reshaping things to a minimum:

>>> a = np.ones((3, 3, 2))
>>> np.einsum('i, jki', v, a)
array([[ 3.,  3.,  3.],[ 3.,  3.,  3.],[ 3.,  3.,  3.]])

Not that it is too relevant in this case, but it is also ridiculously fast:

In [4]: %timeit a.dot(v)
100000 loops, best of 3: 2.43 us per loopIn [5]: %timeit v.dot(np.rollaxis(a, 2, 1))
100000 loops, best of 3: 4.49 us per loopIn [7]: %timeit np.tensordot(v, a, axes=(0, 2))
100000 loops, best of 3: 14.9 us per loopIn [8]: %timeit np.einsum('i, jki', v, a)
100000 loops, best of 3: 2.91 us per loop
https://en.xdnf.cn/q/71219.html

Related Q&A

xml filtering with python

I have a following xml document:<node0><node1><node2 a1="x1"> ... </node2><node2 a1="x2"> ... </node2><node2 a1="x1"> ... </no…

What it really is @client.event? discord.py

A few days ago I became interested in programming discord bots a bit. In the syntax of these programs I noticed a lot of unintelligible issues that I can not find an answer to. Thats why I am asking y…

How to customize virtualenv shell prompt

How do you define a custom prompt to use when activating a Python virtual environment?I have a bash script for activating a virtualenv I use when calling specific Fabric commands. I want the shell pro…

How to get the percent change of values in a dataframe while caring about NaN values?

I have the following DataFrame:Date A 2015-01-01 10 2015-01-02 14 2015-01-05 NaN 2015-01-06 …

Convert CSV to YAML, with Unicode?

Im trying to convert a CSV file, containing Unicode strings, to a YAML file using Python 3.4.Currently, the YAML parser escapes my Unicode text into an ASCII string. I want the YAML parser to export t…

Why is the divide and conquer method of computing factorials so fast for large ints? [closed]

As it currently stands, this question is not a good fit for our Q&A format. We expect answers to be supported by facts, references, or expertise, but this question will likely solicit debate, argum…

Python calculate speed, distance, direction from 2 GPS coordinates

How do I calculate the speed, distance and direction (degrees) from 2 GPS coordinates in Python? Each point has lat, long, time.I found the Haversine distance calculation in this post:Calculate dista…

Installed gunicorn but it is not in venv/bin folder

Im new to gunicorn and trying to deploy a django website on an ubuntu. I have used: pip3 install gunicorn sudo apt-get install gunicornbut when I want to fill this file:sudo nano /etc/systemd/system/g…

Does Pythons asyncio lock.acquire maintain order?

If I have two functions doingasync with mylock.acquire():....Once the lock is released, is it guaranteed that the first to await will win, or is the order selected differently? (e.g. randomly, arbitra…

Howto ignore specific undefined variables in Pydev Eclipse

Im writing a crossplatform python script on windows using Eclipse with the Pydev plugin. The script makes use of the os.symlink() and os.readlink() methods if the current platform isnt NT. Since the os…