【进阶篇】YOLOv8实现K折交叉验证——解决数据集样本稀少和类别不平衡的难题,让你的模型评估更加稳健


在这里插入图片描述


在这里插入图片描述
YOLOv8专栏导航:点击此处跳转


K折交叉验证

K折交叉验证(K-Fold Cross-Validation)是一种常用的机器学习模型评估方法,可以帮助我们评估模型的性能,特别适用于数据集相对较小的情况。

在K折交叉验证中,将原始数据集分成K个子集,然后依次将其中一个子集作为验证集,其余K-1个子集作为训练集进行模型训练和评估。这样可以得到K个模型,每个模型都在不同的验证集上进行评估,最后将K个模型的评估结果,求平均或取最优结果作为最终评估。

优点

  • 充分利用数据集: 在K折交叉验证中,整个数据集被划分为K个互斥的折叠(Folds)。每次训练模型时,都有K-1个折叠用于训练,而剩下的一个用于验证。这样,每个样本都有机会作为验证集的一部分,从而充分利用了数据集中的所有样本。
  • 减轻过拟合风险: 由于每个样本都会在训练集和验证集中出现,模型在验证集上的性能评估更具有代表性。这有助于减轻由于数据稀疏或类别不平衡导致的过拟合问题。
  • 更稳健的模型评估: K折交叉验证计算多次模型性能的平均值,提供了更稳健的性能评估。这对于对抗数据稀疏和类别不平衡等挑战的模型评估尤为重要,因为它减少了单次评估可能引入的随机性。
  • 处理类别不平衡: 如果数据集中某些类别的样本数量较少,K折交叉验证可以确保每个折叠中都包含这些少见类别的样本。这有助于确保模型在少见类别上的性能得到充分评估,并提高对类别不平衡的鲁棒性。

YOLOv8

YOLOv8 是由 YOLOv5 的发布者 Ultralytics 发布的最新版本的 YOLO。它可用于对象检测、分割、分类任务以及大型数据集的学习,并且可以在包括 CPU 和 GPU 在内的各种硬件上执行。

YOLOv8是一种尖端的、最先进的 (SOTA) 模型,它建立在以前成功的 YOLO 版本的基础上,并引入了新的功能和改进,以进一步提高性能和灵活性。YOLOv8 旨在快速、准确且易于使用,这也使其成为对象检测、图像分割和图像分类任务的绝佳选择。具体创新包括一个新的骨干网络、一个新的 Ancher-Free 检测头和一个新的损失函数,还支持YOLO以往版本,方便不同版本切换和性能对比。

YOLOv8的K折交叉验证实现步骤

  1. 首先,将数据集划分成K个子集,可以使用现有的数据集划分函数或手动划分。
  2. 然后,使用一个循环迭代K次,每次将其中一个子集作为验证集,其余K-1个子集作为训练集。
  3. 在每次迭代中,使用训练集进行模型训练,并使用验证集进行模型评估。
  4. 可以根据需要调整模型的超参数或其他设置。例如,可以尝试不同的学习率、迭代次数等。
  5. 最后,将K次迭代中的评估结果进行平均或取最优结果作为最终的模型评估结果。

官方项目下载

git clone https://github.com/ultralytics/ultralytics.git

🚀 YOLOv8-K折交叉验证代码实现

在上述下载好的项目文件夹下新建k-fold-train.py文件,并添加下述代码:

import argparse
import datetime
from itertools import chain
import os
from pathlib import Path
import shutil
import yaml
import pandas as pd
from collections import Counter
from sklearn.model_selection import KFold
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from ultralytics import YOLO

NUM_THREADS = min(8, max(1, os.cpu_count() - 1))

def parse_opt():
    parser = argparse.ArgumentParser()

    parser.add_argument('--data', default=r'./data')  # 数据集路径
    parser.add_argument('--ksplit', default=5, type=int)  # K-Fold交叉验证拆分数据集
    parser.add_argument('--im_suffixes', default=['jpg', 'png', 'jpeg'], help='images suffix')  # 图片后缀名

    return parser.parse_args()

def run(func, this_iter, desc="Processing"):
    with ThreadPoolExecutor(max_workers=NUM_THREADS, thread_name_prefix='MyThread') as executor:
        results = list(
            tqdm(executor.map(func, this_iter), total=len(this_iter), desc=desc)
        )
    return results

