【机器学习】CatBoost 模型实践:回归与分类的全流程解析

一. 引言

本篇博客首发于掘金 https://juejin.cn/post/7441027173430018067。
PS:转载自己的文章也算原创吧。

在机器学习领域,CatBoost 是一款强大的梯度提升框架,特别适合处理带有类别特征的数据。本篇博客以脱敏后的保险数据集为例,展示如何利用 CatBoost 完成分类和回归任务,并以可视化的方式解析特征重要性与结果。

我们将完成以下任务:

  1. 回归任务:预测保险索赔金额。
  2. 分类任务:判断保险案件是否需要调查。
  3. 可视化分析:利用散点图与分割线展示结果。

二. CatBoost 模型简介

CatBoost 是由俄罗斯搜索巨头 Yandex 于 2017 年开源的机器学习库,其名称来源于 “Category” 和 “Boosting” 的组合,旨在高效处理类别特征的梯度提升算法。与其他模型(如 XGBoost 和 LightGBM)相比,CatBoost 具有以下优势:

  • 支持类别特征:无需对类别特征进行独热编码,直接处理类别数据,避免数据膨胀。
  • 对缺失值的鲁棒性:无需特殊预处理即可直接处理缺失值。
  • 防止过拟合:内置多种正则化手段,减少梯度偏差和预测偏移,提高模型的准确性和泛化能力。
  • 对称树结构:采用对称决策树(Oblivious Trees),在每个层级使用相同的特征和分割点,提升训练和预测效率。

三. 实战项目环境与数据准备

本项目使用了脱敏后的保险数据集,包含以下特征:

  • 类别特征:险种代码、出险原因、医疗责任类别等。
  • 数值特征:基本保额、索赔金额等。
  • 标签:是否需要调查(分类任务)。

所有数据均已脱敏,支持迁移至其他表格数据集。

因为不好分享,所以后续第七节补充了一个基于sklearn "California Housing"数据集的流程代码与说明。


四. 回归任务:预测保险索赔金额

数据预处理

在回归任务中,我们根据特征预测索赔金额。以下是数据清洗与预处理的关键步骤:

  1. 过滤无效数据:移除缺失或非法值的记录。
  2. 特征转换:将类别特征转为字符串类型。
  3. 分割数据集:按 80% 和 20% 的比例划分训练集与测试集。

4.1 模型训练与评估

我们使用 CatBoost 进行回归建模,模型参数包括:

  • 学习率:0.02
  • 深度:8
  • 迭代次数:10,000(支持提前停止)

以下是模型的关键代码:

from catboost import CatBoostRegressor

# 初始化 CatBoost 回归模型
cat_regressor = CatBoostRegressor(
    iterations=10000,
    learning_rate=0.02,
    depth=8,
    eval_metric='RMSE',
    early_stopping_rounds=1500,
    random_seed=42
)

# 训练模型
cat_regressor.fit(
    X_train, y_train,
    cat_features=categorical_features_indices,
    eval_set=(X_test, y_test),
    verbose=100
)

4.2 特征重要性分析

特征重要性是衡量特征对模型预测贡献程度的指标,可以帮助我们更好地理解模型。

# 获取特征重要性
feature_importances = cat_regressor.get_feature_importance()
feature_names = X_train.columns

# 可视化特征重要性
import matplotlib.pyplot as plt
importance_df = pd.DataFrame({
    'Feature': feature_names,
    'Importance': feature_importances
}).sort_values(by='Importance', ascending=True)

plt.figure(figsize=(10, 6))
plt.barh(importance_df['Feature'], importance_df['Importance'], color='salmon')
plt.xlabel('特征重要性')
plt.ylabel('特征名称')
plt.title('CatBoost 特征重要性分析')
plt.show()

结果展示
在这里插入图片描述


4.3 模型评估

我们可以均方误差 (MSE) 以及 平均绝对误差 (MAE) 来评估模型在测试集上的回归性能,同时展示模型的学习曲线:

# 获取训练和测试集的 RMSE
evals_result = cat_regressor.get_evals_result()
train_rmse = evals_result['learn']['RMSE']
test_rmse = evals_result['validation']['RMSE']

# 绘制 RMSE 曲线
plt.figure(figsize=(10, 6))
plt.plot(train_rmse, label='训练集 RMSE')
plt.plot(test_rmse, label='测试集 RMSE')
plt.title('训练与测试集的 RMSE 学习曲线')
plt.xlabel('迭代次数')
plt.ylabel('RMSE')
plt.legend()
plt.show()

五. 分类任务:判别是否调查

5.1 数据标注与模型选择

分类任务以 是否调查 作为标签(1 表示需要调查,0 表示无需调查),特征包括所有数值和类别字段。

为了完成分类任务,我们选用 CatBoostClassifier。模型参数类似于回归模型,分类评估指标包括准确率、混淆矩阵和分类报告。


5.2 训练结果与模型评估

训练结果显示,分类准确率达 94.0%。以下是模型的分类报告:

分类报告 (训练集):
               precision    recall  f1-score   support

           0       0.96      0.98      0.97     13087
           1       0.74      0.57      0.64      1354

    accuracy                           0.94     14441
   macro avg       0.85      0.77      0.80     14441
weighted avg       0.94      0.94      0.94     14441
5.3 代码示例
from catboost import CatBoostClassifier

# 初始化分类器
cat_classifier = CatBoostClassifier(
    iterations=1000,
    learning_rate=0.02,
    depth=8,
    eval_metric='Accuracy',
    early_stopping_rounds=150,
    random_seed=42
)

# 模型训练
cat_classifier.fit(
    X_train, y_train,
    cat_features=categorical_features_indices,
    eval_set=(X_test, y_test),
    verbose=100
)

六. 可视化分析

为更直观地理解模型,我们利用散点图和分割线对预测结果进行展示:

  • 散点图:展示实际金额与预测金额的分布。
  • 分割线:通过 KMeans 聚类划分四个金额档次。

以下代码生成散点图与分割线:

# 使用 KMeans 聚类生成分割线
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=4, random_state=42)
df['cluster'] = kmeans.fit_predict(df[['预测金额']])

# 绘制散点图
plt.figure(figsize=(12, 8))
plt.scatter(df['预测金额'], df['是否调查'], c=df['cluster'], cmap='tab10')
plt.title("预测金额与是否调查的散点图")
plt.xlabel("预测金额")
plt.ylabel("是否调查")
plt.colorbar(label='Cluster')
plt.show()

散点图展示

在这里插入图片描述


七. 补充学习

7.1 基础数据集

California Housing 数据集包含加利福尼亚州 20,640 个街区的人口、住房和收入信息。目标是预测每个街区的房价中位数 MedHouseVal

数据特征

  1. MedInc:街区的收入中位数。
  2. HouseAge:街区住房的平均年龄。
  3. AveRooms:每个街区的平均房间数。
  4. AveBedrms:每个街区的平均卧室数。
  5. Population:街区的总人口。
  6. AveOccup:每户的平均人数。
  7. Latitude:街区的纬度。
  8. Longitude:街区的经度。

7.2 实践步骤

7.2.1 导入数据与预处理

我们使用 Scikit-learn 加载数据并进行预处理。

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd

# 加载 California Housing 数据集
data = fetch_california_housing(as_frame=True)
df = data.frame

# 特征和目标变量
X = df.drop(columns="MedHouseVal")
y = df["MedHouseVal"]

# 数据划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

print(f"训练集大小: {X_train.shape}, 测试集大小: {X_test.shape}")

训练集大小: (16512, 8), 测试集大小: (4128, 8)


7.2.2 训练 CatBoost 回归模型

使用 CatBoost 对房价进行预测。

from catboost import CatBoostRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error

# 初始化 CatBoost 回归模型
cat_regressor = CatBoostRegressor(
    iterations=1000,
    learning_rate=0.1,
    depth=6,
    eval_metric="RMSE",
    random_seed=42,
    verbose=100
)

