Xgboost分类模型的完整示例

在这里插入图片描述

往期精彩推荐

  • 数据科学知识库
  • 机器学习算法应用场景与评价指标
  • 机器学习算法—分类
  • 机器学习算法—回归
  • PySpark大数据处理详细教程

定义问题

UCI的蘑菇数据集的主要目的是为了分类任务,特别是区分蘑菇是可食用还是有毒。这个数据集包含了蘑菇的各种特征,如帽形、颜色、气味等,以及一个标签表示蘑菇是否有毒。通过对这些特征的分析,可以构建分类模型来预测任何一个蘑菇样本是否有毒。这种类型的任务对于练习数据科学和机器学习技能,尤其是分类算法的应用和理解,非常有帮助。

导入相关库

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import classification_report
from xgboost import XGBClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import shap

加载数据集

# 数据集的URL
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data"

# 定义列名
column_names = ["class", "cap-shape", "cap-surface", "cap-color", "bruises", "odor",
                "gill-attachment", "gill-spacing", "gill-size", "gill-color", "stalk-shape",
                "stalk-root", "stalk-surface-above-ring", "stalk-surface-below-ring",
                "stalk-color-above-ring", "stalk-color-below-ring", "veil-type", "veil-color",
                "ring-number", "ring-type", "spore-print-color", "population", "habitat"]

# 加载数据集
mushroom_data = pd.read_csv(url, names=column_names)

# 查看数据集的前几行
mushroom_data.head(5)

在这里插入图片描述

数据探索

统计数据可以帮助您快速了解数据的分布情况、中心趋势和离散程度。在数据分析和机器学习的前期阶段,这是一个常用的探索性数据分析(EDA)步骤。

mushroom_data.describe()

在这里插入图片描述

数据预处理

去除异常值

# 可以根据均值标准差来定义异常值
均值 ± 3倍标准差  之外的定义为异常值

缺失值填充

mushroom_data = mushroom_data.fillna(0)

数值类型转换

# 将分类数据转换为数值
label_encoder = LabelEncoder()
for column in mushroom_data.columns:
    mushroom_data[column] = label_encoder.fit_transform(mushroom_data[column])

划分训练集与测试集

# 划分训练集和测试集
X = mushroom_data.drop('class', axis=1)
y = mushroom_data['class']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

均衡采样与模型定义

# 计算正负样本比例,在XGBoost模型中设置scale_pos_weight
scale_pos_weight = sum(y_train == 0) / sum(y_train == 1)

# 定义模型
xgb_model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', scale_pos_weight=scale_pos_weight)

随机搜索选参

定义XGBoost的参数搜索范围,并使用RandomizedSearchCV进行随机搜索,以找到最佳的超参数。

param_distributions = {
    'n_estimators': [100, 300, 500, 800, 1000],  # 表示树的个数。增加树的数量可以提高模型的复杂度,但也可能导致过拟合。
    'learning_rate': [0.01, 0.05, 0.1, 0.15, 0.2],  # 学习率,用于控制每棵树对最终结果的影响。较低的学习率意味着模型需要更多的树来进行训练。
    'max_depth': [3, 4, 5, 6, 7, 8],  # 树的最大深度。更深的树会增加模型的复杂度,但也可能导致过拟合。
    'min_child_weight': [1, 2, 3, 4],  # 决定最小叶子节点样本权重和。较大的值可以防止模型过于复杂,从而避免过拟合。
    'subsample': [0.6, 0.7, 0.8, 0.9, 1.0],  # 用于控制每棵树随机采样的比例,减少这个参数的值可以使模型更加保守,防止过拟合。
    'colsample_bytree': [0.6, 0.7, 0.8, 0.9, 1.0],  # 用于每棵树的训练时,随机采样的特征的比例。减少这个参数的值同样可以防止模型过于复杂。
    'gamma': [0, 0.1, 0.2, 0.3, 0.4]  # 后剪枝时,作为节点分裂所需的最小损失函数下降值。该参数值越大,算法越保守。
}

random_search = RandomizedSearchCV(
    xgb_model, param_distributions, n_iter=50, cv=5, random_state=42
)
random_search.fit(X_train, y_train)

best_params = random_search.best_params_
xgb_model.set_params(**best_params)

模型评估

使用找到的最佳参数训练XGBoost模型,然后在测试集上进行评估,计算性能指标如准确率、精确率、召回率和F1分数。

