【Sklearn】基于随机森林算法的数据分类预测(Excel可直接替换数据)

【Sklearn】基于随机森林算法的数据分类预测(Excel可直接替换数据)

  • 1.模型原理
    • 1.1 模型原理
    • 1.2 数学模型
  • 2.模型参数
  • 3.文件结构
  • 4.Excel数据
  • 5.下载地址
  • 6.完整代码
  • 7.运行结果

1.模型原理

随机森林(Random Forest)是一种集成学习方法,通过组合多个决策树来构建强大的分类或回归模型。随机森林的模型原理和数学模型如下:

1.1 模型原理

随机森林是一种集成学习方法,它结合了多个决策树来改善预测的准确性和鲁棒性。每个决策树都是独立地训练,并且它们的预测结果综合起来形成最终的预测。随机森林的主要思想是构建一个“森林”,其中每棵树都是一个分类器,而每个分类器都在随机的数据子集上进行训练。在预测时,通过投票或平均来综合所有分类器的结果。

随机森林的主要步骤:

  1. 随机抽样(Bootstrap抽样): 从原始训练数据中随机抽取多个样本,允许同一个样本在一个抽样中出现多次,形成一个新的训练集。

  2. 随机特征选择: 对每个决策树的训练过程中,在节点分裂时,只考虑部分特征,而不是全部特征。这样有助于增加树之间的多样性,减少过拟合。

  3. 独立训练: 对于每个样本和每个决策树,使用随机抽样的训练数据和随机选择的特征进行训练,得到多棵独立的决策树。

  4. 预测聚合: 在预测时,将每棵树的预测结果进行投票(分类问题)或平均(回归问题),以决定最终的分类或预测值。

1.2 数学模型

随机森林的数学模型是由多个决策树组成的集合,每个决策树都是一个独立的分类器或回归器。随机森林的预测是通过对每个决策树的预测结果进行综合得到的。可以用以下形式表示:

F ( x ) = 1 T ∑ t = 1 T f t ( x ) F(x) = \frac{1}{T} \sum_{t=1}^{T} f_t(x) F(x)=T1t=1Tft(x)

其中, F ( x ) F(x) F(x)表示随机森林的预测结果, T T T表示决策树的数量, f t ( x ) f_t(x) ft(x)表示第 t t t棵决策树的预测结果。

在训练每棵决策树时,随机森林通过随机抽样和随机特征选择增加了每棵树之间的多样性,从而减少了过拟合的风险。在预测时,通过对多个决策树的预测结果进行综合,提高了模型的准确性和稳定性。

总之,随机森林通过构建多个独立的决策树,并对它们的预测结果进行综合,从而创建了一个强大的集成模型,适用于分类和回归任务。

2.模型参数

RandomForestClassifierscikit-learn中随机森林分类器的类,它具有多个参数可以调整。以下是你提到的参数以及它们的说明:

  1. n_estimators: 随机森林中决策树的数量。默认为100。

  2. criterion: 衡量分割质量的标准。可以是"gini"(基尼系数)或"entropy"(信息熵)。默认是"gini"。

  3. max_depth: 决策树的最大深度。默认为None,表示不限制深度。

  4. min_samples_split: 节点分裂所需的最小样本数。默认为2。

  5. min_samples_leaf: 叶节点所需的最小样本数。默认为1。

  6. min_weight_fraction_leaf: 叶节点所需的最小权重分数总和。默认为0。

  7. max_features: 寻找最佳分割时要考虑的特征数量。可以是整数、浮点数、字符串或None。默认是"auto",意味着"sqrt(n_features)"。

  8. max_leaf_nodes: 最大叶节点数。默认为None。

  9. min_impurity_decrease: 分割需要达到的最小不纯度减少量。默认为0。

  10. bootstrap: 是否对数据进行有放回抽样。默认为True。

  11. oob_score: 是否计算袋外(oob)准确率。默认为False。

  12. n_jobs: 并行处理的作业数。默认为None,表示使用1个作业。

  13. random_state: 随机数生成器的种子,用于重现随机结果。

  14. class_weight: 类别权重,用于处理不平衡数据集。

  15. verbose: 控制训练过程中的输出信息。默认为0,不显示输出。

这些参数可以根据你的数据集和问题进行调整,以获得最佳的模型性能。

3.文件结构

在这里插入图片描述

iris.xlsx						% 可替换数据集
Main.py							% 主函数

4.Excel数据

在这里插入图片描述

5.下载地址

- 资源下载地址

6.完整代码

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

