【目标检测】模型验证:K-Fold 交叉验证

K-Fold 交叉验证

  • 1、引言
    • 1.1 K 折交叉验证概述
  • 2、配置
    • 2.1 数据集
    • 2.2 安装包
  • 3、 实战
    • 3.1 生成物体检测数据集的特征向量
    • 3.2 K 折数据集拆分
    • 3.3 保存记录
    • 3.4 使用 K 折数据分割训练YOLO
  • 4、总结

1、引言

我们将利用YOLO 检测格式和关键的Python 库(如 sklearn、pandas 和 PyYaml),完成必要的设置、生成特征向量的过程以及 K-Fold 数据集拆分的执行。

1.1 K 折交叉验证概述

无论你的项目涉及水果检测数据集还是自定义数据源,都可以使用 K 折交叉验证,
以提高项目的可靠性和稳健性。

书说简短,闲言少叙,咱进入正题
在这里插入图片描述

2、配置

2.1 数据集

该数据集共包含 8479 幅图像。
它包括 6 个类别标签,每个标签的实例总数如下:

类别计数
苹果7049
葡萄7202
菠萝1613
橙色15549
香蕉3536
西瓜1976

2.2 安装包

必要的Python 软件包包括

  • ultralytics
  • sklearn
  • pandas
  • pyyaml

这次实例中,我们使用 k=5 折叠次数

3、 实战

3.1 生成物体检测数据集的特征向量

具体步骤如下:

  • 1、首先创建一个新的 demo.py Python 文件来执行下面的步骤。

  • 2、继续检索数据集的所有标签文件。

from pathlib import Path

dataset_path = Path("./Fruit-detection")  # replace with 'path/to/dataset' for your custom data
labels = sorted(dataset_path.rglob("*labels/*.txt"))  # all data in 'labels'
  • 3、现在,读取数据集 YAML 文件的内容并提取类标签的索引。
yaml_file = "path/to/data.yaml"  # your data YAML with data directories and names dictionary
with open(yaml_file, "r", encoding="utf8") as y:
    classes = yaml.safe_load(y)["names"]
cls_idx = sorted(classes.keys())
  • 4、初始化一个空的 pandas DataFrame.
import pandas as pd

index = [label.stem for label in labels]  # uses base filename as ID (no extension)
labels_df = pd.DataFrame([], columns=cls_idx, index=index)
  • 5、计算注释文件中每个类别标签的实例数。
from collections import Counter

for label in labels:
    lbl_counter = Counter()

    with open(label, "r") as lf:
        lines = lf.readlines()

    for line in lines:
        # classes for YOLO label uses integer at first position of each line
        lbl_counter[int(line.split(" ")[0])] += 1

    labels_df.loc[label.stem] = lbl_counter

labels_df = labels_df.fillna(0.0)  # replace `nan` values with `0.0`
  • 6、以下是已填充 DataFrame 的示例视图:
                                                       0    1    2    3    4    5
'0000a16e4b057580_jpg.rf.00ab48988370f64f5ca8ea4...'  0.0  0.0  0.0  0.0  0.0  7.0
'0000a16e4b057580_jpg.rf.7e6dce029fb67f01eb19aa7...'  0.0  0.0  0.0  0.0  0.0  7.0
'0000a16e4b057580_jpg.rf.bc4d31cdcbe229dd022957a...'  0.0  0.0  0.0  0.0  0.0  7.0
'00020ebf74c4881c_jpg.rf.508192a0a97aa6c4a3b6882...'  0.0  0.0  0.0  1.0  0.0  0.0
'00020ebf74c4881c_jpg.rf.5af192a2254c8ecc4188a25...'  0.0  0.0  0.0  1.0  0.0  0.0
 ...                                                  ...  ...  ...  ...  ...  ...
'ff4cd45896de38be_jpg.rf.c4b5e967ca10c7ced3b9e97...'  0.0  0.0  0.0  0.0  0.0  2.0
'ff4cd45896de38be_jpg.rf.ea4c1d37d2884b3e3cbce08...'  0.0  0.0  0.0  0.0  0.0  2.0
'ff5fd9c3c624b7dc_jpg.rf.bb519feaa36fc4bf630a033...'  1.0  0.0  0.0  0.0  0.0  0.0
'ff5fd9c3c624b7dc_jpg.rf.f0751c9c3aa4519ea3c9d6a...'  1.0  0.0  0.0  0.0  0.0  0.0
'fffe28b31f2a70d4_jpg.rf.7ea16bd637ba0711c53b540...'  0.0  6.0  0.0  0.0  0.0  0.0

