vectorized radix sort with numpy - can it beat np.sort?

2024/10/7 0:19:47

Numpy doesn't yet have a radix sort, so I wondered whether it was possible to write one using pre-existing numpy functions. So far I have the following, which does work, but is about 10 times slower than numpy's quicksort.

line profiler output

Test and benchmark:

a = np.random.randint(0, 1e8, 1e6)
assert(np.all(radix_sort(a) == np.sort(a))) 
%timeit np.sort(a)
%timeit radix_sort(a)

The mask_b loop can be at least partially vectorized, broadcasting out across masks from &, and using cumsum with axis arg, but that ends up being a pessimization, presumably due to the increased memory footprint.

If anyone can see a way to improve on what I have I'd be interested to hear, even if it's still slower than np.sort...this is more a case of intellectual curiosity and interest in numpy tricks.

Note that you can implement a fast counting sort easily enough, though that's only relevant for small integer data.

Edit 1: Taking np.arange(n) out of the loop helps a little, but that's not very exiciting.

Edit 2: The cumsum was actually redundant (ooops!) but this simpler version only helps marginally with performance..

def radix_sort(a):bit_len = np.max(a).bit_length()n = len(a)cached_arange = arange(n)idx = np.empty(n, dtype=int) # fully overwritten each iterationfor mask_b in xrange(bit_len):is_one = (a & 2**mask_b).astype(bool)n_ones = np.sum(is_one)      n_zeros = n-n_onesidx[~is_one] = cached_arange[:n_zeros]idx[is_one] = cached_arange[:n_ones] + n_zeros# next three lines just do: a[idx] = a, but correctlynew_a = np.empty(n, dtype=a.dtype)new_a[idx] = aa = new_areturn a

Edit 3: rather than loop over single bits, you can loop over two or more at a time, if you construct idx in multiple steps. Using 2 bits helps a little, I've not tried more:

idx[is_zero] = np.arange(n_zeros)
idx[is_one] = np.arange(n_ones)
idx[is_two] = np.arange(n_twos)
idx[is_three] = np.arange(n_threes)

Edits 4 and 5: going to 4 bits seems best for the input I'm testing. Also, you can get rid of the idx step entirely. Now only about 5 times, rather than 10 times, slower than np.sort (source available as gist):

enter image description here

Edit 6: This is a tidied up version of the above, but it's also a tiny bit slower. 80% of the time is spent on repeat and extract - if only there was a way to broadcast the extract :( ...

def radix_sort(a, batch_m_bits=3):bit_len = np.max(a).bit_length()batch_m = 2**batch_m_bitsmask = 2**batch_m_bits - 1val_set = np.arange(batch_m, dtype=a.dtype)[:, nax] # nax = np.newaxisfor _ in range((bit_len-1)//batch_m_bits + 1): # ceil-divisiona = np.extract((a & mask)[nax, :] == val_set,np.repeat(a[nax, :], batch_m, axis=0))val_set <<= batch_m_bitsmask <<= batch_m_bitsreturn a

Edits 7 & 8: Actually, you can broadcast the extract using as_strided from numpy.lib.stride_tricks, but it doesn't seem to help much performance-wise:

enter image description here

