机器学习的模型校准

背景知识

之前一直没了解过模型校准是什么东西,最近上班业务需要看了一下:

模型校准是指对分类模型进行修正以提高其概率预测的准确性。在分类模型中,预测结果通常以类别标签形式呈现(例如,0或1),但有时我们更关注的是预测的概率。

当使用某些分类模型(例如支持向量机(SVM)或随机森林)时,其预测的概率并不一定与真实标签的概率分布相匹配。这意味着,即使预测概率较高的类别出现的频率更高,模型的预测概率也可能偏离真实情况。这可能导致对模型的概率输出有误解,或者在需要高度依赖概率预测的任务(例如风险评估或阈值选择)中出现问题。

通过校准分类模型,我们可以将模型的预测概率调整为更准确地反映真实情况。`CalibratedClassifierCV`是Scikit-learn库中提供的用于校准分类器的类。它根据指定的校准方法(`method`),通过拟合后的分类器(`model`)和交叉验证拟合(`cv='prefit'`)来创建一个经过校准的分类器(`calibrated_model`)。

在代码中,使用`calibrated_model.fit(X_train, y_train)`通过使用交叉验证拟合来训练、校准模型。之后,使用`calibrated_model.predict(X_test)`对测试集进行预测,并使用`classification_report`输出校准模型的分类性能报告。

通过校准分类模型,我们可以使得模型的概率预测更为准确,从而提高在概率判断和相关任务中的性能和可靠性。


代码实现

模型校准主要是针对分类模型的,我之前都是做回归,难怪没怎么接触过。也没空找真实数据了,直接模拟数据来实现一下。

导入包和制作数据集

import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.calibration import calibration_curve
from sklearn.ensemble import RandomForestClassifier

import matplotlib.pyplot as plt

# 生成二分类数据集
X, y = make_classification(n_samples=10000, n_features=40, n_classes=2, 
                           weights=[0.9, 0.1], random_state=2, flip_y=0.3)

查看分布:
 

pd.Series(y).value_counts()

不平衡样本。

标准化,划分训练集测试集

# 标准化数据
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y,stratify=y, test_size=0.2, random_state=2)

训练,然后评价,这里就弄了个随机森林模型试试

# 模型训练
model =RandomForestClassifier()
model.fit(X_train, y_train)

# 模型评价
y_pred = model.predict(X_test)
print("Classification Report:")
print(classification_report(y_test, y_pred))


画校准曲线。

# 计算校准曲线
prob_true, prob_pred = calibration_curve(y_test, model.predict_proba(X_test)[:, 1], n_bins=10)

# 绘制校准曲线
plt.figure(figsize=(7, 4),dpi=128)
plt.plot(prob_pred, prob_true, marker='o', label='uncalibrated')
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='perfectly calibrated')
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.title('Calibration Curve (Uncalibrated)')
plt.legend()
plt.show()

这玩意怎么看,,,我也不太懂,反正就是要单调,并且越靠近对角线越好。这个明显在0.1-0.2区间不单调,还有0.7-0.9也在下降。

来校准一下:


模型校准

模型校准很多方法,目前这个是用了 method='sigmoid',这个方法,好像叫做什么p系数校准。

from sklearn.calibration import CalibratedClassifierCV
calibrated_model = CalibratedClassifierCV(model, method='sigmoid', cv='prefit')
calibrated_model.fit(X_train, y_train)

# 模型评价(校准后)
y_pred_calibrated = calibrated_model.predict(X_test)
print("Classification Report (Calibrated Model):")
print(classification_report(y_test, y_pred_calibrated))

emmm,效果好像没有明显提升。

method='isotonic',这个是什么保序回归方法校准。

calibrated_model2 = CalibratedClassifierCV(model, method='isotonic', cv='prefit')
calibrated_model2.fit(X_train, y_train)

# 模型评价(校准后)
y_pred_calibrated2 = calibrated_model2.predict(X_test)
print("Classification Report (Calibrated Model):")
print(classification_report(y_test, y_pred_calibrated2))

效果也差不多。

画出校准曲线的对比图:
 

