How to get ROC curve for decision tree?

2024/11/17 1:30:24

I am trying to find ROC curve and AUROC curve for decision tree. My code was something like

clf.fit(x,y)
y_score = clf.fit(x,y).decision_function(test[col])
pred = clf.predict_proba(test[col])
print(sklearn.metrics.roc_auc_score(actual,y_score))
fpr,tpr,thre = sklearn.metrics.roc_curve(actual,y_score)

output:

 Error()
'DecisionTreeClassifier' object has no attribute 'decision_function'

basically, the error is coming up while finding the y_score. Please explain what is y_score and how to solve this problem?

Answer

First of all, the DecisionTreeClassifier has no attribute decision_function.

If I guess from the structure of your code , you saw this example

In this case the classifier is not the decision tree but it is the OneVsRestClassifier that supports the decision_function method.

You can see the available attributes of DecisionTreeClassifier here

A possible way to do it is to binarize the classes and then compute the auc for each class:

Example:

from sklearn import datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.tree import DecisionTreeClassifier
from scipy import interpiris = datasets.load_iris()
X = iris.data
y = iris.targety = label_binarize(y, classes=[0, 1, 2])
n_classes = y.shape[1]X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, random_state=0)classifier = DecisionTreeClassifier()y_score = classifier.fit(X_train, y_train).predict(X_test)fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])roc_auc[i] = auc(fpr[i], tpr[i])# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])#ROC curve for a specific class here for the class 2
roc_auc[2]

Result

0.94852941176470573
https://en.xdnf.cn/q/71277.html

Related Q&A

pandas - stacked bar chart with timeseries data

Im trying to create a stacked bar chart in pandas using time series data:DATE TYPE VOL0 2010-01-01 Heavy 932.6129031 2010-01-01 Light 370.6129032 2010-01-01 Medium 569.4516133 …

Get element at position with Selenium

Is it possible to either run or get the same functionality provided by document.elementFromPoint using a Selenium webdriver?

Facing obstacle to install pyodbc and pymssql in ubuntu 16.04

I want to install pyodbc for connection mssql server using sqlalchemy I am googling and tried in several ways like pip install pyodbcFollowed this link Pyodbc installation error on Ubuntu 16.04 with S…

Cross entropy loss suddenly increases to infinity

I am attempting to replicate an deep convolution neural network from a research paper. I have implemented the architecture, but after 10 epochs, my cross entropy loss suddenly increases to infinity. Th…

Converting each element of a list to tuple

to convert each element of list to tuple like following : l = [abc,xyz,test]convert to tuple list: newl = [(abc,),(xyz,),(test,)]Actually I have dict with keys like this so for searching purpose I need…

Python, Zeep response to pandas

I am tryng to conenct to a SOAP webservice and use pandas to put in on a table.Zeep give me this list:[{ssPeca: 103,ssQtd: 1,ssUn: un }, {ssPeca: 291A,ssQtd: 8,ssUn: un }, {ssPeca: 406B,ssQtd: 8,ssUn: …

Adjust the distance only between two subplots in matplotlib

I have 3 subplots (3 rows and 1 column). We can use fig.subplots_adjust(hspace=0.2) to adjust the distance between the subplots. this will change the distance between subplots for all case. How can I h…

Many-to-many multi-database join with Flask-SQLAlchemy

Im trying to make this many-to-many join work with Flask-SQLAlchemy and two MySQL databases, and its very close except its using the wrong database for the join table. Heres the basics... Ive got main_…

Merge two rows in the same Dataframe if their index is the same?

I have created a large Dataframe by pulling data from an Azure database. The construction of the dataframe wasnt simple as I had to do it in parts, using the concat function to add new columns to the d…

How do I use the Postgresql ANY operator in a NOT IN statement

Using Pyscopg2, how do I pass a Python list into an SQL statement using the ANY Operator?Normal Working SQL reads (See SQL Fiddle):SELECT * FROM student WHERE id NOT IN (3);Using Psycopg2 as below:Psy…