Plotting Precision-Recall curve when using cross-validation in scikit-learn

2024/10/14 15:28:25

I'm using cross-validation to evaluate the performance of a classifier with scikit-learn and I want to plot the Precision-Recall curve. I found an example on scikit-learn`s website to plot the PR curve but it doesn't use cross validation for the evaluation.

How can I plot the Precision-Recall curve in scikit learn when using cross-validation?

I did the following but i'm not sure if it's the correct way to do it (psudo code):

for each k-fold:precision, recall, _ =  precision_recall_curve(y_test, probs)mean_precision += precisionmean_recall += recallmean_precision /= num_folds
mean_recall /= num_foldsplt.plot(recall, precision)

What do you think?

Edit:

it doesn't work because the size of precision and recall arrays are different after each fold.

anyone?

Answer

Instead of recording the precision and recall values after each fold, store the predictions on the test samples after each fold. Next, collect all the test (i.e. out-of-bag) predictions and compute precision and recall.

 ## let test_samples[k] = test samples for the kth fold (list of list)## let train_samples[k] = test samples for the kth fold (list of list)for k in range(0, k):model = train(parameters, train_samples[k])predictions_fold[k] = predict(model, test_samples[k])# collect predictionspredictions_combined = [p for preds in predictions_fold for p in preds]## let predictions = rearranged predictions s.t. they are in the original order## use predictions and labels to compute lists of TP, FP, FN## use TP, FP, FN to compute precisions and recalls for one run of k-fold cross-validation

Under a single, complete run of k-fold cross-validation, the predictor makes one and only one prediction for each sample. Given n samples, you should have n test predictions.

(Note: These predictions are different from training predictions, because the predictor makes the prediction for each sample without having been previously seen it.)

Unless you are using leave-one-out cross-validation, k-fold cross validation generally requires a random partitioning of the data. Ideally, you would do repeated (and stratified) k-fold cross validation. Combining precision-recall curves from different rounds, however, is not straight forward, since you cannot use simple linear interpolation between precision-recall points, unlike ROC (See Davis and Goadrich 2006).

I personally calculated AUC-PR using the Davis-Goadrich method for interpolation in PR space (followed by numerical integration) and compared the classifiers using the AUC-PR estimates from repeated stratified 10-fold cross validation.

For a nice plot, I showed a representative PR curve from one of the cross-validation rounds.

There are, of course, many other ways of assessing classifier performance, depending on the nature of your dataset.

For instance, if the proportion of (binary) labels in your dataset is not skewed (i.e. it is roughly 50-50), you could use the simpler ROC analysis with cross-validation:

Collect predictions from each fold and construct ROC curves (as before), collect all the TPR-FPR points (i.e. take the union of all TPR-FPR tuples), then plot the combined set of points with possible smoothing. Optionally, compute AUC-ROC using simple linear interpolation and the composite trapezoid method for numerical integration.

https://en.xdnf.cn/q/69402.html

Related Q&A

The SECRET_KEY setting must not be empty || Available at Settings.py

I tried to find this bug, but dont know how to solve it.I kept getting error message "The SECRET_KEY setting must not be empty." when executing populate_rango.pyI have checked on settings.py …

Pandas: Applying Lambda to Multiple Data Frames

Im trying to figure out how to apply a lambda function to multiple dataframes simultaneously, without first merging the data frames together. I am working with large data sets (>60MM records) and I …

scipy.minimize - TypeError: numpy.float64 object is not callable running

Running the scipy.minimize function "I get TypeError: numpy.float64 object is not callable". Specifically during the execution of:.../scipy/optimize/optimize.py", line 292, in function_w…

Flask, not all arguments converted during string formatting

Try to create a register page for my app. I am using Flask framework and MySQL db from pythonanywhere.com. @app.route(/register/, methods=["GET","POST"]) def register_page(): try:f…

No module named objc

Im trying to use cocoa-python with Xcode but it always calls up the error:Traceback (most recent call last):File "main.py", line 10, in <module>import objc ImportError: No module named …

Incompatible types in assignment (expression has type List[nothing], variable has type (...)

Consider the following self-contained example:from typing import List, UnionT_BENCODED_LIST = Union[List[bytes], List[List[bytes]]] ret: T_BENCODED_LIST = []When I test it with mypy, I get the followin…

How to convert XComArg to string values in Airflow 2.x?

Code: from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults from airflow.providers.google.cloud.hooks.gcs import GCSHookclass GCSUploadOperator(BaseOperator):@appl…

Python dryscrape scrape page with cookies

I wanna get some data from site, which requires loggin in. I log in by requestsurl = "http://example.com" response = requests.get(url, {"email":"[email protected]", "…

Python retry using the tenacity module

Im having having difficulty getting the tenacity library to work as expected. The retry in the following test doesnt trigger at all. I would expect a retry every 5 seconds and for the log file to refle…

How to write own logging methods for own logging levels

Hi I would like to extend my logger (taken by logging.getLogger("rrcheck")) with my own methods like: def warnpfx(...):How to do it best? My original wish is to have a root logger writing …