Computing AUC and ROC curve from multi-class data in scikit-learn (sklearn)?

2024/10/2 20:24:26

I am trying to use the scikit-learn module to compute AUC and plot ROC curves for the output of three different classifiers to compare their performance. I am very new to this topic, and I am struggling to understand how the data I have should input to the roc_curve and auc functions.

For each item within the testing set, I have the true value and the output of each of the three classifiers. The classes are ['N', 'L', 'W', 'T']. In addition, I have a confidence score for each value output from the classifiers. How do I pass this information to the roc_curve function?

Do I need to label_binarize my input data? How do I convert a list of [class, confidence] pairs output by the classifiers into the y_score expected by roc_curve?

Thank you for any help! Good resources about ROC curves would also be helpful.

Answer

You need to use label_binarize function and then you can plot a multi-class ROC.

Example using Iris data:

import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
from sklearn.multiclass import OneVsRestClassifier
from itertools import cycle
plt.style.use('ggplot')iris = datasets.load_iris()
X = iris.data
y = iris.target# Binarize the output
y = label_binarize(y, classes=[0, 1, 2])
n_classes = y.shape[1]X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, random_state=0)classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True,random_state=0))
y_score = classifier.fit(X_train, y_train).decision_function(X_test)fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])roc_auc[i] = auc(fpr[i], tpr[i])
colors = cycle(['blue', 'red', 'green'])
for i, color in zip(range(n_classes), colors):plt.plot(fpr[i], tpr[i], color=color, lw=1.5,label='ROC curve of class {0} (area = {1:0.2f})'''.format(i, roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=1.5)
plt.xlim([-0.05, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic for multi-class data')
plt.legend(loc="lower right")
plt.show()

enter image description here

https://en.xdnf.cn/q/70806.html

Related Q&A

Nested dictionary

I am working on some FASTA-like sequences (not FASTA, but something I have defined thats similar for some culled PDB from the PISCES server).I have a question. I have a small no of sequences called nCa…

How to create a Python script to automate software installation? [closed]

Closed. This question is seeking recommendations for books, tools, software libraries, and more. It does not meet Stack Overflow guidelines. It is not currently accepting answers.We don’t allow questi…

Using __str__() method in Django on Python 2

I am Learning Django using the Django project tutorial. Since I use python 2.7 I am unable to implement the following in python 2.7:from django.db import modelsclass Question(models.Model): # ...def _…

Losing merged cells border while editing Excel file with openpyxl

I have two sheets in an Excel file and the first one is a cover sheet which I dont need to edit. There are a few merged cells in the cover sheet, and when I edit the file using openpyxl, without even t…

Reset CollectorRegistry of Prometheus lib after each unit test

I have a class A that initializes a Counter in its initfrom prometheus_client import Counter class A:def __init__(self):self.my_counter = Counter(an_awesome_counter)def method_1(self):return 1def metho…

Why does Django ORM allow me to omit parameters for NOT NULL fields when creating an object?

Stupid question time. I seem to be able to create an object for a Django model even though I omit a column that was defined as NOT NULL and I dont understand why. Heres my model:class Movie(models.Mo…

Where should i do the django validations for objects and fields?

Im creating a django application which uses both the Django Rest Framework and the plain django-views as entrypoint for users.I want to do validation both independant fields of my models, and on object…

Why does bytes.fromhex() treat some hex values strangely?

Im trying to use the socket library in Python to send bytes of two hex digits to a piece of hardware programmed to accept them. To create the bytes from a user-entered string of hex digits, Im trying …

Extracting beats out of MP3 music with Python

What kind of solutions are there to analyze beats out of MP3 music in Python? The purpose of this would be to use rhythm information to time the keyframes of generated animation, export animation as v…

Switch to popup in python using selenium

How do I switch to a popup window in the below selenium program. Ive looked up for all possible solutions but havent been able to get my head around them. Please help!!from selenium import webdriver fr…