def main(opt):
    dataset_path, ksplit, im_suffixes = Path(opt.data), opt.ksplit, opt.im_suffixes

    save_path = Path(dataset_path / f'{datetime.date.today().isoformat()}_{ksplit}-Fold_Cross-Valid')
    save_path.mkdir(parents=True, exist_ok=True)

    # 获取所有图像和标签文件的列表
    images = sorted(list(chain(*[(dataset_path / "images").rglob(f'*.{ext}') for ext in im_suffixes])))
    # images = sorted(image_files)
    labels = sorted((dataset_path / "labels").rglob("*.txt"))

    root_directory = Path.cwd()
    print("当前文件运行根目录:", root_directory)
    if len(images) != len(labels):
        print('*' * 20)
        print('当前数据集和标签数量不一致!!!')
        print('*' * 20)

    # 从YAML文件加载类名
    classes_file = sorted(dataset_path.rglob('classes.yaml'))[0]
    assert classes_file.exists(), "请创建classes.yaml类别文件"
    if classes_file.suffix == ".txt":
        pass
    elif classes_file.suffix == ".yaml":
        with open(classes_file, 'r', encoding="utf8") as f:
            classes = yaml.safe_load(f)['names']
    cls_idx = sorted(classes.keys())

    # 创建DataFrame来存储每张图像的标签计数
    indx = [l.stem for l in labels]  # 使用基本文件名作为ID(无扩展名)
    labels_df = pd.DataFrame([], columns=cls_idx, index=indx)

    # 计算每张图像的标签计数
    for label in labels:
        lbl_counter = Counter()
        with open(label, 'r') as lf:
            lines = lf.readlines()
        for l in lines:
            # YOLO标签使用每行的第一个位置的整数作为类别
            lbl_counter[int(l.split(' ')[0])] += 1
        labels_df.loc[label.stem] = lbl_counter

    # 用0.0替换NaN值
    labels_df = labels_df.fillna(0.0)

    kf = KFold(n_splits=ksplit, shuffle=True, random_state=20)  # 设置random_state以获得可重复的结果
    kfolds = list(kf.split(labels_df))
    folds = [f'split_{n}' for n in range(1, ksplit + 1)]
    folds_df = pd.DataFrame(index=indx, columns=folds)

    # 为每个折叠分配图像到训练集或验证集
    for idx, (train, val) in enumerate(kfolds, start=1):
        folds_df[f'split_{idx}'].loc[labels_df.iloc[train].index] = 'train'
        folds_df[f'split_{idx}'].loc[labels_df.iloc[val].index] = 'val'

    # 计算每个折叠的标签分布比例
    fold_lbl_distrb = pd.DataFrame(index=folds, columns=cls_idx)
    for n, (train_indices, val_indices) in enumerate(kfolds, start=1):
        train_totals = labels_df.iloc[train_indices].sum()
        val_totals = labels_df.iloc[val_indices].sum()

        # 为避免分母为零,向分母添加一个小值(1E-7)
        ratio = val_totals / (train_totals + 1E-7)
        fold_lbl_distrb.loc[f'split_{n}'] = ratio

    ds_yamls = []

    for split in folds_df.columns:
        split_dir = save_path / split
        split_dir.mkdir(parents=True, exist_ok=True)
        (split_dir / 'train' / 'images').mkdir(parents=True, exist_ok=True)
        (split_dir / 'train' / 'labels').mkdir(parents=True, exist_ok=True)
        (split_dir / 'val' / 'images').mkdir(parents=True, exist_ok=True)
        (split_dir / 'val' / 'labels').mkdir(parents=True, exist_ok=True)

        dataset_yaml = split_dir / f'{split}_dataset.yaml'
        ds_yamls.append(dataset_yaml.as_posix())
        split_dir = (root_directory / split_dir).as_posix()

        with open(dataset_yaml, 'w') as ds_y:
            yaml.safe_dump({
                'train': split_dir + '/train/images',
                'val': split_dir + '/val/images',
                'names': classes
            }, ds_y)
    # print(ds_yamls)
    with open(dataset_path / 'yaml_paths.txt', 'w') as f:
        for path in ds_yamls:
            f.write(path + '\n')

    args_list = [(image, save_path, folds_df) for image in images]

    run(split_images_labels, args_list, desc=f"Creating dataset")

def split_images_labels(args):
    image, save_path, folds_df = args
    label = image.parents[1] / 'labels' / f'{image.stem}.txt'
    if label.exists():
        for split, k_split in folds_df.loc[image.stem].items():
            # 目标目录
            img_to_path = save_path / split / k_split / 'images'
            lbl_to_path = save_path / split / k_split / 'labels'
            shutil.copy(image, img_to_path / image.name)
            shutil.copy(label, lbl_to_path / label.name)