# 计算校准后的校准曲线
prob_true_calibrated, prob_pred_calibrated = calibration_curve(y_pred_calibrated,calibrated_model.predict_proba(X_test)[:, 1], n_bins=10)
prob_true_calibrated2, prob_pred_calibrated2 = calibration_curve(y_pred_calibrated2,calibrated_model2.predict_proba(X_test)[:, 1], n_bins=10)
# 绘制校准后的校准曲线
plt.figure(figsize=(7, 4),dpi=128)
plt.plot(prob_pred, prob_true, marker='o', label='uncalibrated')
plt.plot(prob_pred_calibrated, prob_true_calibrated, marker='o', label='sigmoid calibrated')
plt.plot(prob_pred_calibrated2, prob_true_calibrated2, marker='o', label='isotonic calibrated')
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='perfectly calibrated')
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.title('Calibration Curve (Calibrated)')
plt.legend()
plt.show()

 

可以看到模型校准之后这个线都是单调上升的了。但是都很奇怪,而且预测效果也没太多改善,可能是我这个数据集是随便造的原因。

校准曲线的单调性在模型校准中确实非常重要。校准曲线的单调性指的是在横轴表示预测概率的均值,纵轴表示实际观测到的正例比例时,曲线应该是单调递增的,即预测概率越高,观测到的正例比例也应该越高。

校准曲线的单调性反映了模型输出的概率与实际观测之间的一致性。如果校准曲线的单调性较差,意味着模型的输出概率与实际观测之间存在较大的偏差,可能会导致模型在实际应用中表现不稳定或不可靠。因此,单调的校准曲线通常被认为是一个良好校准的指标之一。

在实际应用中,如果模型的校准曲线不单调,可能需要进一步考虑以下问题:

模型的输出概率是否准确反映了样本的真实概率:如果模型输出的概率存在系统性的偏差,可能需要对模型进行校准,使其输出更加准确地反映样本的真实概率。

模型是否过度自信或不足自信:校准曲线的不单调性可能反映了模型在某些概率范围内过度自信或不足自信的问题。对于过度自信的模型,可能需要降低其输出概率;对于不足自信的模型,可能需要提高其输出概率。

模型的可靠性:校准曲线的单调性也反映了模型的可靠性。单调递增的校准曲线意味着模型的输出概率与实际观测之间的一致性较好,通常更可靠。

因此,校准曲线的单调性对于评估模型的校准效果和可靠性具有重要意义,在模型校准过程中应该注意观察和优化校准曲线的单调性。

嗯,都是gpt的话,看看了解一下就行。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/515027.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

注意力机制篇 | YOLOv8改进之添加LSKAttention大核卷积注意力机制 | 即插即用,实现有效涨点

前言:Hello大家好,我是小哥谈。LSKAttention是一种注意力机制,它在自然语言处理领域中被广泛应用。LSKAttention是基于Transformer模型中的Self-Attention机制进行改进的一种变体。在传统的Self-Attention中,每个输入序列中的元素都会与其他元素进行交互,以获取全局的上下…

Linux 命令 top 详解

1 top命令介绍 Linux系统中,Top命令主要用于实时运行系统的监控,包括Linux内核管理的进程或者线程的资源占用情况。这个命令对所有正在运行的进程和系统负荷提供不断更新的概览信息,包括系统负载、CPU利用分布情况、内存使用、每个进程的内容…

开源量化交易研究框架Hikyuu

Hikyuu Quant Framework 是一款基于 C/Python 的开源量化交易研究框架,用于策略分析及回测。其核心思想基于当前成熟的系统化交易方法,将整个系统化交易抽象为由市场环境判断策略、系统有效条件、信号指示器、止损 / 止盈策略、资金管理策略、盈利目标策…

分享three.js实现粒子背景

three.js中粒子效果的实现方式大概分为三种: 1、Javascript直接计算粒子的状态变化,即基于CPU实现; 2、Javascript通知顶点着色器粒子的生命周期,由顶点着色器运行,即基于GPU实现; 3、粒子生成与状态维护全…

QT实现NTP功能

一.NTP基础 1.NTP定义 NTP(Network Time Protocol,网络时间协议)是由RFC 1305定义的时间同步协议,用于分布式设备(比如电脑、手机、智能手表等)进行时间同步,避免人工校时的繁琐和由此引入的误…

【漏洞复现】极简云 download.php 接口处存在任意文件读取漏洞

免责声明:文章来源互联网收集整理,请勿利用文章内的相关技术从事非法测试,由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失,均由使用者本人负责,所产生的一切不良后果与文章作者无关。该…

什么是线程安全、怎么保证线程安全

