1、导入决策树模型,划分数据集
from sklearn import datasets
from sklearn.tree import decisiontreeclassifier
iris=datasets.load_iris()
iris_x=iris.data
iris_y=iris.target
indices = np.random.permutation(len(iris_x))
iris_x_train = iris_x[indices[:-10]]
iris_y_train = iris_y[indices[:-10]]
iris_x_test = iris_x[indices[-10:]]
iris_y_test = iris_y[indices[-10:]]
2、训练模型
clf = decisiontreeclassifier(max_depth=4)
clf.fit(iris_x_train, iris_y_train)
3、 可视化显示决策树模型
from ipython.display import image
from sklearn import tree
import pydotplus
dot_data = tree.export_graphviz(clf, out_file=none,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=true, rounded=true,
special_characters=true)
graph = pydotplus.graph_from_dot_data(dot_data)
#是 graphviz 库中的一个函数,用于从 dot 格式的数据创建一个图形对象。dot 是一种描述图形结构的简单文本格式,通常用于描述图形、网络和流程图。
image(graph.create_png())
如以上代码正常运行需安装graphviz,请参照链接
graphviz安装配置教程(图文详解)
配置完成后,如不生效的话请重启应用或则计算机。
4、为测试数据集分类
iris_y_predict = clf.predict(iris_x_test)
score=clf.score(iris_x_test,iris_y_test,sample_weight=none)
print('iris_y_predict = ')
print(iris_y_predict)
print('iris_y_test = ')
print(iris_y_test)
print('accuracy:',score)
iris_y_predict =
[1 2 0 1 0 1 1 1 1 0]
iris_y_test =
[1 2 0 1 0 1 1 1 1 0]
accuracy: 1.0
发表评论