解析

  • 行是标签文件的索引,每个标签文件对应数据集中的一幅图像,列则对应类标签索引。
  • 每一行代表一个伪特征向量,其中包含数据集中每个类标签的计数。
  • 这种数据结构可以将 K 折交叉验证应用于对象检测数据集。

3.2 K 折数据集拆分

  • 1、使用 KFold 从 sklearn.model_selection 以产生 k 对数据集进行分割。

    • 敲黑板:
      • 设置 shuffle=True 确保了分班中班级的随机分布。
      • 通过设置 random_state=M 其中 M 是一个选定的整数,这样就可以得到可重复的结果。
from sklearn.model_selection import KFold

ksplit = 5
kf = KFold(n_splits=ksplit, shuffle=True, random_state=20)  # setting random_state for repeatable results

kfolds = list(kf.split(labels_df))
  • 2、数据集现已分为 k 折叠,每个折叠都有一个 train 和 val 指数。我们将构建一个 DataFrame 来更清晰地显示这些结果。
folds = [f"split_{n}" for n in range(1, ksplit + 1)]
folds_df = pd.DataFrame(index=index, columns=folds)

for i, (train, val) in enumerate(kfolds, start=1):
    folds_df[f"split_{i}"].loc[labels_df.iloc[train].index] = "train"
    folds_df[f"split_{i}"].loc[labels_df.iloc[val].index] = "val"
  • 3、将计算每个褶皱的类别标签分布,并将其作为褶皱中出现的类别的比率。
fold_lbl_distrb = pd.DataFrame(index=folds, columns=cls_idx)

for n, (train_indices, val_indices) in enumerate(kfolds, start=1):
    train_totals = labels_df.iloc[train_indices].sum()
    val_totals = labels_df.iloc[val_indices].sum()

    # To avoid division by zero, we add a small value (1E-7) to the denominator
    ratio = val_totals / (train_totals + 1e-7)
    fold_lbl_distrb.loc[f"split_{n}"] = ratio
最理想的情况是,每次分割和不同类别的所有类别比率都相当相似。不过,这取决于数据集的具体情况。
  • 4、为每个分割创建目录和数据集 YAML 文件。
import datetime

supported_extensions = [".jpg", ".jpeg", ".png"]

# Initialize an empty list to store image file paths
images = []

# Loop through supported extensions and gather image files
for ext in supported_extensions:
    images.extend(sorted((dataset_path / "images").rglob(f"*{ext}")))

# Create the necessary directories and dataset YAML files (unchanged)
save_path = Path(dataset_path / f"{datetime.date.today().isoformat()}_{ksplit}-Fold_Cross-val")
save_path.mkdir(parents=True, exist_ok=True)
ds_yamls = []

for split in folds_df.columns:
    # Create directories
    split_dir = save_path / split
    split_dir.mkdir(parents=True, exist_ok=True)
    (split_dir / "train" / "images").mkdir(parents=True, exist_ok=True)
    (split_dir / "train" / "labels").mkdir(parents=True, exist_ok=True)
    (split_dir / "val" / "images").mkdir(parents=True, exist_ok=True)
    (split_dir / "val" / "labels").mkdir(parents=True, exist_ok=True)

    # Create dataset YAML files
    dataset_yaml = split_dir / f"{split}_dataset.yaml"
    ds_yamls.append(dataset_yaml)

    with open(dataset_yaml, "w") as ds_y:
        yaml.safe_dump(
            {
                "path": split_dir.as_posix(),
                "train": "train",
                "val": "val",
                "names": classes,
            },
            ds_y,
        )
  • 5、最后,将图像和标签复制到每个分割的相应目录("train "或 “val”)中。
import shutil

for image, label in zip(images, labels):
    for split, k_split in folds_df.loc[image.stem].items():
        # Destination directory
        img_to_path = save_path / split / k_split / "images"
        lbl_to_path = save_path / split / k_split / "labels"

        # Copy image and label files to new directory (SamefileError if file already exists)
        shutil.copy(image, img_to_path / image.name)
        shutil.copy(label, lbl_to_path / label.name)

3.3 保存记录

将 K 折分割和标签分布数据框的记录保存为 CSV 文件。

folds_df.to_csv(save_path / "kfold_datasplit.csv")
fold_lbl_distrb.to_csv(save_path / "kfold_label_distribution.csv")

3.4 使用 K 折数据分割训练YOLO

  • 首先,加载YOLO 模型。
