数据集类不平衡的处理方法

最近在处理一个类不平均的数据集,这里记录一些注意事项,方便以后查阅。

数据集类不平衡的处理方法

  • 数据集类不平衡的处理方法
    • 1. 数据处理方法
    • 2. 模型改进方法
    • 3. 测试与评估方法
    • 4. 综合策略
    • 5. 示例代码
    • 6. 注意事项
  • 模型评估指标
    • 1. 混淆矩阵(Confusion Matrix)
    • 2. 准确率(Accuracy)
    • 3. 精确率(Precision)
    • 4. 召回率(Recall)
    • 5. F1-Score
    • 6. Kappa值(Cohen's Kappa)
    • 7. 特异度(Specificity)
    • 8. 假正率(FPR)
    • 9. ROC曲线与AUC
    • 10. PR曲线与AUC

数据集类不平衡的处理方法

对于类别不平衡的数据集,模型训练与测试的效果可能受到数据分布的影响,因此需要采取一些方法来缓解类别不平衡问题,从而提高模型的性能和泛化能力。以下是常见的解决策略:

1. 数据处理方法

(1) 过采样(Oversampling)

  • 定义:在训练集中增加少数类别的样本数量。
  • 实现
    • 随机复制少数类别样本(Random Oversampling)。
    • 使用合成数据生成技术(如SMOTE)。
  • 优点:缓解数据不平衡,增加模型对少数类别的学习能力。
  • 注意:过度过采样可能导致过拟合。

(2) 欠采样(Undersampling)

  • 定义:减少多数类别的样本数量。
  • 实现:随机移除多数类别样本。
  • 优点:加速训练并平衡数据。
  • 注意:可能丢失重要信息,导致模型性能下降。

(3) 数据增强

  • 定义:通过旋转、缩放等操作对少数类别样本进行扩增。
  • 适用场景:图像、音频等数据。

2. 模型改进方法

(1) 使用加权损失函数

  • 定义:为不同类别分配不同的损失权重,让模型更关注少数类别。
  • 实现
    • 权重比例通常与类别分布反比。
    • 常用损失函数:加权交叉熵损失(Weighted Cross-Entropy Loss)、Focal Loss。
  • 优点:无需改变数据分布。

(2) 平衡采样器(Balanced Sampler)

  • 定义:在每个mini-batch中,按类别比例采样数据。
  • 实现:调整DataLoader或批量生成策略。

(3) 使用特殊模型架构

  • 定义:采用适合处理不平衡数据的模型。
  • 例如:集成学习(如随机森林、XGBoost等)可以较好地处理类别不平衡问题。

3. 测试与评估方法

(1) 选择合适的评估指标

  • 问题:传统准确率指标在类别不平衡情况下可能误导结果。
  • 替代指标
    • 精确率(Precision)、召回率(Recall)、F1-Score。
    • ROC曲线与AUC。
    • PR曲线与AUC。

(2) 分层采样(Stratified Sampling)

  • 定义:在训练集和测试集中保持类别分布一致。
  • 实现:划分数据集时,按照类别比例分层。

(3) 混淆矩阵分析

  • 定义:观察模型对不同类别的预测表现。
  • 作用:确定哪些类别需要改进。

4. 综合策略

(1) 混合采样

  • 定义:结合过采样与欠采样。
  • 适用场景:同时提高训练效率和模型性能。

(2) 使用非对称阈值

  • 定义:针对少数类别设置较低的决策阈值,增加召回率。
  • 实现:通过调整predict_proba的概率阈值。

(3) 分阶段训练

  • 定义:先用平衡数据训练模型,再用真实分布数据进行微调。
  • 优点:提高模型的实际适用性。

5. 示例代码

(1) 使用加权损失函数(以PyTorch为例)

import torch
import torch.nn as nn

# 假设类别0的样本数为1000,类别1的样本数为100
weights = torch.tensor([1/1000, 1/100], dtype=torch.float32)
criterion = nn.CrossEntropyLoss(weight=weights)

(2) 使用SMOTE进行过采样

from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split

# 分割数据
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)

# 使用SMOTE
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)

6. 注意事项

  1. 测试集保持真实分布:测试数据应反映真实世界的数据分布,不应进行平衡处理。
  2. 防止数据泄漏:避免数据增强或采样方法导致数据泄漏(如同一条样本的不同增强版本出现在训练集和验证集中)。
  3. 动态调整:根据数据特性选择适合的策略,灵活调整。

模型评估指标

1. 混淆矩阵(Confusion Matrix)

定义
混淆矩阵是一种评估分类模型性能的工具,显示模型在测试数据上的分类结果与实际结果的对比情况。它将分类结果划分为四种类型:TP(真正例)、TN(真负例)、FP(假正例)和 FN(假负例)。混淆矩阵通常用于二分类问题,但也可扩展到多分类问题。

