【Pytroch】基于决策树算法的数据分类预测(Excel可直接替换数据)

【Pytroch】基于决策树算法的数据分类预测(Excel可直接替换数据)

  • 1.模型原理
  • 2.数学公式
  • 3.文件结构
  • 4.Excel数据
  • 5.下载地址
  • 6.完整代码
  • 7.运行结果

1.模型原理

决策树是一种常用的机器学习算法,用于分类和回归任务。它通过树状结构表示数据的决策过程,将数据集划分为不同的子集,并在每个子集上进行决策。决策树分类的原理可以概括如下:

  1. 特征选择: 决策树的分类过程始于根节点,根据某个特征将数据集划分为不同的子集。选择哪个特征用来进行划分是决策树算法的关键之一。常见的特征选择标准包括信息增益、基尼不纯度、方差等。信息增益表示划分前后熵的差异,基尼不纯度测量随机选择两个样本,其类别不一致的概率。

  2. 递归划分: 一旦选择了划分特征,数据集就会被分为多个子集,每个子集对应于划分特征的不同取值。对每个子集,都会重复进行相同的特征选择和划分过程,直到满足终止条件。终止条件可能包括以下情况:

    • 所有数据点属于同一类别。
    • 到达了事先设定的树的深度。
    • 某个节点上的数据点数目低于某个阈值。
  3. 叶节点分类: 当达到终止条件时,形成了决策树的叶节点。每个叶节点对应于一个分类标签,即决策树预测的输出。该标签可以是子集中样本最常见的标签,或者由其他方法确定。

  4. 预测: 对于新的未见样本,决策树通过从根节点开始,根据特征值依次沿着树的分支进行遍历,直到到达叶节点。在叶节点上的标签即为预测的类别。

  5. 剪枝(可选): 构建决策树时,可能会出现过拟合问题,即模型在训练数据上表现良好,但在新数据上表现较差。为了避免过拟合,可以进行剪枝操作,即通过移除某些分支来简化树的结构。

决策树的优点包括易于理解和解释、可处理数值型和分类型数据、能够处理缺失值等。然而,它也存在一些局限性,如容易过拟合、对数据的小变化敏感等。为了应对这些问题,通常会使用集成方法(如随机森林和梯度提升树)来进一步提升决策树的性能。

总的来说,决策树分类是一种基于树状结构的机器学习方法,通过特征选择和递归划分,将数据映射到不同的类别。其简单性和可解释性使其成为许多机器学习问题的有力工具。

2.数学公式

当涉及决策树算法的数学公式时,我们主要关注以下几个方面:信息增益、基尼不纯度以及决策树的构建和预测过程。

  1. 信息增益(Information Gain): 信息增益用于衡量特征选择的好坏,它表示在给定特征下,数据集的不确定性减少了多少。信息增益越大,意味着使用该特征进行划分可以获得更多的信息。

    信息增益 = 划分前的熵 − 加权划分后的熵 \text{信息增益} = \text{划分前的熵} - \text{加权划分后的熵} 信息增益=划分前的熵加权划分后的熵

    划分前的熵:

    H ( D ) = − ∑ i = 1 c p i log ⁡ 2 ( p i ) H(D) = - \sum_{i=1}^{c} p_i \log_2(p_i) H(D)=i=1cpilog2(pi)

    划分后的加权熵:

    H ( D ∣ A ) = ∑ i = 1 v ∣ D i ∣ ∣ D ∣ ⋅ H ( D i ) H(D|A) = \sum_{i=1}^{v} \frac{|D_i|}{|D|} \cdot H(D_i) H(DA)=i=1vDDiH(Di)

    其中, H ( D ) H(D) H(D) 是数据集的熵, p i p_i pi 是数据集中类别 i i i 的概率, H ( D ∣ A ) H(D|A) H(DA) 是在特征 A A A 下的条件熵, v v v 是特征 A A A 的取值数量, ∣ D i ∣ |D_i| Di 是属于取值 i i i 的样本数量。

  2. 基尼不纯度(Gini Impurity): 基尼不纯度用于衡量在数据集中随机选择两个样本,其类别不一致的概率。

    基尼不纯度 = 1 − ∑ i = 1 c p i 2 \text{基尼不纯度} = 1 - \sum_{i=1}^{c} p_i^2 基尼不纯度=1i=1cpi2

    其中, p i p_i pi 是数据集中类别 i i i 的概率。

  3. 决策树的构建过程: 决策树的构建过程基于递归,以下是一个简化的伪代码表示:

    function BuildDecisionTree(data, target, features):
        if all samples in target belong to the same class:
            return a leaf node with the class label
        if features is empty:
            return a leaf node with the most common class label in target
        choose the best feature based on information gain or Gini impurity
        create a decision node for the best feature
        for each unique value v of the best feature:
            create a branch from the decision node for value v
            recursively call BuildDecisionTree with subset data, target, and updated features
        return the decision node
    
    decision_tree = BuildDecisionTree(training_data, training_target, all_features)
    
  4. 决策树的预测过程: 决策树的预测过程通过从根节点开始沿着树的分支进行遍历,直到到达叶节点为止。

    function PredictSample(tree, sample):
        if tree is a leaf node:
            return the class label of the leaf node
        else:
            get the feature value from the sample
            if the value leads to a known branch:
                recursively call PredictSample on the branch
            else:
                return the most common class label among the leaf nodes' labels
    