xgb_model.fit(X_train, y_train)
y_pred = xgb_model.predict(X_test)
# 混淆矩阵
cm = confusion_matrix(y_test, y_pred)
tn, fp, fn, tp = cm.ravel()
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)

performance_df = pd.DataFrame({
        'TN': [tn], 
        'FP': [fp], 
        'FN': [fn], 
        'TP': [tp],
        'Accuracy': [accuracy_score(y_test, y_pred)],
        'Precision': [precision_score(y_test, y_pred)],
        'Recall': [recall_score(y_test, y_pred)],
        'F1 Score': [f1_score(y_test, y_pred)]
    })
# 展示模型评估结果
performance_df.head()

标准数据集,结果太完美了,实际数据集就差强人意了

在这里插入图片描述

模型特征重要性展示

# 获取特征重要性
feature_importances = xgb_model.feature_importances_

# 可视化特征重要性
plt.style.use("ggplot")
plt.barh(range(len(feature_importances)), feature_importances)
plt.yticks(range(len(X.columns)), X.columns)
plt.xlabel('Feature Importance')
plt.title('Feature Importance in XGBoost Model')
plt.show()

在这里插入图片描述

SHAP负例分析

SHAP库提供了多种可视化工具,可以帮助您更深入地了解模型的行为。使用SHAP库计算XGBoost模型的SHAP值,分析被模型错误分类为负例的情况,并通过可视化来理解影响这些预测的关键特征在进行负例分析时,您可以专注于那些被模型错误分类的样本,并使用SHAP值来探究背后的原因。

# 计算SHAP值
explainer = shap.Explainer(xgb_model, X_train)
shap_values = explainer(X_test)

# 可视化:展示单个预测的SHAP值
shap.initjs()
# shap.force_plot(explainer.expected_value, shap_values[0,:], X_test[0,:])

# 可视化:展示所有测试数据的SHAP值
shap.summary_plot(shap_values, X_test)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/283703.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue3基础知识一,安装及使用

一、安装vue3 需要安装node,然后在项目所在目录命令行执行以下代码。 npm create vuelatest 回车后需要配置以下内容。 二、安装所需的依赖包并运行 cd到项目目录,执行以下代码安装依赖包 npm i 运行项目 npm run dev 打开浏览器查看结果 ok&#…

Linux系统安装DockerDocker-Compose

1、Docker安装 下载Docker依赖的组件 yum -y install yum-utils device-mapper-persistent-data lvm2 设置下载Docker服务的镜像源,设置为阿里云 yum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo 安装Docker服务 …

LTPI协议的理解——4、LTPI链路初始化以及运行

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 LTPI协议的理解——4、LTPI链路初始化以及运行 前言状态图Link TrainingLink DetectLink SpeedLink Training Example Link ConfigurationAdvertiseConfigure & AcceptLi…

Git:常用命令(二)

查看提交历史 1 git log 撤消操作 任何时候,你都有可能需要撤消刚才所做的某些操作。接下来,我们会介绍一些基本的撤消操作相关的命令。请注意,有些操作并不总是可以撤消的,所以请务必谨慎小心,一旦失误&#xff0c…

Django 后台与便签