from ultralytics import YOLO

weights_path = "path/to/weights.pt"
model = YOLO(weights_path, task="detect")
  • 其次,遍历数据集 YAML 文件以运行训练。结果将保存到由 project 和 name 参数。默认情况下,该目录为 “exp/runs#”,其中 # 为整数索引。
results = {}

# Define your additional arguments here
batch = 16
project = "kfold_demo"
epochs = 100

for k in range(ksplit):
    dataset_yaml = ds_yamls[k]
    model = YOLO(weights_path, task="detect")
    model.train(data=dataset_yaml, epochs=epochs, batch=batch, project=project)  # include any train arguments
    results[k] = model.metrics  # save output metrics for further analysis

4、总结

这篇小鱼使用了 K 折交叉验证来训练YOLO 物体检测模型的过程。

还创建报告 DataFrames 的程序,以可视化数据拆分和标签在这些拆分中的分布,清楚地了解训练集和验证集的结构。

此外,还保存了记录,这在大型项目或排除模型性能故障时尤为有用。

最后,在一个循环中使用每个拆分来执行实际的模型训练,保存训练结果,以便进一步分析和比较。

这种 K 折交叉验证技术是充分利用可用数据的一种稳健方法,有助于确保模型在不同数据子集中的性能是可靠和一致的。这将产生一个更具通用性和可靠性的模型,从而减少对特定数据模式的过度拟合。

我是小鱼

  • CSDN 博客专家
  • 阿里云 专家博主
  • 51CTO博客专家
  • 企业认证金牌面试官
  • 多个名企认证&特邀讲师等
  • 名企签约职场面试培训、职场规划师
  • 多个国内主流技术社区的认证专家博主
  • 多款主流产品(阿里云等)评测一等奖获得者

关注小鱼,学习【人工智能&大模型】/【深度学习&机器学习】领域最新最全的知识。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/967196.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

深度学习系列--04.梯度下降以及其他优化器

目录 一.梯度概念 1.一元函数 2.二元函数 3.几何意义上的区别 二.梯度下降 1.原理 2.步骤 3.示例代码(Python) 4.不同类型的梯度下降 5.优缺点 三.动量优化器(Momentum) 适用场景 1.复杂地形的优化问题 2.数据具有噪声的问…

π0开源了且推出自回归版π0-FAST——打造机器人动作专用的高效Tokenizer:比扩散π0的训练速度快5倍但效果相当

前言 过去的半个多月 对于大模型 deepseek火爆全球,我对其的解读也写成了整整一个系列 详见《火爆全球的DeepSeek系列模型》,涉及对GRPO、MLA、V3、R1的详尽细致深入的解读 某种意义来讲,deepseek 相当于把大模型的热度 又直接拉起来了——…

导航守卫router.beforeEach

router.beforeEach 是一个全局前置守卫,在每次路由跳转之前都会触发。 //index.jsrouter.beforeEach((to, from, next) > {// 打印即将要进入的目标路由信息console.log(即将要进入的目标路由信息:, to)// 打印当前正要离开的路由信息console.log(当前正要离开的…

[ESP32:Vscode+PlatformIO]添加第三方库 开源库 与Arduino导入第三方库的区别

前言 PlatformIO与Arduino在添加第三方库方面的原理存在显著差异 在PlatformIO中,第三方库的使用是基于项目(工程)的。具体而言,只有当你为一个特定的项目添加了某个第三方库后,该项目才能使用该库。这些第三方库的文…

了解AI绘图,Stable Diffusion的使用

AI绘图对GPU算力要求较高。 个人电脑配置可参考: CPU:14600kf 盒装 显卡:RTX 4080金属大师 OC,16G显存 主板:z790吹雪d4 内存:芝奇皇家戟4000c18,162G 硬盘:宏基gm7000 1T 散热:追风…

linux环境自动化golang项目启动脚本解析

一.场景介绍 当在本地创建了golang项目,修改了代码功能,怎么在远程测试服务器上更新该功能呢,可以使用下面的步骤来解决该问题(这只是其中一种方法): (1).推送最新代码到远程仓库 (2).在测试服务器上创建该项目并拉取最新代码 (3).创建deploy.sh脚本 (4).运行deploy.sh脚本 二.…

归一化与伪彩:LabVIEW图像处理的区别

在LabVIEW的图像处理领域,归一化(Normalization)和伪彩(Pseudo-coloring)是两个不同的概念,虽然它们都涉及图像像素值的调整,但目的和实现方式截然不同。归一化用于调整像素值的范围&#xff0c…