Initially this made sense to me on the grounds that extract will be iterating over the whole array batch_m times, so the total number of cache lines requested by the CPU will be the same as before (it's just that by the end of the process it has request each cache line batch_m times). However the reality is that extract is not clever enough to iterate over arbitrary stepped arrays, and has to expand out the array before beginning, i.e. the repeat ends up being done anyway. In fact, having looked at the source for extract, I now see that the best we can do with this approach is:

a = a[np.flatnonzero((a & mask)[nax, :] == val_set) % len(a)]

which is marginally slower than extract. However, if len(a) is a power of two we can replace the expensive mod operation with & (len(a) - 1), which does end up being a bit faster than the extract version (now about 4.9x np.sort for a=randint(0, 1e8, 2**20). I suppose we could make this work for non-power of two lengths by zero-padding, and then cropping the extra zeros at the end of the sort...however this would be a pessimisation unless the length was already close to being power of two.

Answer

I had a go with Numba to see how fast a radix sort could be. The key to good performance with Numba (often) is to write out all the loops, which is very instructive. I ended up with the following:

from numba import jit@jit
def radix_loop(nbatches, batch_m_bits, bitsums, a, out):mask = (1 << batch_m_bits) - 1for shift in range(0, nbatches*batch_m_bits, batch_m_bits):# set bit sums to zerofor i in range(bitsums.shape[0]):bitsums[i] = 0# determine bit sumsfor i in range(a.shape[0]):j = (a[i] & mask) >> shiftbitsums[j] += 1# take the cumsum of the bit sumscumsum = 0for i in range(bitsums.shape[0]):temp = bitsums[i]bitsums[i] = cumsumcumsum += temp# sorting loopfor i in range(a.shape[0]):j = (a[i] & mask) >> shiftout[bitsums[j]] = a[i]bitsums[j] += 1# prepare next iterationmask <<= batch_m_bits# cant use `temp` here because of numba internal typestemp2 = aa = outout = temp2return a

From the 4 inner loops, it's easy to see it's the 4th one making it hard to vectorize with Numpy.

One way to cheat around that problem is to pull in a particular C++ function from Scipy: scipy.sparse.coo.coo_tocsr. It does pretty much the same inner loops as the Python function above, so it can be abused to write a faster "vectorized" radix sort in Python. Maybe something like:

from scipy.sparse.coo import coo_tocsrdef radix_step(radix, keys, bitsums, a, w):coo_tocsr(radix, 1, a.size, keys, a, a, bitsums, w, w)return w, adef scipysparse_radix_perbyte(a):# coo_tocsr internally works with system int and upcasts# anything else. We need to copy anyway to not mess with# original array. Also take into account endianness...a = a.astype('<i', copy=True)bitlen = int(a.max()).bit_length()radix = 256work = np.empty_like(a)_ = np.empty(radix+1, int)for i in range((bitlen-1)//8 + 1):keys = a.view('u1')[i::a.itemsize].astype(int)a, work = radix_step(radix, keys, _, a, work)return a

EDIT: Optimized the function a little bit.. see edit history.

One inefficiency of LSB radix sorting like above is that the array is completely shuffled in RAM a number of times, which means the CPU cache isn't used very well. To try to mitigate this effect, one could opt to first do a pass with MSB radix sort, to put items in roughly the right block of RAM, before sorting every resulting group with a LSB radix sort. Here's one implementation:

def scipysparse_radix_hybrid(a, bbits=8, gbits=8):"""Parameters----------a : Array of non-negative integers to be sorted.bbits : Number of bits in radix for LSB sorting.gbits : Number of bits in radix for MSB grouping."""a = a.copy()bitlen = int(a.max()).bit_length()work = np.empty_like(a)# Group values by single iteration of MSB radix sort:# Casting to np.int_ to get rid of python BigIntngroups = np.int_(2**gbits)group_offset = np.empty(ngroups + 1, int)shift = max(bitlen-gbits, 0)a, work = radix_step(ngroups, a>>shift, group_offset, a, work)bitlen = shiftif not bitlen:return a# LSB radix sort each group:agroups = np.split(a, group_offset[1:-1])# Mask off high bits to not undo the grouping..gmask = (1 << shift) - 1nbatch = (bitlen-1) // bbits + 1radix = np.int_(2**bbits)_ = np.empty(radix + 1, int)for agi in agroups:if not agi.size:continuemask = (radix - 1) & gmaskwgi = work[:agi.size]for shift in range(0, nbatch*bbits, bbits):keys = (agi & mask) >> shiftagi, wgi = radix_step(radix, keys, _, agi, wgi)mask = (mask << bbits) & gmaskif nbatch % 2:# Copy result back in to `a`wgi[...] = agireturn a

Timings (with best performing settings for each on my system):

def numba_radix(a, batch_m_bits=8):a = a.copy()bit_len = int(a.max()).bit_length()nbatches = (bit_len-1)//batch_m_bits +1work = np.zeros_like(a)bitsums = np.zeros(2**batch_m_bits + 1, int)srtd = radix_loop(nbatches, batch_m_bits, bitsums, a, work)return srtda = np.random.randint(0, 1e8, 1e6)
%timeit numba_radix(a, 9)
# 10 loops, best of 3: 76.1 ms per loop
%timeit np.sort(a)
#10 loops, best of 3: 115 ms per loop
%timeit scipysparse_radix_perbyte(a)
#10 loops, best of 3: 95.2 ms per loop
%timeit scipysparse_radix_hybrid(a, 11, 6)
#10 loops, best of 3: 75.4 ms per loop

Numba performs very well, as expected. And also with some clever application of existing C-extensions it's possible to beat numpy.sort. IMO at the level of optimization you've already gotten it's worth-it to also consider add-ons to Numpy, but I wouldn't really consider the implementations in my answer "vectorized": The bulk of the work is done in a external dedicated function.

One other thing that strikes me is the sensitivity to the choice of radix. For most of the settings I tried my implementations were still slower than numpy.sort, so in practice some sort of heuristic would be required to offer good performance across the board.

https://en.xdnf.cn/q/70300.html

Related Q&A

Which library should I use to write an XLS from Linux / Python?

Id love a good native Python library to write XLS, but it doesnt seem to exist. Happily, Jython does.So Im trying to decide between jexcelapi and Apache HSSF: http://www.andykhan.com/jexcelapi/tutoria…

put_records() only accepts keyword arguments in Kinesis boto3 Python API

from __future__ import print_function # Python 2/3 compatibility import boto3 import json import decimal#kinesis = boto3.resource(kinesis, region_name=eu-west-1) client = boto3.client(kinesis) with ope…

Setting a transparent main window

How to set main window background transparent on QT? Do I need an attribute or a style? Ive tried setting the opacity, but it didnt work for me. app.setStyleSheet("QMainWindow {opacity:0}"

Elementwise division of sparse matrices, ignoring 0/0

I have two sparse matrices E and D, which have non-zero entries at the same places. Now I want to have E/D as a sparse matrix, defined only where D is non-zero.For example take the following code:impor…

Django import export Line number: 1 - uColumn id not found

I am trying to import excel documents into a Django DB. I have added the following code to admin.py and model.py. There seems to be an error in the development of Django. I have read through several di…

Why cant I access builtins if I use a custom dict as a functions globals?

I have a dict subclass like this:class MyDict(dict):def __getitem__(self, name):return globals()[name]This class can be used with eval and exec without issues:>>> eval(bytearray, MyDict()) <…

How to enable autocomplete (IntelliSense) for python package modules?

This question is not about Pygame, Im usin Pygame as an example.While experimenting with Pygame Ive noticed that autocomplete is not working for some modules. For example, if I start typing pygame.mixe…

Integrating a redirection-included method of payment in django-oscar

I am developing a shopping website using django-oscar framework, in fact I am using their sandbox site. I want to add payment to the checkout process, but the thing is, I am totally confused!Ive read t…

How can I unpack sequence?

Why cant I do this:d = [x for x in range(7)] a, b, c, d, e, f, g = *dWhere is it possible to unpack? Only between parentheses of a function?

Can I statically link Cython modules into an executable which embeds python?

I currently have an executable compiled from C++ that embeds python. The embedded executable runs a python script which load several Cython modules. Both the Cython modules and the executable are lin…