Classification tree in sklearn giving inconsistent answers

2024/9/28 11:25:14

I am using a classification tree from sklearn and when I have the the model train twice using the same data, and predict with the same test data, I am getting different results. I tried reproducing on a smaller iris data set and it worked as predicted. Here is some code

from sklearn import tree
from sklearn.datasets import irisclf = tree.DecisionTreeClassifier()
clf.fit(iris.data, iris.target)
r1 = clf.predict_proba(iris.data)clf.fit(iris.data, iris.target)
r2 = clf.predict_proba(iris.data)

r1 and r2 are the same for this small example, but when I run on my own much larger data set I get differing results. Is there a reason why this would occur?

EDIT After looking into some documentation I see that DecisionTreeClassifier has an input random_state which controls the starting point. By setting this value to a constant I get rid of the problem I was previously having. However now I'm concerned that my model is not as optimal as it could be. What is the recommended method for doing this? Try some randomly? Or are all results expected to be about the same?

Answer

The DecisionTreeClassifier works by repeatedly splitting the training data, based on the value of some feature. The Scikit-learn implementation lets you choose between a few splitting algorithms by providing a value to the splitter keyword argument.

  • "best" randomly chooses a feature and finds the 'best' possible split for it, according to some criterion (which you can also choose; see the methods signature and the criterion argument). It looks like the code does this N_feature times, so it's actually quite like a bootstrap.

  • "random" chooses the feature to consider at random, as above. However, it also then tests randomly-generated thresholds on that feature (random, subject to the constraint that it's between its minimum and maximum values). This may help avoid 'quantization' errors on the tree where the threshold is strongly influenced by the exact values in the training data.

Both of these randomization methods can improve the trees' performance. There are some relevant experimental results in Lui, Ting, and Fan's (2005) KDD paper.

If you absolutely must have an identical tree every time, then I'd re-use the same random_state. Otherwise, I'd expect the trees to end up more or less equivalent every time and, in the absence of a ton of held-out data, I'm not sure how you'd decide which random tree is best.

See also: Source code for the splitter

https://en.xdnf.cn/q/71345.html

Related Q&A

Modifying binary file with Python

i am trying to patch a hex file. i have two patch files (hex) named "patch 1" and "patch 2"the file to be patched is a 16 MB file named "file.bin".i have tried many differ…

python error : module object has no attribute AF_UNIX

this is my python code :if __name__ == __main__: import socket sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.connect((0.0.0.0, 4000)) import time time.sleep(2) #sock.send(1)print …

How to speed up pandas string function?

I am using the pandas vectorized str.split() method to extract the first element returned from a split on "~". I also have also tried using df.apply() with a lambda and str.split() to produc…

sqlalchemy autoloaded orm persistence

We are using sqlalchemys autoload feature to do column mapping to prevent hardcoding in our code.class users(Base):__tablename__ = users__table_args__ = {autoload: True,mysql_engine: InnoDB,mysql_chars…

Data Normalization with tensorflow tf-transform

Im doing a neural network prediction with my own datasets using Tensorflow. The first I did was a model that works with a small dataset in my computer. After this, I changed the code a little bit in or…

Relationship of metaclasss __call__ and instances __init__?

Say Ive got a metaclass and a class using it:class Meta(type):def __call__(cls, *args):print "Meta: __call__ with", argsclass ProductClass(object):__metaclass__ = Metadef __init__(self, *args…

How to present numpy array into pygame surface?

Im writing a code that part of it is reading an image source and displaying it on the screen for the user to interact with. I also need the sharpened image data. I use the following to read the data an…

Following backreferences of unknown kinds in NDB

Im in the process of writing my first RESTful web service atop GAE and the Python 2.7 runtime; Ive started out using Guidos shiny new ndb API.However, Im unsure how to solve a particular case without t…

How to enable math in sphinx?

I am using sphinx with the pngmath extension to document my code that has a lot of mathematical expressions. Doing that in a *.rst file is working just fine.a \times b becomes: However, if I try the sa…

How to set the xticklabels for date in matplotlib

I am trying to plot values from two list. The x axis values are date. Tried these things so faryear = [20070102,20070806,20091208,20111109,20120816,20140117,20140813] yvalues = [-0.5,-0.5,-0.75,-0.75,…