if __name__ == "__main__":
    opt = parse_opt()
    main(opt)

    model = YOLO('yolov8n.pt', task='train')
    # 从文本文件中加载内容并存储到一个列表中
    ds_yamls = []
    with open(Path(opt.data) / 'yaml_paths.txt', 'r') as f:
        for line in f:
            # 去除每行末尾的换行符
            line = line.strip()
            ds_yamls.append(line)

    # 打印加载的文件路径列表
    print(ds_yamls)

    for k in range(opt.ksplit):
        dataset_yaml = ds_yamls[k]
        name = Path(dataset_yaml).stem
        model.train(
            data=dataset_yaml,
            batch=16,
            epochs=100,
            imgsz=640,
            device=0,
            workers=8,
            project="runs/train",
            name=name,)

    print("*"*40)
    print("K-Fold Cross Validation Completed.")
    print("*"*40)

命令行运行

python k-fold-train.py --data ./data --ksplit 5
  • data:数据集所在目录,数据集的分布如下:
    - data
    	- images
    	- labels
    
    images存放图片文件,labels存放txt标注文件。
  • ksplit:数据集拆分为ksplit个子集,该值通常取5或者10,某些情况下可取其他数值。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/259498.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CUMT--Java--JDBC编程

目录 一、JDBC简介 二、数据库访问 1、加载数据库驱动 2、建立数据连接 3、创建Statement对象 4、执行SQL语句 5、访问结果集 三、MetaData接口 1、DatabaseMetaData接口 2、ResultSetMetaData接口 四、事务 1、JDBC中的事务 2、保存点 3、批量更新 一、JDBC简…

若依(ruoyi)管理系统标题和logo修改

1、网页上的logo 2、页面中的logo 进入ruoyi-ui --> src --> assets --> logo --> logo.png,把这个图片换成你自己的logo 3、网页标题 进入ruoyi-ui --> src --> layout --> components --> Sidebar --> Logo.vue,将里面的…

第二周:AI产品经理全局学习

一、AI产品架构全景 二、 AI产品岗位分析和了解 三、 AI产品能力模型 四、AI产品经理工作流 五、AI产品经理学习路径和规划 六、本周市场动态

VM安装Sonoma【笔记】

VMware Workstation安装MacOS Sonoma 1、配置虚拟机,根据系统性能调整参数; 2、先不焦急启动虚拟机,打开虚拟机存储目录,以文本方式打开.vmx文件(这里以Sonoma.vmx为例); 这里只针对Inter CP…

select和poll

slelect实现客户端 #include <myhead.h> #define PORT 8888 #define PORT1 9999 #define IP "192.168.125.147" int main(int argc, const char *argv[]) {int cfd-1;if((cfdsocket(AF_INET,SOCK_STREAM,0))-1){perror("socket error");return -1…

Go语言字符串综合指南:函数、方法和最佳实践

Go语言字符串综合指南&#xff1a;函数、方法和最佳实践 引言Go语言字符串基础声明和初始化不可变性字符串长度 字符串操作函数常用字符串操作转换与解析示例连接分割包含关系替换大小写转换整数与字符串的转换字符串到整数的转换格式化与解析 字符串与字符切片字符串和字符切片…

WINDOWS(WIN11)通过IP添加网络打印机

点击添加设备 点击手动添加 使用IP地址或主机名添加打印机 选择TCP/IP设备&#xff0c;输入打印机地址 如果有正确驱动就安装&#xff0c;没有就取消。 通过手动设置添加本地打印机或网络打印机 使用现有的端口 根据打印机IP&#xff0c;选择标准端口。 成功&#xff01; 到…

【LeetCode刷题笔记(12-1)】【Python】【有效的字母异位词】【排序/字符统计】【简单】

文章目录 引言有效的字母异位词题目描述提示 解决方案1&#xff1a;【排序】解决方案2&#xff1a;【字符统计】结束语 有效的字母异位词 引言 编写通过所有测试案例的代码并不简单&#xff0c;通常需要深思熟虑和理性分析。虽然这些代码能够通过所有的测试案例&#xff0c;但…

贪吃蛇(三)绘制蛇身

绘制蛇身的逻辑不难&#xff0c;存储上面使用结构体。 第一行和第十九行绘制--其它行&#xff0c;绘制|&#xff0c;分别在头尾处。 (1) 扫描蛇身&#xff0c;如果扫描到则绘制[]。 (2) 扫描蛇身&#xff0c;如果扫描不到则绘制空白。 #include"curses.h"struct Sn…