1. 什么是后台管理 后台管理是网页管理员利用网页的后台程序管理和更新网站上网页的内容。各网站里网页内容更新就是通过网站管理员通过后台管理更新的。 2. 创建超级用户 1. python .\manage.py createsuperuser 2. 输入账号密码等信息 Username (leave blank to use syl…

jsp结合servlet

servlet配置 环境配置2023.12.31 idea配置搭建 创建一个普通的java项目 由于新版idea去除了add framework support的ui显示,可以在左边项目栏中使用快捷键shiftk或者setting中搜索add framework support在修改对应的快捷键 点击ok然后应该就是下面这样的结果 这里…

大数据背景下基于联邦学习的小微企业信用风险评估研究

摘要: 小微企业信用风险评估难是制约其融资和发展的一个主要障碍。基于大数据的小微企业信用风险评估依然面临着单机构数据片面、跨机构数据共享难、模型不稳定等诸多挑战。针对相关问题和挑战,本项目拟在多主体所有权数据隐私保护与安全共享的背景下&am…

AI4Green开源ELN(电子实验记录本)

AI4Green是一个开源的电子实验记录本,官网:https://github.com/AI4Green/ai4green 国内镜像: skywalk163/AI4Green - AI4Green - OpenI - 启智AI开源社区提供普惠算力! 论文地址:https://pubs.acs.org/doi/10.1021/…

Python开发环境[PycharmEclipseAnaconda]

Pycharm配置Python开发环境 每种语言的开发工具都有很多,如果写一些小的脚本或者小的工具,建议直接使用命令行或者Python自带的IDLE,如果进行大型的开发工作建议使用Pycharm,当然这属于个人喜好。 虽然Pycharm给了我们一个美观的…

安全加固指南:如何更改 SSH 服务器的默认端口号

在 Linux 系统中修改 SSH 服务的默认端口号是一项重要的安全措施,它可以帮助增强系统的安全性。这个过程相对简单,但必须由具有管理员权限的用户来执行。下面,我将向大家介绍如何安全地更改 SSH 端口的具体步骤。 1 备份 SSH 配置文件 在修改…

【MyBatis】操作数据库——入门

文章目录 为什么要学习MyBatis什么是MyBatisMyBatis 入门创建带有MyBatis框架的SpringBoot项目数据准备在配置文件中配置数据库相关信息实现持久层代码单元测试 为什么要学习MyBatis 前面我们肯定多多少少学过 sql 语言,sql 语言是一种操作数据库的一类语言&#x…

【AIGC风格prompt】风格类绘画风格的提示词技巧

风格类绘画风格的提示词展示 主题:首先需要确定绘画的主题,例如动物、自然景观、人物等。 描述:根据主题提供详细的描述,包括颜色、情感、场景等。 绘画细节:描述绘画中的细节,例如表情、纹理、光影等。 场…

认识Linux指令之 “mv” 指令

01.mv指令(重要) mv命令是move的缩写,可以用来移动文件或者将文件改名(move (rename) files),是Linux系统下常用的命令,经常用来备份文件或者目录。 语法: mv [选项] 源文件或目录 目标文件或…

影视后期:Pr 调色处理之风格调色

写在前面 整理一些影视后期相关学习笔记博文为 Pr 调色处理中风格调色,涉及下面几个Demo 好莱坞电影电影感调色复古港风调色赛博朋克风格调色日系小清晰调色 理解不足小伙伴帮忙指正 简单地说就是害怕向前迈进或者是不想真正地努力。不愿意为了改变自我而牺牲目前所…

jsp介绍

JSP 一种编写动态网页的语言&#xff0c;可以嵌入java代码和html代码&#xff0c;其底层本质上为servlet,html部分为输出流&#xff0c;编译为java文件 例如 源jsp文件 <% page contentType"text/html; charsetutf-8" language"java" pageEncoding&…

数据库存储引擎

一、什么是存储引擎 存储引擎是MySQL数据库中的一个【组件】&#xff0c;【负责执行实际的数据I/O操作】&#xff0c;工作在文件系统之上&#xff0c;数据库的数据会先传到存储引擎&#xff0c;在按照存储引擎的格式&#xff0c;保存到文件系统。 常用的存储引擎&#xff1a;…

在多Module项目中,给IDEA底部选项卡区域添加Services选项卡

一般一个spring cloud项目中大大小小存在几个十几个module编写具体的微服务项目。此时&#xff0c;如果要调试测需要依次启动各个项目比较麻烦。 idea其实提供了各module的启动管理工具了&#xff0c;可以快速启动和关闭各个服务&#xff0c;也能批量操作&#xff0c;比如一次…

[每周一更]-(第49期):一名成熟Go开发需储备的知识点(答案篇)- 2

答案篇 1、Go语言基础知识 什么是Go语言&#xff1f;它有哪些特点&#xff1f; Go语言&#xff08;也称为Golang&#xff09;是一种由Google开发的开源编程语言。它于2007年首次公开发布&#xff0c;并在2012年正式推出了稳定版本。Go语言旨在提供简单、高效、可靠的编程解决…

Windows 10 安装和开启VNCServer 服务

Windows 10 安装和开启VNCServer 服务 登录云服务器 使用本地RDP登录到配置VNCServer服务的Windows10系统的云服务器。 下载VNC Server安装包 打开官网下载VNCServer安装包 URL&#xff1a;https://www.realvnc.com/en/connect/download/vnc/windows/ 安装VNC Server 双击…

基于蝗虫算法优化的Elman神经网络数据预测 - 附代码

基于蝗虫算法优化的Elman神经网络数据预测 - 附代码 文章目录 基于蝗虫算法优化的Elman神经网络数据预测 - 附代码1.Elman 神经网络结构2.Elman 神经用络学习过程3.电力负荷预测概述3.1 模型建立 4.基于蝗虫优化的Elman网络5.测试结果6.参考文献7.Matlab代码 摘要&#xff1a;针…