机器学习实验二-----决策树构建

决策树是机器学习中一种基本的分类和回归算法,是依托于策略抉择而建立起来的树。本文学习的是决策树的分类

1. 构建决策树流程

  1. 选择算法:常用的算法包括ID3、C4.5、CART等。

  2. 划分节点:根据数据特征和算法选择,递归地划分节点,直到满足停止条件。

  3. 决策树剪枝:对决策树进行剪枝操作,减少决策树的复杂度,提高泛化能力。

  4. 决策树评估:使用测试数据集评估决策树模型的性能,通常使用准确率、召回率、F1值等。

2.常用的三个算法

2.1 ID3

D3采用信息增益来划分属性。

2.12 信息熵

用来衡量数据集的混乱程度,信息熵越大,表明数据集的混乱程度越大,不确定性越大。

公式:H(X=xi)=-\sum_{1}^{n}pi log pi

其中pi表示的是分类为xi这个样本在中的占比。

2.12信息增益

 划分数据集之前之后信息发生的变化      

公式:Gain(D,a)=Ent(D)-\sum_{v=1}^{V}\frac{\left | D^{v} \right |}{\left | D \right |}Ent(D^{v})

信息增益越大,则意味着采用该属性a划分节点获得的纯度提升更大。在每次划分中采用信息增益最大的划分。

信息增益实际上就是数据集整体的信息熵减去使用特征 a进行划分后各子集的加权平均信息熵,即子集的信息熵的期望值。当信息增益越大时,意味着子集的信息熵的减少量越大,即数据集的不确定性减少的程度更大,信息熵变小。

2.2   C4.5

C4.5算法在ID3算法上做了提升,使用信息增益比来构造决策树,且有剪枝功能防止过拟合。

信息增益比:特征a对训练集D的信息增益比定义为特征a的信息增益与训练集D对于a的信息熵之比, 同样是信息增益比越大越好。

公式:Gain_radio(D,a)=\frac{Gain(D,a)}{H(D,a)}

先剪枝:提前停止树的构建而对树”剪枝“,提前停止的策略有定义一个树的深度,到达指定深度自动停止构造;

后剪枝:先构造完整的子树,对于决策树中信息增益比较低的子树用叶子节点代替。

2.3 CART基尼指数

基尼指数是衡量数据集纯度或不确定性的一种指标,常用于决策树算法中的特征选择和节点划分。

公式:\text{Gini}(D) = 1 - \sum_{k=1}^{K} (p_k)^2

基尼指数越小越好。

3.划分节点

划分节点就是根据我们选择的算法来进行划分的,我们这边拿C4.5算法来举例一下。拿鸢尾花来举例一下,我们有花萼长度,花萼宽度,花瓣长度,花瓣宽度四个特征值。我们分别计算一下每一个的熵,根据公式计算出信息增益比,选择按照信息增益比大小排序的特征来当划分的依据。这边假设我们排序就是花萼长度,花萼宽度,花瓣长度,花瓣宽度,那我们先按照花萼长度把根节点划分为左右子树,子树再根据花萼宽度把子树再继续划分,一直用递归来划分每一个节点。最后就得到整棵树。

4. 决策树的剪枝

决策树生成算法递归的产生决策树,直到不能继续下去为止,这样产生的树往往对训练数据的分类很准确,但对未知测试数据的分类缺没有那么精确,即会出现过拟合现象。

过拟合产生的原因在于在学习时过多的考虑如何提高对训练数据的正确分类,从而构建出过于复杂的决策树,解决方法是考虑决策树的复杂度,对已经生成的树进行简化。

剪枝:从已经生成的树上裁掉一些子树或叶节点,并将其根节点或父节点作为新的叶子节点,从而简化分类树模型。

剪枝分为预剪枝与后剪枝。

预剪枝是指在决策树的生成过程中,对每个节点在划分前先进行评估,若当前的划分不能带来泛化性能的提升,则停止划分,并将当前节点标记为叶节点。

后剪枝是指先从训练集生成一颗完整的决策树,然后自底向上对非叶节点进行考察,若将该节点对应的子树替换为叶节点,能带来泛化性能的提升,则将该子树替换为叶节点。

5.决策树构建

这个函数直接依据我们所给的数据用C4.5创建了决策树

clf.fit(X_train, y_train)

6.决策树可视化

使用了export_graphviz函数从训练好的决策树模型中生成一个Graphviz格式的文本文件,然后使用graphviz.Source将这个文本文件转换为一个Graphviz对象,最后使用render方法将这个对象渲染为图形文件。

