Combining element-wise and matrix multiplication with multi-dimensional arrays in NumPy

2024/9/8 10:25:19

I have two multidimensional NumPy arrays, A and B, with A.shape = (K, d, N) and B.shape = (K, N, d). I would like to perform an element-wise operation over axis 0 (K), with that operation being matrix multiplication over axes 1 and 2 (d, N and N, d). So the result should be a multidimensional array C with C.shape = (K, d, d), so that C[k] = np.dot(A[k], B[k]). A naive implementation would look like this:

C = np.vstack([np.dot(A[k], B[k])[np.newaxis, :, :] for k in xrange(K)])

but this implementation is slow. A slightly faster approach looks like this:

C = np.dot(A, B)[:, :, 0, :]

which uses the default behaviour of np.dot on multidimensional arrays, giving me an array with shape (K, d, K, d). However, this approach computes the required answer K times (each of the entries along axis 2 are the same). Asymptotically it will be slower than the first approach, but the overhead is much less. I am also aware of the following approach:

from numpy.core.umath_tests import matrix_multiply
C = matrix_multiply(A, B)

but I am not guaranteed that this function will be available. My question is thus, does NumPy provide a standard way of doing this efficiently? An answer which applies to multidimensional arrays in general would be perfect, but an answer specific to only this case would be great too.

Edit: As pointed out by @Juh_, the second approach is incorrect. The correct version is:

C = np.dot(A, B).diagonal(axis1=0, axis2=2).transpose(2, 0, 1)

but the overhead added makes it slower than the first approach, even for small matrices. The last approach is winning by a long shot on all my timing tests, for small and large matrices. I'm now strongly considering using this if no better solution crops up, even if that would mean copying the numpy.core.umath_tests library (written in C) into my project.

Answer

A possible solution to your problem is:

C = np.sum(A[:,:,:,np.newaxis]*B[:,np.newaxis,:,:],axis=2)

However:

  1. it is quicker than the vstack approach only if K is much bigger than d and N
  2. their might be some memory issue: in the above solution an KxdxNxd array is allocated (i.e. all possible product paires, before summing). Actually I could not test with big K,d and N as I was going out of memory.

btw, note that:

C = np.dot(A, B)[:, :, 0, :]

does not give the correct result. It got me tricked because I first checked my method by comparing the results to those given by this np.dot command.

https://en.xdnf.cn/q/72734.html

Related Q&A

Target array shape different to expected output using Tensorflow

Im trying to make a CNN (still a beginner). When trying to fit the model I am getting this error:ValueError: A target array with shape (10000, 10) was passed for output of shape (None, 6, 6, 10) while …

Using openpyxl to refresh pivot tables in Excle

I have a file that has several tabs that have pivot tables that are based on one data tab. I am able to write the data to the data tab without issue, but I cant figure out how to get all of the tabs wi…

python: is there a library function for chunking an input stream?

I want to chunk an input stream for batch processing. Given an input list or generator,x_in = [1, 2, 3, 4, 5, 6 ...]I want a function that will return chunks of that input. Say, if chunk_size=4, then,x…

jinja2: How to make it fail Silently like djangotemplate

Well i dont find the answer Im sure that its very simple, but i just dont find out how to make it work like Django when it doesnt find a variablei tried to use Undefined and create my own undefined but…

ImportError when from transformers import BertTokenizer

My code is: import torch from transformers import BertTokenizer from IPython.display import clear_outputI got error in line from transformers import BertTokenizer: ImportError: /lib/x86_64-linux-gnu/li…

How to get feature names of shap_values from TreeExplainer?

I am doing a shap tutorial, and attempting to get the shap values for each person in a dataset from sklearn.model_selection import train_test_split import xgboost import shap import numpy as np import …

How can I clear a line in console after using \r and printing some text?

For my current project, there are some pieces of code that are slow and which I cant make faster. To get some feedback how much was done / has to be done, Ive created a progress snippet which you can s…

installing pyaudio to docker container

I am trying to install pyaudio to my docker container and I was wondering if anyone had any solution for Windows. I have tried two methods: Method 1: Using pipwin - Error Code: => [3/7] RUN pip inst…

Escaping special characters in elasticsearch

I am using the elasticsearch python client to make some queries to the elasticsearch instance that we are hosting.I noticed that some characters need to be escaped. Specifically, these...+ - &&…

Interacting with live matplotlib plot

Im trying to create a live plot which updates as more data is available.import os,sys import matplotlib.pyplot as pltimport time import randomdef live_plot():fig = plt.figure()ax = fig.add_subplot(111)…