# 模型训练
cat_regressor.fit(X_train, y_train, eval_set=(X_test, y_test), verbose=100, early_stopping_rounds=50)

# 模型预测
y_pred_train = cat_regressor.predict(X_train)
y_pred_test = cat_regressor.predict(X_test)

# 模型评估
mse_train = mean_squared_error(y_train, y_pred_train)
mse_test = mean_squared_error(y_test, y_pred_test)
mae_test = mean_absolute_error(y_test, y_pred_test)

print(f"训练集均方误差 (MSE): {mse_train}")
print(f"测试集均方误差 (MSE): {mse_test}")
print(f"测试集平均绝对误差 (MAE): {mae_test}")

输出如下:

0:	learn: 1.0934740	test: 1.0841841	best: 1.0841841 (0)	total: 1.24s	remaining: 20m 38s
100:	learn: 0.4867395	test: 0.5154868	best: 0.5154868 (100)	total: 1.54s	remaining: 13.7s
200:	learn: 0.4320149	test: 0.4798269	best: 0.4798269 (200)	total: 1.8s	remaining: 7.18s
300:	learn: 0.4020581	test: 0.4657293	best: 0.4657293 (300)	total: 2.07s	remaining: 4.8s
400:	learn: 0.3803801	test: 0.4582868	best: 0.4582868 (400)	total: 2.35s	remaining: 3.5s
500:	learn: 0.3633580	test: 0.4534430	best: 0.4534430 (500)	total: 2.61s	remaining: 2.6s
600:	learn: 0.3488402	test: 0.4491723	best: 0.4491723 (600)	total: 2.89s	remaining: 1.92s
700:	learn: 0.3358611	test: 0.4461323	best: 0.4461323 (700)	total: 3.17s	remaining: 1.35s
800:	learn: 0.3234759	test: 0.4431320	best: 0.4431320 (800)	total: 3.44s	remaining: 854ms
900:	learn: 0.3126821	test: 0.4403978	best: 0.4403978 (900)	total: 3.71s	remaining: 407ms
999:	learn: 0.3025414	test: 0.4386906	best: 0.4386902 (998)	total: 3.97s	remaining: 0us

bestTest = 0.438690174
bestIteration = 998

Shrink model to first 999 iterations.
训练集均方误差 (MSE): 0.09158491090576551
测试集均方误差 (MSE): 0.19244906768098075
测试集平均绝对误差 (MAE): 0.28701415230111493

7.2.3 可视化预测结果

展示预测值与实际值的对比,以及模型的特征重要性。

实际值与预测值对比
import matplotlib.pyplot as plt

# 对比测试集的预测值和实际值
plt.figure(figsize=(10, 6))
plt.scatter(range(len(y_test)), y_test, color="blue", label="真实值", alpha=0.6)
plt.scatter(range(len(y_pred_test)), y_pred_test, color="red", label="预测值", alpha=0.6)
plt.title("真实房价与预测房价对比")
plt.xlabel("样本索引")
plt.ylabel("房价中位数")
plt.legend()
plt.show()

特征重要性分析
# 特征重要性可视化
feature_importances = cat_regressor.get_feature_importance()
feature_names = data.feature_names

plt.figure(figsize=(10, 6))
plt.barh(feature_names, feature_importances, color="skyblue")
plt.title("CatBoost 特征重要性")
plt.xlabel("重要性得分")
plt.ylabel("特征名称")
plt.show()

在这里插入图片描述


7.3 数据结果

  • 模型评估结果:
    • 训练集均方误差 (MSE): 0.09158491090576551
    • 测试集均方误差 (MSE): 0.19244906768098075
    • 测试集平均绝对误差 (MAE): 0.28701415230111493
  • 特征重要性解读:
    根据特征重要性分析,MedInc(收入中位数)对预测房价的影响最大,而经纬度特征(Latitude 和 Longitude)也提供了显著的信息。

八. 总结

通过本项目,我们完成了基于 CatBoost 的回归与分类建模,并展示了预测结果的可视化。CatBoost 的强大功能和易用性使其在处理类别特征和缺失值的数据中表现优异。

