Getting the maximum accuracy for a binary probabilistic classifier in scikit-learn

2024/10/12 20:21:10

Is there any built-in function to get the maximum accuracy for a binary probabilistic classifier in scikit-learn?

E.g. to get the maximum F1-score I do:

# AUCPR
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_true, y_score)    
auprc  = sklearn.metrics.auc(recall, precision)
max_f1 = 0
for r, p, t in zip(recall, precision, thresholds):if p + r == 0: continueif (2*p*r)/(p + r) > max_f1:max_f1 = (2*p*r)/(p + r) max_f1_threshold = t

I could compute the maximum accuracy in a similar fashion:

accuracies = []
thresholds = np.arange(0,1,0.1)
for threshold in thresholds:y_pred = np.greater(y_score, threshold).astype(int)accuracy = sklearn.metrics.accuracy_score(y_true, y_pred)accuracies.append(accuracy)accuracies = np.array(accuracies)
max_accuracy = accuracies.max() 
max_accuracy_threshold =  thresholds[accuracies.argmax()]

but I wonder whether there is any built-in function.

Answer
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curvefpr, tpr, thresholds = roc_curve(y_true, probs)
accuracy_scores = []
for thresh in thresholds:accuracy_scores.append(accuracy_score(y_true, [m > thresh for m in probs]))accuracies = np.array(accuracy_scores)
max_accuracy = accuracies.max() 
max_accuracy_threshold =  thresholds[accuracies.argmax()]
https://en.xdnf.cn/q/69613.html

Related Q&A

Pydantic does not validate when assigning a number to a string

When assigning an incorrect attribute to a Pydantic model field, no validation error occurs. from pydantic import BaseModelclass pyUser(BaseModel):username: strclass Config:validate_all = Truevalidate_…

PyUsb USB Barcode Scanner

Im trying to output a string from a barcode or qrcode using a Honeywell USB 3310g scanner in Ubuntu. I have libusb and a library called metro-usb (http://gitorious.org/other/metro-usb) which are enabli…

Count unique dates in pandas dataframe

I have a dataframe of surface weather observations (fzraHrObs) organized by a station identifier code and date. fzraHrObs has several columns of weather data. The station code and date (datetime object…

Miniforge / VScode - Python is not installed and virtualenv is not found

I have been stuck on this issue for several days, so any help is greatly appreciated. I recently had to move away from Anaconda (due to their change in the commercial policy) and decided to try Minifo…

How to merge pandas table by regex

I am wondering if there a fast way to merge two pandas tables by the regular expression in python .For example: table A col1 col2 1 apple_3dollars_5 2 apple_2dollar_4 1 o…

Scipy Optimize is only returning x0, only completing one iteration

I am using scipy optimize to get the minimum value on the following function: def randomForest_b(a,b,c,d,e):return abs(rf_diff.predict([[a,b,c,d,e]]))I eventually want to be able to get the optimal val…

Order of sess.run([op1, op2...]) in Tensorflow

I wonder whats the running order of the op list in sess.run(ops_list, ...). for example:for a typical classification scenario: _, loss = sess.run([train_op, loss_op]), if train_op run first,then the lo…

Django form validation: get errors in JSON format

I have this very simple Django formfrom django import formsclass RegistrationForm(forms.Form):Username = forms.CharField()Password = forms.CharField()I manage this manually and dont use the template en…

Django inheritance and polymorphism with proxy models

Im working on a Django project that I did not start and I am facing a problem of inheritance. I have a big model (simplified in the example) called MyModel that is supposed to represents different kind…

L suffix in long integer in Python 3.x

In Python 2.x there was a L suffix after long integer. As Python 3 treats all integers as long integer this has been removed. From Whats New In Python 3.0:The repr() of a long integer doesn’t include …