当前位置: 代码网 > it编程>前端脚本>Python > 机器学习_决策树

机器学习_决策树

2024年08月02日 Python 我要评论
本文为机器学习中关于决策树的概念和构建决策树的python代码

目录

一、概念

1、概念

2、结束条件

3、熵

二、算法&具体例题计算

1、id3信息增益

2、c4.5 信息增益率

3、基尼指数

二-一 例题

(1)id3信息增益:

(2)gini

三、python实例

1、基尼指数

(1)构建gini代码计算acc

(2)决策树可视化

2、id3信息增益

(1)构建id3代码计算acc

(2)决策树可视化


一、概念

1、概念

 决策树分二叉树和非二叉树,通过分析某些特征对数据集进行划分,从而预测新的待测样本。简单来说就是想象一棵树,根节点代表我们对数据集特征选取,而树枝和叶子代表在特征属性下对数据集的不同判断和结果。

2、结束条件

(1)没有特征可以往下分了

(2)剩下的样本都属于同一个类别

(3)剩下的样本数少于设定的阈值

(4)达到设定的深度

3、熵

那么决策树如何构建?怎么决定哪个特征先判断?——由此,引出这个概念,即表示不确定性、混乱程度。熵越大越混乱。

熵的表达式为:

后续计算信息增益/增益率需要用到信息熵。

二、算法&具体例题计算

1、id3信息增益

1)计算集合d的信息熵h(d):

2)计算特征a条件下d的信息熵h(d|a),即a的影响程度大不大:

3)最后算差值计算信息增益:gain(d,a) = h(d) - h(a)

2、c4.5 信息增益率

其实就是上面的信息增益 / 特征a的信息熵:

3、基尼指数

基尼指数越小纯度越高。​​​​​​​

1)计算整个样本d的基尼指数:

 

2)计算在特征a的条件下集合d的基尼指数,即a用于判断集合d的影响:

3)最后选择基尼指数较小的特征进行划分。

二-一 例题

下面举个具体的例子来说明这三个算法的决策树构建:

样本属性分类
x1x2
1tt
2tf
3tf×
4ft
5ft×
6ft

(1)id3信息增益:

h(d) = -(4/6)log(4/6) - (2/6)log(2/6)   =0.918

h(t1) = -(2/3)log(2/3) - (1/3)log(1/3) = 0.918

h(f1) = -(2/3)log(2/3) - (1/3)log(1/3) = 0.918

gain(d,x1) = h(d) - [ (1/2)h(t1) + (1/2)h(f1) ] = 0

h(t2) = -(3/4)log(3/4)-(1/4)log(1/4) = 0.811

h(f2) = -(1/2)log(1/2) - (1/2)log(1/2) = 1

gain(d,x2) =h(d) -[ (4/6)h(t2) + (2/6)h(f2) ] = 0.918-0.874 = 0.044

0.044>0,选x2为第一个特征。

(2)gini

​​​​​​​

0.444>0.417,基尼指数越小纯度越高所以选x2作为划分特征。

三、python实例

1、基尼指数

(1)构建gini代码计算acc

import numpy as np
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
from sklearn import tree

def load_date():
    transform = transforms.compose([transforms.totensor(),transforms.normalize((0.1307, ),(0.3081, ))])
    dataset_train = datasets.mnist(root='../data/minist',train=true,download=true,transform=transform)
    dataset_test = datasets.mnist(root='../data/minist',train=false,download=true,transform=transform)
    x_train = dataset_train.data.numpy()
    x_test = dataset_test.data.numpy()
    x_train = np.reshape(x_train,(60000,784))
    x_test  =np.reshape(x_test,(10000,784))
    y_train = dataset_train.targets.numpy()
    y_test = dataset_test.targets.numpy()

    return x_train,y_train,x_test,y_test


if __name__ == '__main__':
    train_x,train_y,test_x,test_y = load_date()
    cart = tree.decisiontreeclassifier(criterion='gini',max_depth=8,random_state=5)
    # cart = tree.decisiontreeclassifier(criterion='entropy',max_depth=8)
    cart = cart.fit(train_x,train_y)
    acc = cart.score(test_x,test_y)
    print("准确率:",acc)

(2)决策树可视化

    #可视化
    plt.figure(figsize=(12.8, 6.4))
    plt.subplot(121)
    tree.plot_tree(cart)
    plt.title("gini")
    plt.show()

2、id3信息增益

(1)构建id3代码计算acc

if __name__ == '__main__':
    train_x,train_y,test_x,test_y = load_data()
    # cart = tree.decisiontreeclassifier(criterion='gini',max_depth=3,random_state=5)
    cart = tree.decisiontreeclassifier(criterion='entropy',max_depth=8) 
    cart = cart.fit(train_x,train_y)
    acc = cart.score(test_x,test_y)
    print("准确率:",acc)

(2)决策树可视化

    #可视化
    plt.figure(figsize=(12.8, 6.4))
    plt.subplot(121)
    tree.plot_tree(cart)
    plt.title("id3")
    plt.show()

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com