def random_forest_classification(data_path, test_size=0.2, random_state=42):
    # 加载数据
    data = pd.read_excel(data_path)

    # 分割特征和标签
    X = data.iloc[:, :-1]  # 所有列除了最后一列
    y = data.iloc[:, -1]   # 最后一列

    # 划分训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)

    # 创建随机森林分类器
    # 1. ** n_estimators: ** 随机森林中决策树的数量。默认为100。
    # 2. ** criterion: ** 衡量分割质量的标准。可以是"gini"(基尼系数)或"entropy"(信息熵)。默认是"gini"。
    # 3. ** max_depth: ** 决策树的最大深度。默认为None,表示不限制深度。
    # 4. ** min_samples_split: ** 节点分裂所需的最小样本数。默认为2。
    # 5. ** min_samples_leaf: ** 叶节点所需的最小样本数。默认为1。
    # 6. ** min_weight_fraction_leaf: ** 叶节点所需的最小权重分数总和。默认为0。
    # 7. ** max_features: ** 寻找最佳分割时要考虑的特征数量。可以是整数、浮点数、字符串或None。默认是"auto",意味着"sqrt(n_features)"。
    # 8. ** max_leaf_nodes: ** 最大叶节点数。默认为None。
    # 9. ** min_impurity_decrease: ** 分割需要达到的最小不纯度减少量。默认为0。
    # 10. ** bootstrap: ** 是否对数据进行有放回抽样。默认为True。
    # 11. ** oob_score: ** 是否计算袋外(oob)准确率。默认为False。
    # 12. ** n_jobs: ** 并行处理的作业数。默认为None,表示使用1个作业。
    # 13. ** random_state: ** 随机数生成器的种子,用于重现随机结果。
    # 14. ** class_weight: ** 类别权重,用于处理不平衡数据集。
    # 15. ** verbose: ** 控制训练过程中的输出信息。默认为0,不显示输出。
    model = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=random_state)

    # 在训练集上训练模型
    model.fit(X_train, y_train)

    # 在测试集上进行预测
    y_pred = model.predict(X_test)

    # 计算准确率
    accuracy = accuracy_score(y_test, y_pred)
    return confusion_matrix(y_test, y_pred), y_test.values, y_pred, accuracy

if __name__ == "__main__":
    # 使用函数进行分类任务
    data_path = "iris.xlsx"
    confusion_mat, true_labels, predicted_labels, accuracy = random_forest_classification(data_path)

    print("真实值:", true_labels)
    print("预测值:", predicted_labels)
    print("准确率:{:.2%}".format(accuracy))

    # 绘制混淆矩阵
    plt.figure(figsize=(8, 6))
    sns.heatmap(confusion_mat, annot=True, fmt="d", cmap="Blues")
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted Labels")
    plt.ylabel("True Labels")
    plt.show()

    # 用圆圈表示真实值,用叉叉表示预测值
    # 绘制真实值与预测值的对比结果
    plt.figure(figsize=(10, 6))
    plt.plot(true_labels, 'o', label="True Labels")
    plt.plot(predicted_labels, 'x', label="Predicted Labels")

    plt.title("True Labels vs Predicted Labels")
    plt.xlabel("Sample Index")
    plt.ylabel("Label")
    plt.legend()
    plt.show()


7.运行结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/73391.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习】PyTorch快速入门

【深度学习】学习PyTorch基础 介绍PyTorch 深度学习框架是一种软件工具,旨在简化和加速构建、训练和部署深度学习模型的过程。深度学习框架提供了一系列的函数、类和工具,用于定义、优化和执行各种深度神经网络模型。这些框架帮助研究人员和开发人员专注…

RabbitMQ 安装教程

RabbitMQ 安装教程 特殊说明 因为RabbitMQ基于Erlang开发,所以安装时需要先安装Erlang RabbitMQ和Erlang版本对应关系 查看地址:www.rabbitmq.com/which-erlan… 环境选择 Erlang: 23.3及以上 RabbitMQ: 3.10.1Windows 安装 1. 安装Erlang 下载地…

Spark 学习记录

基础 SparkContext是什么?有什么作用? https://blog.csdn.net/Shockang/article/details/118344357 SparkContext 是什么? SparkContext 是通往 Spark 集群的唯一入口,可以用来在 Spark 集群中创建 RDDs 、累加和广播变量( Br…

【css】渐变