二分类混淆矩阵

实际类别 / 预测类别预测为正例预测为负例
实际为正例TPFN
实际为负例FPTN

元素解释

  1. TP(真正例,True Positive)
    • 实际类别为正例,且模型预测为正例。
    • 示例:病人被正确诊断为患病。
  2. TN(真负例,True Negative)
    • 实际类别为负例,且模型预测为负例。
    • 示例:健康人被正确诊断为健康。
  3. FP(假正例,False Positive)
    • 实际类别为负例,但模型预测为正例(“误报”)。
    • 示例:健康人被错误诊断为患病。
  4. FN(假负例,False Negative)
    • 实际类别为正例,但模型预测为负例(“漏报”)。
    • 示例:病人被错误诊断为健康。

扩展到多分类问题

对于多分类问题,混淆矩阵是一个 n × n 的方阵,n 是类别数。矩阵中的第 i i i 行表示实际类别为 i i i 的样本,第 j j j 列表示模型预测类别为 j j j 的样本数量。对角线上的值表示正确分类的样本数,非对角线表示错误分类的情况。

2. 准确率(Accuracy)

定义

  • 表示模型预测正确的样本占总样本的比例。

公式
Accuracy = TP + TN TP + TN + FP + FN \text{Accuracy} = \frac{\text{TP} + \text{TN}}{\text{TP} + \text{TN} + \text{FP} + \text{FN}} Accuracy=TP+TN+FP+FNTP+TN

应用场景

  • 类别平衡的数据集,如正常分类任务。

适用性

  • 不适用于类不平衡问题,容易被多数类样本主导,无法反映模型对少数类的性能。

优缺点

  • 优点:直观、简单,适合初步评估。
  • 缺点:类别不平衡时误导性强。

物理意义

  • 反映了模型总体上的正确预测比例。

实现代码

from sklearn.metrics import accuracy_score

# y_true: 实际标签, y_pred: 预测标签
accuracy = accuracy_score(y_true, y_pred)

3. 精确率(Precision)

定义

  • 模型预测为正例的样本中,实际为正例的比例。

公式
Precision = TP TP + FP \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} Precision=TP+FPTP

应用场景

  • 误报代价高的场景,如垃圾邮件过滤、金融欺诈检测。

适用性

  • 适用于类不平衡问题,更关注正例的预测质量。

优缺点

  • 优点:有效避免过多误报。
  • 缺点:可能忽略对召回率的考虑。

物理意义

  • 衡量模型对正例预测的可信度。

实现代码

from sklearn.metrics import precision_score

precision = precision_score(y_true, y_pred)

4. 召回率(Recall)

定义

  • 实际为正例的样本中,模型正确预测为正例的比例。

公式
Recall = TP TP + FN \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} Recall=TP+FNTP

应用场景

  • 漏报代价高的场景,如医疗诊断、安防监控。

适用性

  • 适用于类不平衡问题,更关注正例的覆盖能力。

优缺点

  • 优点:减少漏报。
  • 缺点:可能导致更多误报。

物理意义

  • 衡量模型对正例的捕获能力。

实现代码

from sklearn.metrics import recall_score

recall = recall_score(y_true, y_pred)

5. F1-Score

定义

  • 精确率和召回率的调和平均值,综合考虑两者的性能。

公式
F1-Score = 2 × Precision × Recall Precision + Recall \text{F1-Score} = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}} F1-Score=2×Precision+RecallPrecision×Recall

应用场景

  • 需要平衡精确率与召回率的场景。

适用性

  • 适用于类不平衡问题,可作为主要指标。

优缺点

  • 优点:综合性强。
  • 缺点:对高精确率或高召回率的情况敏感。

物理意义

  • 衡量模型整体性能的均衡性。

实现代码

from sklearn.metrics import f1_score

f1 = f1_score(y_true, y_pred)

6. Kappa值(Cohen’s Kappa)

定义

  • 评估模型分类与随机分类之间的一致性。

公式
κ = p o − p e 1 − p e \kappa = \frac{p_o - p_e}{1 - p_e} κ=1pepope

  • p o p_o po:观察到的准确率。
  • p e p_e pe:随机一致的概率。

应用场景

  • 强调分类一致性的场景。

适用性

  • 适用于类不平衡问题

优缺点

  • 优点:剔除随机预测的影响。
  • 缺点:对概率分布的敏感性较高。

物理意义

  • 衡量模型分类的一致性。

实现代码

from sklearn.metrics import cohen_kappa_score

kappa = cohen_kappa_score(y_true, y_pred)

