Convert decision tree directly to png [duplicate]

2024/9/16 22:58:47

I am trying to generate a decision tree which I want to visualize using dot. The resulting dotfile shall be converted to png.

While I can do the last conversion step in dos using something like

export_graphviz(dectree, out_file="graph.dot")

followed by a DOS command

dot -Tps graph.dot -o outfile.ps

doing all this directly in python dows not work and generates an error

AttributeError: 'list' object has no attribute 'write_png'

This is the program code I have tried:

from sklearn import tree  
import pydot
import StringIO# Define training and target set for the classifier
train = [[1,2,3],[2,5,1],[2,1,7]]
target = [10,20,30]# Initialize Classifier. Random values are initialized with always the same random seed of value 0 
# (allows reproducible results)
dectree = tree.DecisionTreeClassifier(random_state=0)
dectree.fit(train, target)# Test classifier with other, unknown feature vector
test = [2,2,3]
predicted = dectree.predict(test)dotfile = StringIO.StringIO()
tree.export_graphviz(dectree, out_file=dotfile)
graph=pydot.graph_from_dot_data(dotfile.getvalue())
graph.write_png("dtree.png")

What am I missing?

Answer

I ended up using pydotplus:

from sklearn import tree  
import pydotplus
import StringIO# Define training and target set for the classifier
train = [[1,2,3],[2,5,1],[2,1,7]]
target = [10,20,30]# Initialize Classifier. Random values are initialized with always the same random seed of value 0 
# (allows reproducible results)
dectree = tree.DecisionTreeClassifier(random_state=0)
dectree.fit(train, target)# Test classifier with other, unknown feature vector
test = [2,2,3]
predicted = dectree.predict(test)dotfile = StringIO.StringIO()
tree.export_graphviz(dectree, out_file=dotfile)
graph=pydotplus.graph_from_dot_data(dotfile.getvalue())
graph.write_png("dtree.png")

EDIT: Thanks for the comment, to get this running in pydot I'd have to write:

(graph,)=pydot.graph_from_dot_data(dotfile.getvalue())
https://en.xdnf.cn/q/72415.html

Related Q&A

Python: can I modify a Tuple?

I have a 2 D tuple (Actually I thought, it was a list.. but the error says its a tuple) But anyways.. The tuple is of form: (floatnumber_val, prod_id) now I have a dictionary which contains key-> p…

Saving scatterplot animations

Ive been trying to save an animated scatterplot with matplotlib, and I would prefer that it didnt require totally different code for viewing as an animated figure and for saving a copy. The figure show…

Pandas: Bin dates into 30 minute intervals and calculate averages

I have a Pandas dataframe with two columns which are speed and time.speed date 54.72 1:33:56 49.37 1:33:59 37.03 1:34:03 24.02 7:39:58 28.02 7:40:01 24.04 7:40:04 24.02 7:40:07 25.35 …

Regular expression for UK Mobile Number - Python

I need a regular expression that only validates UK mobile numbers. A UK mobile number can be between 10-14 digits and either starts with 07, or omits the 0 and starts with 447. Importantly, if the user…

Iterate through all the rows in a table using python lxml xpath

This is the source code of the html page I want to extract data from.Webpage: http://gbgfotboll.se/information/?scr=table&ftid=51168 The table is at the bottom of the page <html><tab…

Django: Serializing a list of multiple, chained models

Given two different models, with the same parent base class. Is there any way, using either Django Rest Framework Serializers or serpy, to serialize a chained list containing instances of both the chil…

Formatting cells in Excel with Python

How do I format cells in Excel with python?In particular I need to change the font of several subsequent rows to be regular instead of bold.

What is the legality of scraping YouTube data? [closed]

Closed. This question does not meet Stack Overflow guidelines. It is not currently accepting answers.This question does not appear to be about programming within the scope defined in the help center.Cl…

Numpy: fast calculations considering items neighbors and their position inside the array

I have 4 2D numpy arrays, called a, b, c, d, each of them made of n rows and m columns. What I need to do is giving to each element of b and d a value calculated as follows (pseudo-code):min_coords = m…

How to see all the databases and Tables in Databricks

i want to list all the tables in every database in Azure Databricks. so i want the output to look somewhat like this: Database | Table_name Database1 | Table_1 Database1 | Table_2 Database1 | Table_3 D…