渐变是设置一种颜色或者多种颜色之间的过度变化。 两种渐变类型: 线性渐变(向下/向上/向左/向右/对角线) 径向渐变(由其中心定义) 1、线性渐变 语法:background-image: linear-gradient(direction, co…

黑客必备的操作系统——kali linux安装

大家经常会在电视里面看到各种炫酷的黑客操作,那么黑客一般用什么操作系统呢?今天小训带大家来安装黑客必备的kali linux-2022操作系统,有兴趣的一起来学习下吧! 1、安装前准备 1.1 VMware下载 VMware官网下载: ht…

开源力量再现,国产操作系统商业化的全新探索

文章目录 1. 开源运动的兴起2. 开源力量的推动3. 国产操作系统的崭露头角3.1 国产操作系统有哪些 4.国产操作系统的商业化探索5.开源力量对国产操作系统商业化的推动 操作系统作为连接硬件、中间件、数据库、应用软件的纽带,被认为是软件技术体系中最核心的基础软件…

linux4.0新增32位ARM的系统调用

修改内核源码 Linux系统为每一个系统调用赋予一个系统调用号。当应用程序执行一个系统调用时,应用程序就可以知道执行和调用到哪个系统调用了,从而不会造成混乱。系统调用号一旦分配之后就不会有任何变更,否则已经编译好的应用程序就不能运行…

runtime error: member access within misaligned address(力扣最常见错误之一)

runtime error: member access within misaligned address(力扣最常见错误之一) 前言原因和解决办法总结 前言 最近博主在刷力扣时,明明代码逻辑都没问题,但总是报下面这个错误: runtime error: member access within…

SQL-每日一题【1251. 平均售价】

题目 Table: Prices Table: UnitsSold 编写SQL查询以查找每种产品的平均售价。average_price 应该四舍五入到小数点后两位。 查询结果格式如下例所示: 解题思路 1.题目要求查询每种产品的平均售价。给出了两个表,我们用聚合查询来解决此问题。 2.首先我…

腾讯会议:云上协奏,远程韶华

腾讯会议的原理及历史 摘要 本论文介绍了腾讯会议的原理和历史。腾讯会议是一款基于云计算和通信技术的在线会议平台,由腾讯公司推出。通过分析腾讯会议的工作原理和演进历史,我们可以深入了解该平台是如何实现高效、便捷、安全的远程协作和沟通的。 1. 引言 近年来,随着…

ruby send call 的简单使用

refer: ruby on rails - What does .call do? - Stack Overflow Ruby使用call 可以调用方法或者proc m 12.method("") # > method gets the method defined in the Fixnum instance # m.class # > Methodm.call(3) #> 15 # 3 is passed inside the…

Mr. Cappuccino的第63杯咖啡——Spring之AnnotationConfigApplicationContext源码分析

Spring之AnnotationConfigApplicationContext源码分析 源码分析 源码分析 以上一篇文章《Spring之Bean的生命周期》的代码进行源码分析 AnnotationConfigApplicationContext applicationContext new AnnotationConfigApplicationContext(SpringConfig02.class); LifeCycleBe…

如何在Stream流中分组统计

上面是今天碰到需求,之前就做过类似的分组统计,这个相对来说比较简单,统计的也少,序号和总预约人数这两部分交给前端了,不需要由后端统计,后端统计一下预约日期和检查项目和预约人数就行; Overridepublic List<ItemStatisticsVo> statistics(ItemStatisticsModel itemSta…

智能与本体

世界的本体是一个复杂而广泛的话题&#xff0c;可以根据不同的学科、思想体系和信仰背景来进行不同的解释和理解。它涉及到人类对于现实和存在的思考&#xff0c;以及对于世界本质的追寻和探索。 在哲学上&#xff0c;世界的本体指的是存在的实质或基本特征。它探讨了世界的本源…

[保研/考研机试] KY85 二叉树 北京大学复试上机题 C++实现

题目链接&#xff1a; 二叉树https://www.nowcoder.com/share/jump/437195121692000296981 描述 如上所示&#xff0c;由正整数1&#xff0c;2&#xff0c;3……组成了一颗特殊二叉树。我们已知这个二叉树的最后一个结点是n。现在的问题是&#xff0c;结点m所在的子树中一共包…

线程记录(2)

1.线程状态 NEW : 分配内存地址&#xff0c;创建线程 RUNNABLE&#xff1a;&#xff08;就绪/运行&#xff09;调用start()之后&#xff08;/没有调度CPU调度&#xff09; BLOCKED&#xff1a;还未拿到锁&#xff0c;等待、被阻塞&#xff08;拿到synchronized失败状态&…

AI 绘画Stable Diffusion 研究(七) 一文读懂 Stable Diffusion 工作原理

大家好&#xff0c;我是风雨无阻。 本文适合人群&#xff1a; 想要了解AI绘图基本原理的朋友。 对Stable Diffusion AI绘图感兴趣的朋友。 本期内容&#xff1a; Stable Diffusion 能做什么 什么是扩散模型 扩散模型实现原理 Stable Diffusion 潜扩散模型 Stable Diffu…

TFRecords详解

内容目录 TFRecords 是什么序列化(Serialization)tf.data 图像序列化&#xff08;Serializing Images)tf.Example函数封装 小结 TFRecords 是什么 TPU拥有八个核心&#xff0c;充当八个独立的工作单元。我们可以通过将数据集分成多个文件或分片&#xff08;shards&#xff09;…

初始多线程

目录 认识线程 线程是什么&#xff1a; 线程与进程的区别 Java中的线程和操作系统线程的关系 创建线程 继承Thread类 实现Runnable接口 其他变形 Thread类及其常见方法 Thread的常见构造方法 Thread类的几个常见属性 Thread类常用的方法 启动一个线程-start() 中断…

[保研/考研机试] KY109 Zero-complexity Transposition 上海交通大学复试上机题 C++实现

描述&#xff1a; You are given a sequence of integer numbers. Zero-complexity transposition of the sequence is the reverse of this sequence. Your task is to write a program that prints zero-complexity transposition of the given sequence. 输入描述&#xf…