7. 特异度(Specificity)

定义

  • 实际为负例的样本中,模型正确预测为负例的比例。

公式
Specificity = TN TN + FP \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} Specificity=TN+FPTN

应用场景

  • 关注负例分类准确性的场景。

适用性

  • 适用于类不平衡问题

优缺点

  • 优点:关注负例性能。
  • 缺点:容易被少数类忽略。

物理意义

  • 衡量负例预测的能力。

实现代码

from sklearn.metrics import confusion_matrix

tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
specificity = tn / (tn + fp)

8. 假正率(FPR)

定义

  • 实际为负例的样本中,模型错误预测为正例的比例。

公式
FPR = FP FP + TN \text{FPR} = \frac{\text{FP}}{\text{FP} + \text{TN}} FPR=FP+TNFP

应用场景

  • 分析误报对结果的影响。

适用性

  • 可辅助其他指标分析。

优缺点

  • 优点:有效评估误报。
  • 缺点:无法独立衡量模型性能。

9. ROC曲线与AUC

定义

  • ROC曲线:以假正率(FPR)为横轴,真正率(TPR)为纵轴,反映分类性能。
  • AUC:ROC曲线下面积。

应用场景

  • 评估分类器整体性能。

适用性

  • 类不平衡时稳定

优缺点

  • 优点:评估全面。
  • 缺点:受类别分布影响。

实现代码

from sklearn.metrics import roc_auc_score

auc = roc_auc_score(y_true, y_pred_proba)

10. PR曲线与AUC

定义

  • PR曲线:以召回率为横轴,精确率为纵轴。
  • PR-AUC:PR曲线下面积。

应用场景

  • 类不平衡问题,如少数类检测。

优缺点

  • 优点:对正例的性能更敏感。
  • 缺点:忽略负例表现。

实现代码

from sklearn.metrics import precision_recall_curve, auc

precision, recall, _ = precision_recall_curve(y_true, y_pred_proba)
pr_auc = auc(recall, precision)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/953563.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

目标检测中的Bounding Box(边界框)介绍:定义以及不同表示方式

《------往期经典推荐------》 一、AI应用软件开发实战专栏【链接】 项目名称项目名称1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】3.【手势识别系统开发】4.【人脸面部活体检测系统开发】5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】7.【…

openEuler22.03系统使用Kolla-ansible搭建OpenStack

Kolla-ansible 是一个利用 Ansible 自动化工具来搭建 OpenStack 云平台的开源项目,它通过容器化的方式部署 OpenStack 服务,能够简化安装过程、提高部署效率并增强系统的可维护性。 前置环境准备: 系统:openEuler-22.03-LTS-SP4 配置&…

Leecode刷题C语言之统计重新排列后包含另一个字符串的子字符串数目②