这些数学公式和伪代码描述了决策树分类的核心原理,涉及到特征选择、信息增益、基尼不纯度以及决策树的构建和预测过程。实际实现中,可能还会涉及到剪枝等细节,但以上内容已经涵盖了决策树分类的基本原理。

3.文件结构

在这里插入图片描述

iris.xlsx						% 可替换数据集
Main.py							% 主函数

4.Excel数据

在这里插入图片描述

5.下载地址

- 资源下载地址

6.完整代码

import torch
import pandas as pd
from sklearn.model_selection import train_test_split  # Add this line
import numpy as np
import matplotlib.pyplot as plt

class DecisionTreeNode:
    def __init__(self, feature_index=None, label=None):
        self.feature_index = feature_index  # 特征索引
        self.label = label  # 叶节点标签
        self.children = {}  # 子节点字典

def entropy(labels):
    _, counts = torch.unique(labels, return_counts=True)
    probabilities = counts.float() / len(labels)
    return torch.sum(-probabilities * torch.log2(probabilities))

def information_gain(data, feature_index, target):
    feature_values = data[:, feature_index]
    unique_values = torch.unique(feature_values)
    total_entropy = entropy(target)

    weighted_entropy = 0.0
    for value in unique_values:
        mask = feature_values == value
        subset_target = target[mask]
        subset_entropy = entropy(subset_target)
        weighted_entropy += (len(subset_target) / len(target)) * subset_entropy

    return total_entropy - weighted_entropy

def build_decision_tree(data, target, features, max_depth=None, min_samples_split=2, current_depth=0):
    # 若样本中所有标签相同,则返回叶节点
    if torch.unique(target).size(0) == 1:
        return DecisionTreeNode(label=target[0].item())

    # 若达到最大深度或样本数不足以继续划分,则返回叶节点,将标签设为样本中最常见的标签
    if current_depth == max_depth or len(target) < min_samples_split:
        label = torch.mode(target).values.item()
        return DecisionTreeNode(label=label)

    # 若没有特征可用,则返回叶节点,将标签设为样本中最常见的标签
    if features.size(0) == 0:
        label = torch.mode(target).values.item()
        return DecisionTreeNode(label=label)

    best_feature_index = None
    best_info_gain = -1.0

    for feature_index in range(features.size(0)):
        info_gain = information_gain(data, feature_index, target)
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_feature_index = feature_index

    if best_info_gain == 0:
        label = torch.mode(target).values.item()
        return DecisionTreeNode(label=label)

    best_feature = features[best_feature_index]
    decision_tree = DecisionTreeNode(feature_index=best_feature_index)

    unique_values = torch.unique(data[:, best_feature_index])
    for value in unique_values:
        mask = data[:, best_feature_index] == value
        subset_data = data[mask]
        subset_target = target[mask]
        subset_features = features[features != best_feature]
        decision_tree.children[value.item()] = build_decision_tree(subset_data, subset_target, subset_features, max_depth, min_samples_split, current_depth + 1)

    return decision_tree

def predict_sample(tree, sample):
    if tree.label is not None:
        return tree.label

    feature_value = sample[tree.feature_index]
    if feature_value in tree.children:
        return predict_sample(tree.children[feature_value], sample)
    else:
        # 若在训练时未见过该特征值,则返回叶节点中最常见的标签
        labels = [child.label for child in tree.children.values()]
        return max(set(labels), key=labels.count)

def predict(tree, data):
    predictions = []
    for sample in data:
        prediction = predict_sample(tree, sample)
        predictions.append(prediction)
    return torch.tensor(predictions)

def accuracy(y_true, y_pred):
    return torch.sum(y_true == y_pred).item() / len(y_true)

