Sklearn Decision Rules for Specific Class in Decision tree

2024/9/20 22:34:43

I am creating a decision tree.My data is of the following type

X1 |X2 |X3|.....X50|Y
_____________________________________
1  |5  |7 |.....0  |1
1.5|34 |81|.....0  |1
4  |21 |21|.... 1  |0
65 |34 |23|.....1  |1

I am trying following code to execute:

X_train = data.iloc[:,0:51]
Y_train = data.iloc[:,51]
clf = DecisionTreeClassifier(criterion = "entropy", random_state = 100,max_depth=8, min_samples_leaf=15)
clf.fit(X_train, y_train)

What I want i decision rules which predict the specific class(In this case "0").For Example,

when X1 > 4 && X5> 78 && X50 =100 Then Y = 0 ( Probability =84%)
When X4 = 56 && X39 < 100 Then Y = 0 ( Probability = 93%)
...

So basically I want all the leaf nodes,decision rules attached to them and probability of Y=0 coming,those predict the Class Y = "0".I also want to print those decision rules in the above specified format.

I am not interested in the decision rules which predict (Y=1)

Thanks, Any help would be appreciated

Answer

Based on http://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html

Assuming that probabilities equal to proportion of classes in each node, e.g. if leaf holds 68 instances with class 0 and 15 with class 1 (i.e. value in tree_ is [68,15]) probabilities are [0.81927711, 0.18072289].

Generarate a simple tree, 4 features, 2 classes:

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import make_classification
from sklearn.cross_validation import train_test_split
from sklearn.tree import _treeX, y = make_classification(n_informative=3, n_features=4, n_samples=200, n_redundant=1, random_state=42, n_classes=2)
feature_names = ['X0','X1','X2','X3']
Xtrain, Xtest, ytrain, ytest = train_test_split(X,y, random_state=42)
clf = DecisionTreeClassifier(max_depth=2)
clf.fit(Xtrain, ytrain)

Visualize it:

from sklearn.externals.six import StringIO  
from sklearn import tree
import pydot 
dot_data = StringIO() 
tree.export_graphviz(clf, out_file=dot_data) 
graph = pydot.graph_from_dot_data(dot_data.getvalue()) [0]
graph.write_jpeg('1.jpeg')

enter image description here

Create a function for printing a condition for one instance:

node_indicator = clf.decision_path(Xtrain)
n_nodes = clf.tree_.node_count
feature = clf.tree_.feature
threshold = clf.tree_.threshold
leave_id = clf.apply(Xtrain)def value2prob(value):return value / value.sum(axis=1).reshape(-1, 1)def print_condition(sample_id):print("WHEN", end=' ')node_index = node_indicator.indices[node_indicator.indptr[sample_id]:node_indicator.indptr[sample_id + 1]]for n, node_id in enumerate(node_index):if leave_id[sample_id] == node_id:values = clf.tree_.value[node_id]probs = value2prob(values)print('THEN Y={} (probability={}) (values={})'.format(probs.argmax(), probs.max(), values))continueif n > 0:print('&& ', end='')if (Xtrain[sample_id, feature[node_id]] <= threshold[node_id]):threshold_sign = "<="else:threshold_sign = ">"if feature[node_id] != _tree.TREE_UNDEFINED:print("%s %s %s" % (feature_names[feature[node_id]],#Xtrain[sample_id,feature[node_id]] # actual valuethreshold_sign,threshold[node_id]),end=' ')

Call it on the first row:

>>> print_condition(0)
WHEN X1 > -0.2662498950958252 && X0 > -1.1966443061828613 THEN Y=1 (probability=0.9672131147540983) (values=[[ 2. 59.]])

Call it on all rows where predicted value is zero:

[print_condition(i) for i in (clf.predict(Xtrain) == 0).nonzero()[0]]
https://en.xdnf.cn/q/72120.html

Related Q&A

Cubic hermit spline interpolation python

I would like to calculate a third-degree polynomial that is defined by its function values and derivatives at specified points.https://en.wikipedia.org/wiki/Cubic_Hermite_splineI know of scipys interpo…

Increase Accuracy of float division (python)

Im writing a bit of code in PyCharm, and I want the division to be much more accurate than it currently is (40-50 numbers instead of about 15). How Can I accomplish this?Thanks.

Twitter API libraries for desktop apps?

Im looking for a way to fetch recent posts from twitter. Really I just want to be able to grab and store new posts about a certain topic from twitter in a text file. Are there any current programs or l…

How to generate a PDF from an HTML / CSS (including images) source in Python? [closed]

As it currently stands, this question is not a good fit for our Q&A format. We expect answers to be supported by facts, references, or expertise, but this question will likely solicit debate, argum…

Modify subclassed string in place

Ive got the following string subclass:class S(str):def conc(self, next_val, delimiter = ):"""Concatenate values to an existing string"""if not next_val is None:self = sel…

sum numpy ndarray with 3d array along a given axis 1

I have an numpy ndarray with shape (2,3,3),for example:array([[[ 1, 2, 3],[ 4, 5, 6],[12, 34, 90]],[[ 4, 5, 6],[ 2, 5, 6],[ 7, 3, 4]]])I am getting lost in np.sum(above ndarray ,axis=1), why …

Get the number of nonzero elements in a numpy array?

Is it possible to get the length of the nonzero elements in a numpy array without iterating over the array or masking the array. Speed is the main goal of calculating the length.Essentially, something…

Pytest on Python Tools for visual studio

Can debug python tests which are using pytest library on visual studio 2010 ? I added the -m pytest on the Interpreter arguments but the breakpoints are not hit, I can only run the test script without…

Python Paramiko directory walk over SFTP

How to do os.walk() but on another computer through SSH? The problem is that os.walk() executes on a local machine and I want to ssh to another host, walk through a directory and generate MD5 hashes f…

Python 2.7 32-bit install on Win 7: No registry keys?

I have downloaded the Python 2.7.2 Windows x86 32-bit MSI from python.org and installed it on a 64-bit Windows 7 system. Everything works (at least the command-line interpreter starts and runs), but t…