cpulimit设计理念及其思考

背景 前几天&#xff0c;同事咨询了我一个问题&#xff1a;IO占用能和cpu使用率那样&#xff0c;有方法来控制吗&#xff1f;这个问题的背景是因为客户提了两个需求&#xff0c;如下&#xff1a; 说实话&#xff0c;针对这两点需求&#xff0c;我的第一反应是有一点思路&#…

第二节TypeScript 基础语法

1、typescript程序由以下几个部分组成&#xff1a; 模块函数变量语句和表达式注释 2、开始第一个typescript程序 创建一个typescript程序&#xff0c;使之输出“hello typescript”&#xff1a; 代码&#xff1a; var message:string "hello typescript" cons…

MaBatis使用`ResultMap`标签手动映射详解使用

文章目录 MaBatis使用ResultMap标签手动映射详解使用1、MyBatis只能自动维护库表”列名“与”属性名“相同时的对应关系&#xff0c;二者不同时无法自动ORM&#xff0c;如下&#xff1a;2、在SQL中使用 as 为查询字段添加列别名&#xff0c;以匹配属性名&#xff1a;但是如果我…

(7)Linux GDB以及gcc和g++

&#x1f4ad; 前言 本章我们将带着大家高雅的学一学令众多习惯图形化页面的朋友难受的 gdb 调试&#xff0c;这部分知识可以选择性学习学习&#xff0c;以后倘若遇到一些问题时能在 Linux 内简单调试&#xff0c;还是很香的。然后在讲讲 gcc 和 g&#xff0c;系统讲解程序运行…

基于STC89C51单片机实现的森林防火系统源码+仿真+原理图+设计报告,含视频讲解

森林防火 摘要 森林防火是非常必要的,火灾对森林的破坏是具有毁灭性的,有着很大的危害,在春秋季节森林火灾高发期,若发生火灾,对人民生活带来极大危害,不仅危害人们生产生活,而且对地球环境产生影响.本课题研究的内容是以单片机STC89C51为控制核心&#xff0c;以MQ-2型半导体电…

YZ系列工具之YZ02:字典的多功能应用

我给VBA下的定义&#xff1a;VBA是个人小型自动化处理的有效工具。利用好了&#xff0c;可以大大提高自己的工作效率&#xff0c;而且可以提高数据的准确度。我的教程一共九套一部VBA手册&#xff0c;教程分为初级、中级、高级三大部分。是对VBA的系统讲解&#xff0c;从简单的…

C : DS二叉排序树之删除

Description 给出一个数据序列&#xff0c;建立二叉排序树&#xff0c;并实现删除功能 对二叉排序树进行中序遍历&#xff0c;可以得到有序的数据序列 Input 第一行输入t&#xff0c;表示有t个数据序列 第二行输入n&#xff0c;表示首个序列包含n个数据 第三行输入n个数据…

【TCP服务器的演变过程】编写第一个TCP服务器:实现一对一的连接通信

编写第一个TCP服务器&#xff1a;实现一对一的连接通信 一、前言二、需要使用到的API2.1、socket()函数2.2、bind()函数2.3、listen()函数2.4、accept()函数2.5、recv()函数2.6、send()函数2.7、strerror()函数 三、实现步骤四、完整代码五、TCP客户端5.1、自己实现一个TCP客户…

OpenHarmony南向之TP触摸屏

概述 Touchscreen驱动用于驱动触摸屏使其正常工作&#xff0c;该驱动主要完成如下工作&#xff1a;对触摸屏驱动IC进行上电、配置硬件管脚并初始化其状态、注册中断、配置通信接口&#xff08;I2C或SPI&#xff09;、设定Input相关配置、下载及更新固件等操作。 Touchscreen驱…

代码随想录算法训练营第四十一天|198.打家劫舍 ,213.打家劫舍II ,337.打家劫舍III

198. 打家劫舍 - 力扣&#xff08;LeetCode&#xff09; 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金&#xff0c;影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统&#xff0c;如果两间相邻的房屋在同一晚上被小偷闯入&#…

建行驻江门市分行纪检组党支部开展“以廉养人,以案警人”清廉文化现场教学活动

近日&#xff0c;建行驻江门市分行纪检组党支部联合建设支行党支部到江门市党群服务中心开展“以廉养人&#xff0c;以案警人”清廉文化现场教学活动。 名言语句亮初心。一楼展馆入口处竖立着“拔烂树、治病树、正歪树”“以猛药去疴刮骨疗毒的勇气反腐”“理想信念是共产党人的…