def plot_confusion_matrix(conf_matrix, classes):
    plt.figure(figsize=(8, 6))
    plt.imshow(conf_matrix, cmap=plt.cm.Blues, interpolation='nearest')
    plt.title("Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.tight_layout()
    plt.show()

def plot_predictions_vs_true(y_true, y_pred):
    plt.figure(figsize=(10, 6))
    plt.plot(y_true, 'go', label='True Labels')
    plt.plot(y_pred, 'rx', label='Predicted Labels')
    plt.title("True Labels vs Predicted Labels")
    plt.xlabel("Sample Index")
    plt.ylabel("Class Label")
    plt.legend()
    plt.show()

def main():
    # 读取Data.xlsx文件并加载数据
    data = pd.read_excel("iris.xlsx")

    # 划分特征值和标签
    features = torch.tensor(data.iloc[:, :-1].values, dtype=torch.float32)
    labels = torch.tensor(data.iloc[:, -1].values, dtype=torch.long)

    # 将数据集拆分为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

    # 构建决策树
    # 构建决策树,限制最大深度为5,叶节点至少包含5个样本
    decision_tree = build_decision_tree(X_train, y_train, torch.arange(X_train.size(1)), max_depth=5,
                                        min_samples_split=5)

    # 进行预测
    y_pred = predict(decision_tree, X_test)

    # 计算准确率
    acc = accuracy(y_test, y_pred)
    print("模型的准确率:", acc)
    # 绘制混淆矩阵
    from sklearn.metrics import confusion_matrix
    conf_matrix = confusion_matrix(y_test, y_pred)
    classes = ['class_0', 'class_1', 'class_2']
    plot_confusion_matrix(conf_matrix, classes)

    # 绘制真实标签与预测标签对比图
    plot_predictions_vs_true(y_test, y_pred)

if __name__ == "__main__":
    main()

7.运行结果

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/72403.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Python常用函数】一文让你彻底掌握Python中的toad.quality函数

任何事情都是由量变到质变的过程&#xff0c;学习Python也不例外。只有把一个语言中的常用函数了如指掌了&#xff0c;才能在处理问题的过程中得心应手&#xff0c;快速地找到最优方案。本文和你一起来探索Python中的toad.quality函数&#xff0c;让你以最短的时间明白这个函数…

远程通信-RPC

项目场景&#xff1a; 在分布式微服务架构中&#xff0c;远程通信是最基本的需求。 常见的远程通信方式&#xff0c;有基于 REST 架构的 HTTP协议、RPC 框架。 下面&#xff0c;从三个维度了解一下 RPC。 1、什么是远程调用 2、什么是 RPC 3、RPC 的运用场景和优 什么是远程调用…

【Zabbix安装-5.5版本】

Zabbix安装&#xff08;rpm包安装&#xff09; Index of /zabbix/zabbix/5.5/rhel/8/x86_64/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror rpm包链接&#xff1a;https://mirrors.tuna.tsinghua.edu.cn/zabbix/zabbix/5.5/rhel/8/x86_64/zabbix-release-5.5-1.e…

【福建事业单位-数学运算】03经济利润-排列组合与概率

【福建事业单位-数学运算】03经济利润-排列组合与概率 一、经济利润1.1常规经济——考的多、难具体数值——方程无数值&#xff0c;给比例——赋值 1.2 分段计费1.3 函数最值&#xff08;销售总量 单价 * 销量 &#xff1b; 总利润 单利 * 销量&#xff09;总结 二、排列组合2…

f1tenth仿真设置

文章目录 一、安装依赖二、进入工作空间克隆三、编译四、运行 一、安装依赖 tf2_geometry_msgs ackermann_msgs joy map_server sudo apt-get install ros-noetic-tf2-geometry-msgs ros-noetic-ackermann-msgs ros-melodic-joy ros-noetic-map-server 二、进入工作空间克隆…

软件测试基础篇——LAMP环境搭建

LAMP 1、Linux系统的其他命令 find命令&#xff1a;在目录下查找文件 ​ 格式一&#xff1a;find 路径 参数 文件名 ​ 路径&#xff1a;如果没有指定路径&#xff0c;默认是在当前目录下 ​ 参数&#xff1a;-name 根据文件名来查找&#xff0c;区分大小写&#xff1b; -…

化工行业案例 | 甄知科技助力万华化学重构IT服务价值,打造信息中心ERP!

随着科技的发展&#xff0c;新材料的应用领域与日俱增&#xff0c;近年来&#xff0c;全球化工新材料产业发展整体步入高技术引领、产品迭代速度快、产业规模和需求不断扩大的阶段。一体化协同与数字化转型策略是实现化工新材料生产原料自给、节能降耗、降低排放和物料成本的重…

C语言必会题目(2)

W...Y的主页 &#x1f60a; 代码仓库分享&#x1f495; 今天继续分享C语言必会的题目&#xff0c;上一篇文章主要是一些选择题&#xff0c;而今天我们主要内容为编程题的推荐与讲解 准备好迎接下面的题了吗&#xff1f;开始发车了&#xff01;&#xff01;&#xff01; 输入…

数据结构【图的类型定义和存储结构】

数据结构之图 图的定义和概念图的定义图的术语 图的类型定义图的存储结构数组&#xff08;邻接矩阵&#xff09;表示法无向图的邻接矩阵表示法有向图的邻接矩阵表示法网&#xff08;即有权图&#xff09;的邻接矩阵表示法 邻接矩阵的ADT定义邻接表&#xff08;链式&#xff09;…

【C++】STL初识

1.STL的基本概念 2.vector存放内置数据类型 #include <iostream> using namespace std; #include <vector> #include <algorithm>void MyPrint(int val) {cout << val << endl; }void test01() {//创建vector容器对象&#xff0c;并且通过模板参…

Rx.NET in Action 中文介绍 前言及序言

Rx 处理器目录 (Catalog of Rx operators) 目标可选方式Rx 处理器(Operator)创建 Observable Creating Observables直接创建 By explicit logicCreate Defer根据范围创建 By specificationRangeRepeatGenerateTimerInterval Return使用预设 Predefined primitivesThrow …

wifi列表消失 后总结

故障现象&#xff1a; 管理源身份打开cmd &#xff0c;然后重启网络服务 Fn 加信号塔 开启二者为自动&#xff1a; 刷新网络&#xff1a; Fn 加信号塔 重启的时间可以放长一些 半个小时左右

小龟带你敲排序之冒泡排序

冒泡排序 一. 定义二.题目三. 思路分析&#xff08;图文结合&#xff09;四. 代码演示 一. 定义 冒泡排序&#xff08;Bubble Sort&#xff0c;台湾译为&#xff1a;泡沫排序或气泡排序&#xff09;是一种简单的排序算法。它重复地走访过要排序的数列&#xff0c;一次比较两个元…

自定义类型:结构体,枚举,联合

自定义类型&#xff1a;结构体&#xff0c;枚举&#xff0c;联合 结构体 1 结构体的声明 1.1 结构的基础知识 结构是一些值的集合&#xff0c;这些值称为成员变量。结构的每个成员可以是不同类型的变量。 1.2结构的声明 struct tag {member-listvariable-list }; 例如描述…

阿里云免费服务器领取入口_限制条件_申请指南

阿里云免费服务器领取&#xff0c;个人和企业用户均可以申请&#xff0c;个人免费服务器1核2GB 每月750小时&#xff0c;企业u1服务器2核8GB免费使用3个月&#xff0c;阿里云百科分享阿里云免费服务器申请入口、个人和企业免费配置、申请资格条件及云服务器免费使用时长&#x…

【数学建模】--时间序列分析

时间序列分析概念与时间序列分解模型 定义&#xff1a;时间序列也称动态序列&#xff0c;是指将某种现象的指标数值按照时间顺序排列而成的数值序列。时间序列分析大致可分成三大部分&#xff0c;分别是描述过去&#xff0c;分线规律和预测未来&#xff0c;本讲将主要介绍时间序…

双指针算法

文章目录 双指针算法leetcode题目 双指针算法 双指针算法可以实现对于时间复杂度降一维度&#xff0c;使得O(n2)的算法时间复杂度变为O(n) 指针类型 对撞指针快慢指针 对撞指针 一般是用于顺序结构中的&#xff0c;也可以称为左右指针&#xff0c;从两端向中间移动&#xff0c…

超导热催生meme,换汤不换药的投机轮回

文/章鱼哥 出品/陀螺财经 币圈对炒作meme概念的热情从未消亡过。 随着一种名为LK-99的物质被发现&#xff0c;围绕超导的兴奋不仅激发了科学界&#xff0c;加密货币相关概念也与之沸腾。不出所料&#xff0c;与此前围绕元宇宙、AI大肆炒作一样&#xff0c;许多meme代币已经出现…

Maven基础总结

前言 Maven 是一个项目管理工具&#xff0c;可以对 Java 项目进行构建、依赖管理。 基本要求掌握 配置Maven环境直接查。 得会在IDEA创建Maven的java项目吧、会创建Maven的web项目吧、会创建多模块项目吧。 得会配置插件pligin、依赖dependency吧 一、Maven四大特性 1、…

github版面混乱加载不出的解决办法

最近出现打开github 界面加载不成功&#xff0c;网页访问乱码&#xff0c;打开chrome的检查发现 github的github.githubassets.com 拒绝访问&#xff0c; 解法&#xff1a; 1.先打开hosts文件所在的目录C:\Windows\System32\drivers\etc 2.右键点击hosts文件-选择用记事本或者…