draw random element in numpy

2024/10/8 10:57:15

I have an array of element probabilities, let's say [0.1, 0.2, 0.5, 0.2]. The array sums up to 1.0.

Using plain Python or numpy, I want to draw elements proportional to their probability: the first element about 10% of the time, second 20%, third 50% etc. The "draw" should return index of the element drawn.

I came up with this:

def draw(probs):cumsum = numpy.cumsum(probs / sum(probs)) # sum up to 1.0, just in casereturn len(numpy.where(numpy.random.rand() >= cumsum)[0])

It works, but it's too convoluted, there must be a better way. Thanks.

Answer
import numpy as np
def random_pick(choices, probs):'''>>> a = ['Hit', 'Out']>>> b = [.3, .7]>>> random_pick(a,b)'''cutoffs = np.cumsum(probs)idx = cutoffs.searchsorted(np.random.uniform(0, cutoffs[-1]))return choices[idx]

How it works:

In [22]: import numpy as np
In [23]: probs = [0.1, 0.2, 0.5, 0.2]

Compute the cumulative sum:

In [24]: cutoffs = np.cumsum(probs)
In [25]: cutoffs
Out[25]: array([ 0.1,  0.3,  0.8,  1. ])

Compute a uniformly distributed random number in the half-open interval [0, cutoffs[-1]):

In [26]: np.random.uniform(0, cutoffs[-1])
Out[26]: 0.9723114393023948

Use searchsorted to find the index where the random number would be inserted into cutoffs:

In [27]: cutoffs.searchsorted(0.9723114393023948)
Out[27]: 3

Return choices[idx], where idx is that index.

https://en.xdnf.cn/q/70131.html

Related Q&A

Python Gevent Pywsgi server with ssl

Im trying to use gevent.pywsgi.WSGIServer to wrap a Flask app. Everything works fine, however, when I try to add a key and a certificate for ssl, its not even able to accept any clients anymore.This is…

unexpected keyword argument buffering - python client

I am receiving the error as "getresponse() got an unexpected keyword argument buffering". Complete error log is :[INFO ] Kivy v1.8.0 [INFO ] [Logger ] Record lo…

numpy and pandas timedelta error

In Python I have an array of dates generated (or read from a CSV-file) using pandas, and I want to add one year to each date. I can get it working using pandas but not using numpy. What am I doing wron…

Pandas - split large excel file

I have an excel file with about 500,000 rows and I want to split it to several excel file, each with 50,000 rows.I want to do it with pandas so it will be the quickest and easiest.any ideas how to make…

Unable to verify secret hash for client at REFRESH_TOKEN_AUTH

Problem"Unable to verify secret hash for client ..." at REFRESH_TOKEN_AUTH auth flow. {"Error": {"Code": "NotAuthorizedException","Message": "Unab…

save a dependecy graph in python

I am using in python3 the stanford dependency parser to parse a sentence, which returns a dependency graph. import pickle from nltk.parse.stanford import StanfordDependencyParserparser = StanfordDepend…

What are the specific rules for constant folding?

I just realized that CPython seems to treat constant expressions, which represent the same value, differently with respect to constant folding. For example:>>> import dis >>> dis.dis(…

installing opencv for python on mavericks

I am trying to install opencv on a Macbook Pro late 2013 with mavericks. I didnt find any binaries so I am trying to build it. I tried http://www.guidefreitas.com/installing-opencv-2-4-2-on-mac-osx-mou…

Python 3 reading CSV file with line breaks in rows

I have a large CSV file with one column and line breaks in some of its rows. I want to read the content of each cell and write it to a text file but the CSV reader is splitting the cells with line brea…

Python appending dictionary, TypeError: unhashable type?

abc = {} abc[int: anotherint]Then the error came up. TypeError: unhashable type? Why I received this? Ive tried str()