基于DeepSeek API和VSCode的自动化网页生成流程

1.创建API key 访问官网DeepSeek ,点击API开放平台。 在开放平台界面左侧点击API keys,进入API keys管理界面,点击创建API key按钮创建API key,名称自定义。 2.下载并安装配置编辑器VSCode 官网Visual Studio Code - Code Editing…

Open WebUI项目源码学习记录(从0开始基于纯CPU环境部署一个网页Chat服务)

感谢您点开这篇文章:D,鼠鼠我是一个代码小白,下文是学习开源项目Open WebUI过程中的一点笔记记录,希望能帮助到你~ 本人菜鸟,持续成长,能力不足有疏漏的地方欢迎一起探讨指正,比心心&#xff5e…

SSM仓库物品管理系统 附带详细运行指导视频

文章目录 一、项目演示二、项目介绍三、运行截图四、主要代码1.用户登录代码:2.保存物品信息代码:3.删除仓库信息代码: 一、项目演示 项目演示地址: 视频地址 二、项目介绍 项目描述:这是一个基于SSM框架开发的仓库…

Python微博动态爬虫

本文是刘金路的《语言数据获取与分析基础》第十章的扩展,详细解释了如何利用Python进行微博爬虫,爬虫内容包括微博指定帖子的一级评论、评论时间、用户名、id、地区、点赞数。 整个过程十分明了,就是用户利用代码模拟Ajax请求,发…

时序数据库:Influxdb详解

文章目录 一、简介1、简介2、官网 二、部署1、安装2、配置(1)用户初始化 三、入门(Web UI)1、加载数据(1)上传数据文件(2)代码接入模板 2、管理存储桶(1)创建…

unity学习32:角色相关1,基础移动控制

目录 0 应用商店 1 角色上新增CharacterController 组件 1.1 角色上新增CharacterController 组件 1.2 如果没有这个则会报错 2 速度 2.1 默认速度,按帧率计算 2.2 修改速度为按时间计算 2.3 movespeed,基础是1米/秒,这个就是每 move…

Centos Ollama + Deepseek-r1+Chatbox运行环境搭建

Centos Ollama Deepseek-r1Chatbox运行环境搭建 内容介绍下载ollama在Ollama运行DeepSeek-r1模型使用chatbox连接ollama api 内容介绍 你好! 这篇文章简单讲述一下如何在linux环境搭建 Ollama Deepseek-r1。并在本地安装的Chatbox中进行远程调用 下载ollama 登…

mysql8.0使用pxc实现高可用

环境准备 准备三台虚拟机,其对应的主机名和IP地址为 pxc-1192.168.190.129pxc-2192.168.190.133pxc-3192.168.190.134 解析,都要做解析 测试 下载pxc的安装包, 官网:https://www.percona.com/downloads 选择8.0的版本并下载,…

LabVIEW污水生化处理在线监测

污水处理是环保领域的重要工作,传统污水处理方法在监测方面存在实时性差、操作不便等问题。为解决这些问题,本项目设计并实现了一套基于LabVIEW的污水生化处理在线监测平台,能够实时监测污水处理过程中的关键参数,如温度、pH值、溶…

【AI学习】关于 DeepSeek-R1的几个流程图

遇见关于DeepSeek-R1的几个流程图,清晰易懂形象直观,记录于此。 流程图一 来自文章《Understanding Reasoning LLMs》, 文章链接:https://magazine.sebastianraschka.com/p/understanding-reasoning-llms?continueFlagaf07b1a0…

vs封装dll 给C#使用

一,vs创建控制台应用 创建控制台应用得好处时,我们可以自己测试接口,如果接口没有问题,改成dll重新编译一遍就可以。 二, 创建一个c 类,将所需提供得功能 封装到类中。 这样可以将 所有功能,进…

ubuntu20使用tigervnc远程桌面配置记录

一、安装tigervnc sudo apt install tigervnc-common sudo apt install tigervnc-standalone-server二、增加配置文件 安装完后新增配置文件:vim ~/.vnc/xstartup #!/bin/sh #Uncomment the following two lines for normal desktop: #unset SESSION_MANAGER #ex…

DeepSeek使用技巧大全(含本地部署教程)

在人工智能技术日新月异的今天,DeepSeek 作为一款极具创新性和实用性的 AI,在众多同类产品中崭露头角,凭借其卓越的性能和丰富的功能,吸引了大量用户的关注。 DeepSeek 是一款由国内顶尖团队研发的人工智能,它基于先进…