Counting consecutive 1s in NumPy array

2024/11/15 16:32:34
[1, 1, 1, 0, 0, 0, 1, 1, 0, 0]

I have a NumPy array consisting of 0's and 1's like above. How can I add all consecutive 1's like below? Any time I encounter a 0, I reset.

[1, 2, 3, 0, 0, 0, 1, 2, 0, 0]

I can do this using a for loop, but is there a vectorized solution using NumPy?

Answer

Here's a vectorized approach -

def island_cumsum_vectorized(a):a_ext = np.concatenate(( [0], a, [0] ))idx = np.flatnonzero(a_ext[1:] != a_ext[:-1])a_ext[1:][idx[1::2]] = idx[::2] - idx[1::2]return a_ext.cumsum()[1:-1]

Sample run -

In [91]: a = np.array([1, 1, 1, 0, 0, 0, 1, 1, 0, 0])In [92]: island_cumsum_vectorized(a)
Out[92]: array([1, 2, 3, 0, 0, 0, 1, 2, 0, 0])In [93]: a = np.array([0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1])In [94]: island_cumsum_vectorized(a)
Out[94]: array([0, 1, 2, 3, 4, 0, 0, 0, 1, 2, 0, 0, 1])

Runtime test

For the timings , I would use OP's sample input array and repeat/tile it and hopefully this should be a less opportunistic benchmark -

Small case :

In [16]: a = np.array([1, 1, 1, 0, 0, 0, 1, 1, 0, 0])In [17]: a = np.tile(a,10)  # Repeat OP's data 10 times# @Paul Panzer's solution
In [18]: %timeit np.concatenate([np.cumsum(c) if c[0] == 1 else c for c in np.split(a, 1 + np.where(np.diff(a))[0])])
10000 loops, best of 3: 73.4 µs per loopIn [19]: %timeit island_cumsum_vectorized(a)
100000 loops, best of 3: 8.65 µs per loop

Bigger case :

In [20]: a = np.array([1, 1, 1, 0, 0, 0, 1, 1, 0, 0])In [21]: a = np.tile(a,1000)  # Repeat OP's data 1000 times# @Paul Panzer's solution
In [22]: %timeit np.concatenate([np.cumsum(c) if c[0] == 1 else c for c in np.split(a, 1 + np.where(np.diff(a))[0])])
100 loops, best of 3: 6.52 ms per loopIn [23]: %timeit island_cumsum_vectorized(a)
10000 loops, best of 3: 49.7 µs per loop

Nah, I want really huge case :

In [24]: a = np.array([1, 1, 1, 0, 0, 0, 1, 1, 0, 0])In [25]: a = np.tile(a,100000)  # Repeat OP's data 100000 times# @Paul Panzer's solution
In [26]: %timeit np.concatenate([np.cumsum(c) if c[0] == 1 else c for c in np.split(a, 1 + np.where(np.diff(a))[0])])
1 loops, best of 3: 725 ms per loopIn [27]: %timeit island_cumsum_vectorized(a)
100 loops, best of 3: 7.28 ms per loop
https://en.xdnf.cn/q/71765.html

Related Q&A

python 3 replacement for dircache?

Before I go reinventing the wheel, can anyone tell me if theres a drop-in (or semi-drop-in) replacement for the single-line statement:allfiles = dircache.listdir(.)

AES_128_CTR encryption by openssl and PyCrypto

Wondering the right way to convert a AES_128_CTR encryption by openssl to PyCrypto.First, I did an encryption by openssl as following:openssl enc -aes-128-ctr -in input.mp4 -out output.openssl.mp4 -K 7…

How can i determine the exact size of a type used by python

>>> sys.getsizeof(int) 436 #? does this mean int occupies 436 bytes .>>> sys.getsizeof(1) 12 #12 bytes for int object, is this the memory requirement.I thought int in python is repre…

Python list.clear complexity [duplicate]

This question already has answers here:Python list.clear() time and space complexity?(4 answers)Closed 2 years ago.What is the complexity of the Python 3 method list.clear() ?It is not given here: ht…

Unresolved import org.python / working with jython and java?

Im using Eclipse and I"m trying to create a java program that can run my python code. Im following the guidelines on this page: http://jythonpodcast.hostjava.net/jythonbook/en/1.0/JythonAndJavaInt…

elegant unpacking variable-length tuples

A real, if silly problem:https://github.com/joshmarshall/tornadorpc/blob/master/tornadorpc/base.pydef start_server(handlers, ...):...for (route, handler) in handlers:...Normally "handlers" is…

Extracting data from a 3D scatter plot in matplotlib

Im writing an interface for making 3D scatter plots in matplotlib, and Id like to access the data from a python script. For a 2D scatter plot, I know the process would be:import numpy as np from matpl…

How does Pythonic garbage collection with numpy array appends and deletes?

I am trying to adapt the underlying structure of plotting code (matplotlib) that is updated on a timer to go from using Python lists for the plot data to using numpy arrays. I want to be able to lower …

What is the default if I install virtualenv using pip and pip3 respectively?

I used sudo pip install virtualenv, then when I run virtualenv ENV in a directory, I get a Python 2 virtual enviroment.If I use pip3 install virtualenv to install virtualenv again, will it override the…

Flask Admin ModelView different fields between CREATE and EDIT

In a Flask app using Flask Admin I would like to be able to define different form fields in the Edit section of a ModelView than those in the Create section.The form_columns setting applies to both Cre…