希望本篇博客能为大家带来启发,助力实际项目的落地实现。如果对您有所帮助,也欢迎点赞与分享😊。

源码已上传到:https://github.com/YYForReal/ML-DL-RL-Learning/blob/main/ML-Learning/Catboost/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/930847.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

HTML Input 文件上传功能全解析:从基础到优化

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

Unity在运行状态下,当物体Mesh网格发生变化时,如何让MeshCollider碰撞体也随之实时同步变化?

旧版源代码地址:https://download.csdn.net/download/qq_41603955/90087225?spm1001.2014.3001.5501 旧版效果展示: 新版加上MeshCollider后的效果: 注意:在Unity中,当你动态地更改物体的Mesh时,通常期望…

conda create -n name python=x.x 执行失败问题解决方法

今天想在anaconda环境下创建一个指定python版本为3.9的虚拟环境,执行命令 conda create -n DeepLearning python3.9 但是系统竟然报错 看报错信息是在镜像源里找不到下载包,于是对镜像源文件处理 首先删除之前的镜像通道 conda config --remove-key …

第一个 JSP 程序

一个简单的 JSP 程序&#xff1a; 使用 IDEA 开发工具新建一个 maven 项目&#xff0c;具体操作如图所示&#xff1a; 配置 Tomcat 服务器 项目结构如下图所示&#xff1a; 3. 修改 index.jsp 页面的代码&#xff1a; <% page language"java" contentType&q…

Altium Designer学习笔记 32 DRC检查_丝印调整

基于Altium Designer 23学习版&#xff0c;四层板智能小车PCB 更多AD学习笔记&#xff1a;Altium Designer学习笔记 1-5 工程创建_元件库创建Altium Designer学习笔记 6-10 异性元件库创建_原理图绘制Altium Designer学习笔记 11-15 原理图的封装 编译 检查 _PCB封装库的创建Al…

docker学习笔记(五)--docker-compose

文章目录 常用命令docker-compose是什么yml配置指令详解versionservicesimagebuildcommandportsvolumesdepends_on docker-compose.yml文件编写 常用命令 命令说明docker-compose up启动所有docker-compose服务&#xff0c;通常加上-d选项&#xff0c;让其运行在后台docker-co…

不同类型的集成技术——Bagging、Boosting、Stacking、Voting、Blending简述

目录 一、说明 二、堆叠 2.1 堆叠的工作原理&#xff1a; 2.2 例子&#xff1a; 2.3 堆叠的优点&#xff1a; 三、投票&#xff08;简单投票&#xff09; 3.1 例子&#xff1a; 3.2 投票的优点&#xff1a; 四、装袋和投票之间的区别 五、混合 6.1 混合的主要特征&#xff1a; …

ONES 功能上新|ONES Project 甘特图再度升级

ONES Project 甘特图支持展示工作项标题、进度百分比、依赖关系延迟时间等信息。 应用场景&#xff1a; 在使用甘特图规划项目任务、编排项目计划时&#xff0c;可以对甘特图区域进行配置&#xff0c;展示工作项的工作项标题、进度百分比以及依赖关系延迟时间等维度&#xff0c…

【目标检测】【反无人机目标检测】使用SEB-YOLOv8s实时检测未经授权的无人机

Real-Time Detection of Unauthorized Unmanned Aerial Vehicles Using SEB-YOLOv8s 使用SEB-YOLOv8s实时检测未经授权的无人机 论文链接 0.论文摘要 摘要&#xff1a;针对无人机的实时检测&#xff0c;复杂背景下无人机小目标容易漏检、难以检测的问题。为了在降低内存和计算…

Elasticsearch:使用 Elastic APM 监控 Android 应用程序

一、前言 人们通过私人和专业的移动应用程序在智能手机上处理越来越多的事情。 拥有成千上万甚至数百万的用户&#xff0c;确保出色的性能和可靠性是移动应用程序和相关后端服务的提供商和运营商面临的主要挑战。 了解移动应用程序的行为、崩溃的发生和类型、响应时间慢的根本…