目录 什么是线程安全 多线程编程中的三个核心概念 JMM内存模型 JMM内存模型怎么实现原子性、可见性 怎么保证线程安全 什么是线程安全 当多个线程访问一个对象时,如果不用考虑这些线程在运行时环境下的调度和交替执行,也不需要进行额外的同步&#x…

Rust---复合数据类型之结构体

目录 结构体的使用输出结果 结构体简化创建结构体更新语法元组结构体单元结构体(unit struct)结构体中的引用使用#[derive(Debug)]再次介绍 代码综合展示 与元组不同的是,结构体可以为内部的每个字段起一个富有含义的名称,因此无需…

基于SpringBoot+Vue的汽车租赁管理系统的设计和实现【附源码】

1、系统演示视频(演示视频) 2、需要交流和学习请联系

vivado适用于 UltraScale 和 UltraScale+ 器件的 eFUSE 寄存器访问和编程

FUSE_DNA : 唯一的器件 DNA 每个 UltraScale 器件都有唯一的器件 ID , 称为器件 DNA , 且赛灵思已将此 DNA 编程到器件中。用户无法对 FUSE_DNA 进行编程。 UltraScale 器件具有 96 位 DNA 。您可在 Vivado Design Suite Tcl 控制台中…

Matlab梁单元有限元编程:铁木辛柯梁VS欧拉梁

专栏导读 作者简介:工学博士,高级工程师,专注于工业软件算法研究本文已收录于专栏:《有限元编程从入门到精通》本专栏旨在提供 1.以案例的形式讲解各类有限元问题的程序实现,并提供所有案例完整源码;2.单元…

openplc Linux 地址映射io,读写驱动数据等使用记录

1. 上一篇记录 openplc使用C语言文件读写驱动实现基本流程。 openPLC_Editor C语言编程 在mp157 arm板上调用io等使用记录_openplc c 编程-CSDN博客 2. 下面通过映射地址的方式控制io和读写驱动数据。 在runtime 环境的 hardware 硬件配置中 选择 python on Linux(PSM)&#…

成为先进企业应该从用飞书开始,还是应该从裁员开始?

点击文末“阅读原文”即可参与节目互动 剪辑、音频 / 卷圈 运营 / SandLiu 卷圈 监制 / 姝琦 封面 / 姝琦Midjourney 产品统筹 / bobo 大家都觉得飞书不行了,我们反而不这么看。 众所周知,飞书最近裁员的消息在业界引起了巨大的反响,大…

基于SSM+Jsp+Mysql的快递管理系统

开发语言:Java框架:ssm技术:JSPJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包…

62、服务攻防——框架安全CVE复现SpringStrutsLaravelThinkphp

文章目录 常见语言开发框架(主流): PHP:ThinkPHP、Laravel、YIIJAVA:Spring、SpringbootPython:FlaskJavaScript:Vue.js、Node.js 框架判断(): 通过插件Wappalyzer(可能…

练手项目层初阶1—《详解静态版本——通讯录管理系统》

文章目录 🚩前言🔊 项目需求📚 项目知识点包含🧩项目框架📝框架拆解分析✨Struct_book.h 头文件解析✨Struct_book.c文件解析✨test_book.c文件解析 📺演示效果🚀完整代码 🚩前言 俗…

联想 Y9000P 连接网线速度慢 的 问题解决

参考帖子:求助,拯救者Y9000P 2022 i73060版本 有线网非常慢 无线网正常【笔记本吧】_百度贴吧 问题原因: 网卡驱动版本不对。不能用Win11版,要用Win10版。 问题解决: 1、卸载原驱动 2、下载Win10 驱动 并安装 下载…

C语言——常用库函数的使用及模拟实现

C语言编译系统为使用者提供了很多的预定义库函数。用户在编写程序时,可以直接调用这些库函数。这里选择了一些常用的库函数,简单介绍各个函数的使用方法: 字符串函数 字符串函数函数分为长度不受限制的字符串函数,和长度受限制的…

手搓链表(java)(不完整)

手搓链表(java)(不完整) 文章目录 手搓链表(java)(不完整)前言一、代码1.MyLinkedList类:2.测试类: 总结 前言 提示:以下是本篇文章正文内容&…

Git 术语及中英文对照

完毕!!感谢您的收看 ----------★★历史博文集合★★---------- 我的零基础Python教程,Python入门篇 进阶篇 视频教程 Py安装py项目 Python模块 Python爬虫 Json Xpath 正则表达式 Selenium Etree CssGui程序开发 Tkinter Pyqt5 列表元组字…