执行结果:通过 执行用时和内存消耗如下: void update(int *diff, int c, int add, int *cnt) {diff[c] add;if (add 1 && diff[c] 0) {// 表明 diff[c] 由 -1 变为 0(*cnt)--;} else if (add -1 && diff[c] -1) {// 表明 diff[c] 由 0 变为 -…

uniapp 微信小程序webview与h5双向实时通信交互

描述: 小程序webview内嵌的h5需要向小程序实时发送消息,有人说postMessage可以实现,所以试验一下,结果是实现不了实时,只能在特定时机后退、组件销毁、分享时小程序才能接收到信息(小程序为了安全等考虑做了…

pycharm-pyspark 环境安装

1、环境准备:java、scala、pyspark、python-anaconda、pycharm vi ~/.bash_profile export SCALA_HOME/Users/xunyongsun/Documents/scala-2.13.0 export PATH P A T H : PATH: PATH:SCALA_HOME/bin export SPARK_HOME/Users/xunyongsun/Documents/spark-3.5.4-bin…

fast-crud select下拉框 实现多选功能及下拉框数据动态获取(通过接口获取)

教程 fast-crud select示例配置需求:需求比较复杂 1. 下拉框选项需要通过后端接口获取 2. 实现多选功能 由于这个前端框架使用逻辑比较复杂我也是第一次使用,所以只记录核心问题 环境:vue3,typescript,fast-crud ,elementPlus 效果 代码 // crud.tsx文件(/.ts也行 js应…

高性能现代PHP全栈框架 Spiral

概述 Spiral Framework 诞生于现实世界的软件开发项目是一个现代 PHP 框架,旨在为更快、更清洁、更卓越的软件开发提供动力。 特性 高性能 由于其设计以及复杂精密的应用服务器,Spiral Framework框架在不影响代码质量以及与常用库的兼容性的情况下&a…

天机学堂笔记1

FeignClient(contextId "course", value "course-service") public interface CourseClient {/*** 根据老师id列表获取老师出题数据和讲课数据* param teacherIds 老师id列表* return 老师id和老师对应的出题数和教课数*/GetMapping("/course/infoB…

lobechat搭建本地知识库

本文中,我们提供了完全基于开源自建服务的 Docker Compose 配置,你可以直接使用这份配置文件来启动 LobeChat 数据库版本,也可以对之进行修改以适应你的需求。 我们默认使用 MinIO 作为本地 S3 对象存储服务,使用 Casdoor 作为本…

沸点 | 聚焦嬴图Cloud V2.1:具备水平可扩展性+深度计算的云原生嬴图动力站!

近日,嬴图正式推出嬴图Cloud V2.1,此次发布专注于提供无与伦比的用户体验,包括具有水平可扩展性的嬴图Powerhouse的一键部署、具有灵活定制功能的管理控制台、VPC / 专用链接等,旨在满足用户不断变化需求的各项前沿功能&#xff0…

Linux---shell脚本练习

要求: 1、shell 脚本写出检测 /tmp/size.log 文件如果存在显示它的内容,不存在则创建一个文件将创建时间写入。 2、写一个 shel1 脚本,实现批量添加 20个用户,用户名为user01-20,密码为user 后面跟5个随机字符。 3、编写个shel 脚本将/usr/local 日录下…

LiveNVR监控流媒体Onvif/RTSP常见问题-二次开发接口jquery调用示例如何解决JS|axios调用接口时遇到的跨域问题

LiveNVR二次开发接口jquery调用示例如何解决JS|axios调用接口时遇到的跨域问题 1、接口调用示例2、JS调用遇到跨域解决示例3、axios请求接口遇到跨域问题3.1、post请求3.2、get请求 4、RTSP/HLS/FLV/RTMP拉流Onvif流媒体服务 1、接口调用示例 下面是完整的 jquery 调用示例 $.a…

Canvas简历编辑器-选中绘制与拖拽多选交互方案

Canvas简历编辑器-选中绘制与拖拽多选交互方案 在之前我们聊了聊如何基于Canvas与基本事件组合实现了轻量级DOM,并且在此基础上实现了如何进行管理事件以及多层级渲染的能力设计。那么此时我们就依然在轻量级DOM的基础上,关注于实现选中绘制与拖拽多选交…

服务器数据恢复—raid5故障导致上层ORACLE无法启动的数据恢复案例

服务器数据恢复环境&故障: 一台服务器上的8块硬盘组建了一组raid5磁盘阵列。上层安装windows server操作系统,部署了oracle数据库。 raid5阵列中有2块硬盘的硬盘指示灯显示异常报警。服务器操作系统无法启动,ORACLE数据库也无法启动。 服…

LabVIEW光流算法的应用

该VI展示了如何使用NI Vision Development Module中的光流算法来计算图像序列中像素的运动矢量。通过该方法,可以实现目标跟踪、运动检测等功能,适用于视频处理、机器人视觉和监控领域。程序采用模块化设计,包含图像输入、算法处理、结果展示…

Redis十大数据类型详解

Redis(一) 十大数据类型 redis字符串(String) string是redis最基本的类型,一个key对应一个value string类型是二进制安全的,意思是redis的string可以包含任何数据。例如说是jpg图片或者序列化对象 一个re…

Navicat Premium 16.0.90 for Mac 安装与free使用

步骤 0.下载 通过网盘分享的文件:Navicat Premium 16.0.90 链接: https://pan.baidu.com/s/12O22rXa9MiBPKKTGMELNIg 提取码: yyds 1.打开下好的 dmg 文件 (这个界面不要关闭) 2.将Navicat Premium 拖动至 Applications 这时出现 点击取消。 3.点开…

基于Springboot + vue实现的购物推荐网站

🥂(❁◡❁)您的点赞👍➕评论📝➕收藏⭐是作者创作的最大动力🤞 💖📕🎉🔥 支持我:点赞👍收藏⭐️留言📝欢迎留言讨论 🔥🔥&…

【大数据】机器学习-----最开始的引路

以下是关于机器学习的一些基本信息,包括基本术语、假设空间、归纳偏好、发展历程、应用现状和代码示例: 一、基本术语 样本(Sample): 也称为实例(Instance)或数据点(Data Point&…

【WPS】【WORDEXCEL】【VB】实现微软WORD自动更正的效果

1. 代码规范方面 添加 Option Explicit:强制要求显式声明所有变量,这样可以避免因变量名拼写错误等情况而出现难以排查的逻辑错误,提高代码的健壮性。使用 On Error GoTo 进行错误处理:通过设置错误处理机制,当代码执行…