DataSophon集成CMAK KafkaManager

本次集成基于DDP1.2.1 集成CMAK-3.0.0.6 设计的json和tar包我放网盘了. 通过网盘分享的文件&#xff1a;DDP集成CMAK 链接: https://pan.baidu.com/s/1BR70Ajj9FxvjBlsOX4Ivhw?pwdcpmc 提取码: cpmc CMAK github上提供了zip压缩包.将压缩包解压之后 在根目录下加入启动脚本…

Java——异常机制(上)

1 异常机制本质 (异常在Java里面是对象) (抛出异常&#xff1a;执行一个方法时&#xff0c;如果发生异常&#xff0c;则这个方法生成代表该异常的一个对象&#xff0c;停止当前执行路径&#xff0c;并把异常对象提交给JRE) 工作中&#xff0c;程序遇到的情况不可能完美。比如…

vue3 vite ts day2

虚拟dom diff 算法的了解 diff 算法 源码的了解 简单易懂的图 参考文章 学习Vue3 第五章&#xff08;Vue核心虚拟Dom和 diff 算法&#xff09;_学习vue3 第五章 (vue核心虚拟dom-CSDN博客 如需了解更多请去原作者下看&#xff0c;讲的真的很细。 ref reactive vue2 …

动态计算加载图片

学习啦 别名路径&#xff1a;①npm install path --save-dev②配置 // vite.config,js import { defineConfig } from vite import vue from vitejs/plugin-vueimport { viteStaticCopy } from vite-plugin-static-copy import path from path export default defineConfig({re…

Postgresql 格式转换笔记整理

1、数据类型有哪些 1.1 数值类型 DECIMAL/NUMERIC 使用方法 DECIMAL是PostgreSQL中的一种数值数据类型&#xff0c;用于存储固定精度和小数位数的数值。DECIMAL的精度是由用户指定的&#xff0c;可以存储任何位数的数值&#xff0c;而小数位数则由用户自行定义。DECIMAL类型的…

爬虫运行后数据如何存储?

爬虫运行后获取的数据可以存储在多种不同的存储系统中&#xff0c;具体选择取决于数据的规模、查询需求以及应用场景。以下是一些常见的数据存储方法&#xff1a; 1. 文件系统 对于小型项目或临时数据存储&#xff0c;可以直接将数据保存到本地文件中。常见的文件格式包括&…

吉林大学23级数据结构上机实验(第7周)

A 去火车站 寒假到了&#xff0c;小明准备坐火车回老家&#xff0c;现在他从学校出发去火车站&#xff0c;CC市去火车站有两种方式&#xff1a;轻轨和公交车。小明为了省钱&#xff0c;准备主要以乘坐公交为主。CC市还有一项优惠政策&#xff0c;持学生证可以免费乘坐一站轻轨&…

谈谈IPD在PLM的落地

关注作者 1 前言 全球化市场竞争形势下&#xff0c;越来越多企业不断提升自身的研发创新能力&#xff0c;加大产品的研发创新投入。从整个研发投入来看&#xff0c;2022年至2023年间&#xff0c;研发投入强度由1.54%提升至2.64%&#xff0c;其中中小民营企业增长为3.75%&#…

线程(二)——线程安全

如何理解线程安全&#xff1a; 多线程并发执行的时候&#xff0c;有时候会触发一些“bug”&#xff0c;虽然代码能够执行&#xff0c;线程也在工作&#xff0c;但是过程和结果都不符合我们的开发时的预期&#xff0c;所以我们将此类线程称之为“线程安全问题”。 例如&#xff…

思特奇政·企数智化产品服务平台正式发布,助力运营商政企数智能力跃迁

数字浪潮下,产业数字化进程加速发展,信息服务迎来更广阔的天地,同时也为运营商政企支撑系统提出了更高要求。12月4日,2024数字科技生态大会期间,思特奇正式发布政企数智化产品服务平台,融合应用大数据、AI等新质生产要素,构建集平台服务、精准营销、全周期运营支撑、智慧大脑于…