`Hello, everyone.
I found there is a strange behavior when subclassing a ndarray.
import numpy as npclass fooarray(np.ndarray):def __new__(cls, input_array, *args, **kwargs):obj = np.asarray(input_array).view(cls)return objdef __init__(self, *args, **kwargs):returndef __array_finalize__(self, obj):returna=fooarray(np.random.randn(3,5))
b=np.random.randn(3,5)a_sum=np.sum(a,axis=0,keepdims=True)
b_sum=np.sum(b,axis=0, keepdims=True)print a_sum.ndim #1
print b_sum.ndim #2
As you have seen, the keepdims
argument doesn't work for my subclass fooarray
. It lost one of its axis. How can't I avoid this problem? Or more generally, how can I subclass numpy ndarray correctly?
np.sum
can accept a variety of objects as input: not only ndarrays, but also lists, generators, np.matrix
s, for instance. The keepdims
parameter obviously does not make sense for lists or generators. It is also not appropriate for np.matrix
instances either, since np.matrix
s always have 2 dimensions. If you look at the call signature for np.matrix.sum
you see that its sum
method has no keepdims
parameter:
Definition: np.matrix.sum(self, axis=None, dtype=None, out=None)
So some subclasses of ndarray
may have sum
methods which do not have a keepdims
parameter. This is an unfortunate violation of the Liskov substitution principle and the origin of the pitfall you encountered.
Now if you look at the source code for np.sum
, you see that it is a delegating function which tries to determine what to do based on the type of the first argument.
If the type of the first argument is not ndarray
, it drops the keepdims
parameter. It does this because passing the keepdims parameter to np.matrix.sum
would raise an exception.
So because np.sum
is trying to do the delegation in the most general way, not making any assumption about what arguments a subclass of ndarray may take, it drops the keepdims
parameter when passed a fooarray
.
The workaround is to not use np.sum
, but call a.sum
instead. This is more direct anyway, since np.sum
is merely a delegating function.
import numpy as npclass fooarray(np.ndarray):def __new__(cls, input_array, *args, **kwargs):obj = np.asarray(input_array, *args, **kwargs).view(cls)return obja = fooarray(np.random.randn(3, 5))
b = np.random.randn(3, 5)a_sum = a.sum(axis=0, keepdims=True)
b_sum = np.sum(b, axis=0, keepdims=True)print(a_sum.ndim) # 2
print(b_sum.ndim) # 2