dot_data = export_graphviz(clf, out_file=None,
                     feature_names=["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"],
                     class_names=["setosa", "versicolor", "virginica"],
                     filled=True, rounded=True,
                     special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("iris_decision_tree") 

graph.view()

 运行这个代码得到创建的决策树会包含我们数据集内的所有东西,我的鸢尾花数据集中每个特征的信息增益比在不同的条件下就得到下面这棵树的划分标准。

完整代码展现:

import pandas as pd
from sklearn.tree import DecisionTreeClassifier


def load_data(train_file, test_file):
    train_data = pd.read_csv(train_file, sep='\s+')
    test_data = pd.read_csv(test_file, sep='\s+')

    X_train = train_data[["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"]].values
    y_train = train_data["Species"].values
    X_test = test_data[["Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width"]].values
    y_test = test_data["Species"].values

    return X_train, y_train, X_test, y_test


clf = DecisionTreeClassifier()

train_file = "C:\\Users\\李烨\\Desktop\\新建文件夹\\6\\iris.txt"
test_file = "C:\\Users\\李烨\\Desktop\\新建文件夹\\6\\iristest.txt"
X_train, y_train, X_test, y_test = load_data(train_file, test_file)

clf.fit(X_train, y_train)


def predict_flower(sepal_length, sepal_width, petal_length, petal_width):
    input_features = [[sepal_length, sepal_width, petal_length, petal_width]]
    prediction = clf.predict(input_features)
    if prediction[0] == 'setosa':
        print("预测类别:setosa")
    elif prediction[0] == 'versicolor':
        print("预测类别:versicolor")
    elif prediction[0] == 'virginica':
        print("预测类别:virginica")
    else:
        print("未知类别")
    return prediction[0]


def get_input():
    sepal_length = float(input("请输入花萼长度:"))
    sepal_width = float(input("请输入花萼宽度:"))
    petal_length = float(input("请输入花瓣长度:"))
    petal_width = float(input("请输入花瓣宽度:"))
    return sepal_length, sepal_width, petal_length, petal_width


print("实验二决策树分类")
while True:
    try:
        user_input = input("输入 'exit' 退出:")
        if user_input.lower() == 'exit':
            print("程序结束")
            break
            
            
        sepal_length, sepal_width, petal_length, petal_width = get_input()
        result = predict_flower(sepal_length, sepal_width, petal_length, petal_width)

    except ValueError:
            print("输入有误,请重新输入。")



本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/552847.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

鉴源实验室丨智能网联汽车协议模糊测试技术概述

作者 | 乔琪 上海控安可信软件创新研究院工控网络安全组 来源 | 鉴源实验室 社群 | 添加微信号“TICPShanghai”加入“上海控安51fusa安全社区” 摘要:随着智能网联汽车的快速发展,其协议安全性和稳定性成为了关注焦点。智能网联汽车协议特点主要表现为…

javaweb http

1、http简介 HTTP 超文本传输协议(HTTP-Hyper Text transfer protocol),是一个属于应用层的面向对象的协议,由于其简捷、快速的方式,适用于分布式超媒体信息系统。它于1990年提出,经过十几年的使用与发展&…

【绘图软件】solidworks2022版本安装强制卸载(清理干净版!)全流程~超详细+报错避坑

安装包可以去【盒子部落】下载,也可以用下面网盘直接下载: 夸克网盘链接:https://pan.quark.cn/s/f2c110617bfa(无密码) 1.解压后进入文件夹 2.进入_SolidSQUAD_文件夹 3.复制SolidWorks_Flexnet_Server文件夹 4.打开C…

MySQL 的事务

事务概念 MySQL事务是一个或者多个的数据库操作,要么全部执行成功,要么全部失败回滚。 事务是通过事务日志来实现的,事务日志包括:redo log和undo log。 事务状态 事务有以下五种状态: 活动的部分提交的失败的中止的…

基于SpringBoot+Vue社区医院服务平台(源码+文档+包运行)

一.系统概述 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的实施在技术上已逐步成熟。本文介绍了社区医院信息平台的开发全过程。通过分析社区医院信息平台管理的不足,创建了一个计算机管理社区医院信息平台的方案。文章介绍了社区医院信息…

2024Guitar Pro 8.1 Mac 最新下载、安装、激活、换机图文教程

吉他爱好者必备神器:Guitar Pro v8.1.1 Build 17深度解析 随着数字音乐制作和学习的日益普及,越来越多的吉他爱好者开始寻找能够帮助他们提升技能、创作音乐的专业工具。在众多吉他制作软件中,Guitar Pro因其强大的功能和易用的界面备受推崇…

Qt 6子窗口全屏显示

一、全屏显示效果 二、全屏相关函数 1,全屏显示函数 QWidget::showFullScreen(); // 此方法只对顶级窗口有效,对子窗口无效 2,恢复显示函数 QWidget::showNormal(); // 此方法也只对顶级窗口有效,对子窗口无效 3,最小化显示函…

使用Flask和Flask-JWT-Extended保护API免受跨站请求攻击

在本文中,我们将探讨如何使用Flask和Flask-JWT-Extended库来保护您的API免受跨站请求攻击(CSRF)。我们将首先简要介绍CSRF攻击的概念,然后详细说明如何使用Flask-JWT-Extended库来保护您的API。 什么是跨站请求攻击(C…

【C++题解】1565. 成绩(score)

问题:1565. 成绩(score) 类型:基本运算、小数运算 题目描述: 牛牛最近学习了 C 入门课程,这门课程的总成绩计算方法是: 总成绩作业成绩 20% 小测成绩 30% 期末考试成绩 50%。 牛牛想知道&am…

聊聊应用商城评分4.9的Apipost IDEA插件

Apipost Helper,作为IDEA插件,可以快速生成和查询API文档,直观友好地在IDE中调试接口。它简化了开发流程并提升效率,即使新手也能够迅速掌握。Apipost Helper提供了诸多便捷功能,如通过代码查找接口或者通过接口查找代…

C++教你如何模拟实现string,如何实现string写时拷贝

文章目录 前言成员变量默认成员函数默认构造函数拷贝构造函数析构函数赋值运算符重载 容量相关函数(Capacity)reserve函数resize函数size函数capacity 函数clear函数 修改函数(Modifiers)swap函数insert函数字符插入字符串插入 ap…

论文阅读-Federated-Unlearning-With-Momentum-Degradation

论文阅读-Federated Unlearning With Momentum Degradation 联邦忘却与动量退化 Yian Zhao IEEE Internet of Things Journal 2023 年 10 月 2 日 CCF-C momentum degradation-MoDe 动量退化 memory guidance-记忆引导 knowledge erasure-知识擦除 Deep-learning neural n…

【笔记】vscode debug进入site-packages包源码

选择左侧栏第三个图标,点击创建 launch.json 文件 选择 Python Debugger 选择Python文件 这里可以看到launch.json 文件 在configurations中添加键值对 "justMyCode": false在文件中打上断点,点击"三角符"号开始调试 按F11或者红框…

ROS 2边学边练(27)-- 创建一个launch文件

前言 ROS 2中的启动系统负责帮助用户描述其系统的配置,然后按描述执行。系统的配置包括运行什么程序,在哪里运行,传递什么参数,以及ROS特定的约定,这些约定通过为每个组件提供不同的配置,使其易于在整个系统…

[stm32]DMA使用

自动重装和M2M(软件trig)不能一起使用,否则会停不下来 void MyDMA_Init(uint32_t AddrA,uint32_t AddrB,uint16_t Size){RCC_AHBPeriphClockCmd(RCC_AHBPeriph_DMA1,ENABLE);DMA_InitTypeDef DMA_InitStructure;DMA_InitStructure.DMA_PeripheralBaseAddrAddrA;//外…

go语言并发实战——日志收集系统(三) 利用sarama包连接KafKa实现消息的生产与消费

环境的搭建 Kafka以及相关组件的下载 我们要实现今天的内容,不可避免的要进行对开发环境的配置,Kafka环境的配置比较繁琐,需要配置JDK,Scala,ZoopKeeper和Kafka,这里我们不做赘述,如果大家不知道如何配置环境&#x…

STM32芯片flash被锁导致Error Flash Download failed Cortex-M4,解决办法(全)亲测有效

STM32芯片flash被锁导致Error: Flash Download failed - "Cortex-M4",解决办法(全)亲测有效🤩! 方法1:由于Keil 中debug的仿真器配置出错导致的下载失败(这种问题虽然是低级错误&…

友思特应用 | 红外视角的延伸:短波红外相机的机器视觉应用

导读 短波红外SWIR在不同波段针对不同材料的独特成像特征为各领域检测应用的拓宽提供了基础。本文将展现短波红外成像技术在水分检测、塑料检测、太阳能电池板检查和矿场开采等领域的丰富应用案例,讨论短波红外相机在未来的发展方向。 SWIR 背景简介 短波红外 &am…

基于SpringBoot+Vue的IT技术交流平台(源码+文档+包运行)

一.系统概述 我国科学技术的不断发展,计算机的应用日渐成熟,其强大的功能给人们留下深刻的印象,它已经应用到了人类社会的各个层次的领域,发挥着重要的不可替换的作用。信息管理作为计算机应用的一部分,使用计算机进行…

PHP货运搬家/拉货小程序二开源码搭建的功能

运搬家/拉货小程序的二次开发可以添加许多功能,以增强用户体验和提高业务效率。以下是一些可能的功能: 用户端功能: 注册登录:允许用户创建个人账户并登录以使用应用程序。货物发布:允许用户发布他们需要搬运的货物信息…