用自己的数据集训练YOLO-NAS目标检测器

YOLO-NAS 是 Deci 开发的一种新的最先进的目标检测模型。 在本指南中,我们将讨论什么是 YOLO-NAS 以及如何在自定义数据集上训练 YOLO-NAS 模型。

在这里插入图片描述

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 3D场景编辑器

为了训练我们的自定义模型,我们将:

  • 加载预训练的YOLO-NAS模型;
  • 从 Roboflow 加载自定义数据集,或者使用UnrealSynth制作合成数据集
  • 设置超参数值;
  • 使用超级梯度 Python 包根据我们的数据训练模型;
  • 评估模型以了解结果。

话不多说,让我们开始吧!

1、什么是 YOLO-NAS?

You Only Look Once  神经架构搜索(YOLO-NAS)是最新最先进的(SOTA)实时目标检测模型。 在 COCO 数据集上进行评估并与其前身 YOLOv6 和 YOLOv8  相比,YOLO-NAS 以更低的延迟实现了更高的 mAP 值。

YOLO-NAS 作为 Deci 维护的 super-gradient包的一部分提供。

下图展示了Deci在YOLO-NAS上的基准测试结果:
在这里插入图片描述

YOLO-NAS 与其他顶级实时检测器在 COCO 数据集上的性能对比图

YOLO-NAS 在 Roboflow 100 数据集基准测试中也是最好的,这表明它可以轻松地在自定义数据集上进行微调。

在这里插入图片描述

YOLO-NAS 和其他顶级实时检测器在 RF100 数据集上的性能对比图

2、Python环境设置

在开始训练之前,我们需要准备好Python环境。 让我们从安装三个 pip 包开始。 YOLO-NAS 模型本身是使用 super-gradient 包进行分发的。 请记住,该模型仍在积极开发中。 为了保持环境的稳定性,最好固定特定版本的包。 此外,我们将安装 roboflow 和监督,这将使我们能够从 Roboflow Universe 下载数据集并分别可视化我们的训练结果。

pip install super-gradients==3.1.1
pip install roboflow
pip install supervision

如果你在 Jupyter Notebook 中运行 YOLO-NAS,请不要忘记在安装完成后重新启动环境。

3、使用预训练模型进行推理

在开始培训之前,最好确保安装按计划进行。 最简单的方法是使用预先训练的模型之一进行测试推理。 同时,这也能让我们熟悉YOLO-NAS API。

3.1 加载YOLO-NAS模型

为了使用预训练的 COCO 模型进行推理,我们首先需要选择模型的大小。 YOLO-NAS提供三种不同的模型大小:yolo_nas_s、yolo_nas_m和yolo_nas_l。

yolo_nas_s 模型是最小且最快的,但它可能不会像较大的模型那么准确。 相反,yolo_nas_l 模型最大、最准确、最慢。 yolo_nas_m 模型提供了两者之间的中间立场。

import torch
from super_gradients.training import models

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
MODEL_ARCH = 'yolo_nas_l'
#            'yolo_nas_m'
#            'yolo_nas_s'

model = models.get(MODEL_ARCH, pretrained_weights="coco").to(DEVICE)

3.2 YOLO-NAS模型推理

推理过程包括设置置信度阈值和调用预测方法。 预测方法将返回预测列表,其中每个预测对应于图像中检测到的对象。

CONFIDENCE_TRESHOLD = 0.35

result = list(model.predict(image, conf=CONFIDENCE_TRESHOLD))[0]

在这里插入图片描述

YOLO-NAS推理结果图示

3.3 YOLO-NAS 推理输出格式

YOLO-NAS 推理的输出是一个 ImageDetectionPrediction 对象,它封装了图像中检测到的对象的详细信息。 该对象包含三个字段:

  • image - 表示用于推理的图像的 NumPy 数组。
  • class_names - 模型训练期间使用的类别名称的 Python 列表。
  • Prediction -DetectionPrediction 类的实例,其中包含有关模型检测的详细信息。

DetectionPrediction对象具有三个字段:

  • bboxes_xyxy - 形状 (N, 4) 的 NumPy 数组,以 xyxy 格式表示检测到的对象的边界框。
  • confidence - 形状 (N,) 的 NumPy 数组,表示检测的置信度值。 每个值都在 0 和 1 之间。
  • labels - 形状 (N,) 的 NumPy 数组,表示检测到的对象的类 ID。 每个类 ID 对应于 class_names 列表中的一个索引。

4、使用开源数据集微调 YOLO-NAS

为了微调模型,我们需要数据。 我们将使用足球运动员检测图像数据集。

如果你已经有 YOLO 格式的数据集,请随意使用它。 如果没有,请看看 Roboflow Universe,那里拥有超过 200,000 个开源项目,并且所有项目都可以以任何格式导出。

另外一种获取数据集的方法是使用UnrealSynth,一个基于虚幻引擎开发的YOLO合成数据生成器,可以自动生成包括标注的训练数据集,非常方便:
在这里插入图片描述

https://tools.nsdt.cloud/UnrealSynth

import roboflow
from roboflow import Roboflow

roboflow.login()

rf = Roboflow()
project = rf.workspace(WORKSPACE_ID).project(PROJECT_ID)
dataset = project.version(PROJECT_VERSION).download("yolov5")

要训练 YOLO-NAS 模型,你需要设置几个关键参数。

首先,你需要选择模型尺寸。 有三个选项可供选择:小型、中型和大型。 请记住,较大的模型可能需要更长的时间来训练并需要更多的内存,因此如果使用的资源有限,你可能需要考虑使用较小的模型。

接下来,你需要设置批量大小。 该参数指示在训练过程的每次迭代期间将有多少图像通过神经网络。 较大的批量大小将加快训练过程,但也需要更多的内存。

MODEL_ARCH = 'yolo_nas_l'
BATCH_SIZE = 8
MAX_EPOCHS = 25
CHECKPOINT_DIR = f'{HOME}/checkpoints'
EXPERIMENT_NAME = project.name.lower().replace(" ", "_")
LOCATION = dataset.location
CLASSES = sorted(project.classes.keys())

dataset_params = {
    'data_dir': LOCATION,
    'train_images_dir':'train/images',
    'train_labels_dir':'train/labels',
    'val_images_dir':'valid/images',
    'val_labels_dir':'valid/labels',
    'test_images_dir':'test/images',
    'test_labels_dir':'test/labels',
    'classes': CLASSES
}

from super_gradients.training.dataloaders.dataloaders import (
    coco_detection_yolo_format_train, coco_detection_yolo_format_val)

train_data = coco_detection_yolo_format_train(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['train_images_dir'],
        'labels_dir': dataset_params['train_labels_dir'],
        'classes': dataset_params['classes']
    },
    dataloader_params={
        'batch_size': BATCH_SIZE,
        'num_workers': 2
    }
)

val_data = coco_detection_yolo_format_val(
    dataset_params={
        'data_dir': dataset_params['data_dir'],
        'images_dir': dataset_params['val_images_dir'],
        'labels_dir': dataset_params['val_labels_dir'],
        'classes': dataset_params['classes']
    },
    dataloader_params={
        'batch_size': BATCH_SIZE,
        'num_workers': 2
    }
)

最后,你需要设置训练过程的纪元数。 这本质上是整个数据集通过神经网络的次数。

5、训练自定义 YOLO-NAS 模型

你可能已经注意到,训练模型的过程比 YOLOv8 更加冗长。 Ultralytics 模型中的许多功能需要在 CLI 中传递参数,而对于 YOLO-NAS,则需要编写自定义逻辑。

最后,我们准备开始训练。 在调用 train 方法之前,值得运行 TensorBoard。 这将使我们能够实时跟踪培训的关键指标。 值得一提的是,YOLO-NAS还支持W&B等最流行的实验记录仪。
在这里插入图片描述

YOLO-NAS 训练期间获得的指标图

trainer.train(
    model=model, 
    training_params=train_params, 
    train_loader=train_data, 
    valid_loader=val_data
)

6、评估自定义 YOLO-NAS 模型

训练结束后,你可以使用Trainer提供的测试方法评估模型的性能。 你需要传入测试集数据加载器,训练器将返回一个指标列表,包括通常用于评估对象检测模型的平均精度(mAP)。

trainer.test(
    model=best_model,
    test_loader=test_data,
    test_metrics_list=DetectionMetrics_050(
        score_thres=0.1, 
        top_k_predictions=300, 
        num_cls=len(dataset_params['classes']), 
        normalize_targets=True, 
        post_prediction_callback=PPYoloEPostPredictionCallback(
            score_threshold=0.01, 
            nms_top_k=1000, 
            max_predictions=300,                                                                              
            nms_threshold=0.7
        )
    )
)

在这里插入图片描述

模型评估期间获得的预测与手动标注的比较

此外,你可以对测试集图像进行推理并可视化结果,以更好地了解模型在各个示例上的表现。 你还可以计算混淆矩阵,以更详细地了解每个类别的模型性能:

在这里插入图片描述

模型评估过程中创建的混淆矩阵

7、结束语

一夜之间,YOLO-NAS 成为实时物体检测器的新选择。 在为你的项目微调模型时,请记住要考虑所有方面——从模型准确性到推理速度,再到易于训练和许可限制。


原文链接:训练自己的YOLO-NAS — BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/115255.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

浮动模块布局

基本思路 若宽度和浏览器一样宽,则不需要设置width 一般父盒子使用标准流,然后标准流内使用浮动 一般父盒子需要居中显示,使用 margin: 0 auto; 注意浮动盒子之间的margin值 与 父盒子width、height值之间的相等关系,一定要计算…

Goland 对容器中的 Go 程序断点远程调试

1,针对 golang 程序打断点有哪几种情况 临时进程:针对临时运行一次的 Golang 脚本,比如定时统计脚本,定时推送脚本。常驻进程:针对一直在后台运行的 Golang 程序,比如 HTTP 或者 GRPC 服务。 我们现在假设…

Observability:使用 OpenTelemetry 手动检测 .NET 应用程序

作者:David Hope 在快节奏的软件开发领域,尤其是在云原生领域,DevOps 和 SRE 团队日益成为应用程序稳定性和增长的重要合作伙伴。 DevOps 工程师不断优化软件交付,而 SRE 团队则充当应用程序可靠性、可扩展性和顶级性能的管理者。…

免费记课时小程序-全优学堂

1. 教师使用小程序记上课 使用步骤 创建了员工账号,员工需设置为教师为班级进行排课使用系统账号绑定小程序,记上课 #1.1 创建员工账号 通过系统菜单’机构设置->员工管理‘,添加本机构教师及其他员工。 添加过程中,可设置…

Webpack搭建本地服务器

一、搭建webpack本地服务 1.为什么要搭建本地服务器? 目前我们开发的代码,为了运行需要有两个操作: 操作一:npm run build,编译相关的代码;操作二:通过live server或者直接通过浏览器&#x…

Path with “WEB-INF“ or “META-INF“: [webapp/WEB-INF/NewFile.html]

2023-11-04 01:03:14.523 WARN 10896 --- [nio-8072-exec-6] o.s.w.s.r.ResourceHttpRequestHandler : Path with "WEB-INF" or "META-INF": [webapp/WEB-INFNewFile.html] spring.mvc.view.prefix:/webapp/WEB-INF/

forward和完美转发

std::move(value)是独立于值的右值引用&#xff0c;一个右值引用参数作为函数的形参&#xff0c;在函数内部再转发该参数的时候已经变成了一个左值&#xff0c;并不是它原来的类型了。 template<typename T> void forwardValue(T& val) {processValue(value); //…

Leetcode刷题详解——二叉树的所有路径

1. 题目链接&#xff1a;257. 二叉树的所有路径 2. 题目描述&#xff1a; 给你一个二叉树的根节点 root &#xff0c;按 任意顺序 &#xff0c;返回所有从根节点到叶子节点的路径。 叶子节点 是指没有子节点的节点。 示例 1&#xff1a; 输入&#xff1a;root [1,2,3,null,5]…

AITO问界崛起的“临门一脚”,落在了赛力斯汽车的智慧工厂里

文 | 智能相对论 作者 | 沈浪 AITO问界新M7的销量爆了&#xff0c;口碑也紧接着“爆”了。 AITO问界新M7系列上市以来50天&#xff0c;累计大定突破8万辆。AITO问界M9预计今年12月上市&#xff0c;预订超过了1.5万辆。根据最新公布的产销数据&#xff0c;在过去的10月份&…

【蓝桥杯选拔赛真题48】python最小矩阵 青少年组蓝桥杯python 选拔赛STEMA比赛真题解析

目录 python最小矩阵 一、题目要求 1、编程实现 2、输入输出 二、算法分析

Python某建筑平台数据, 实现网站JS逆向解密

嗨喽&#xff0c;大家好呀~这里是爱看美女的茜茜呐 环境使用: 首先我们先来安装一下写代码的软件&#xff08;对没安装的小白说&#xff09; Python 3.8 / 编译器 Pycharm 2021.2版本 / 编辑器 专业版是付费的 <文章下方名片可获取魔法永久用~> 社区版是免费的 模块…

自动化测试入门知识 —— 数据驱动测试

一、什么是数据驱动测试&#xff1f; 数据驱动测试是一种测试方法&#xff0c;它的核心思想是通过不同的测试数据来验证同一个测试逻辑。通常情况下&#xff0c;测试用例中的输入数据和预期结果会被提取出来&#xff0c;以便可以通过不同的测试数据进行重复执行。 数据驱动测…

R语言用jsonlite库写的一个图片爬虫

目录 一、引言 二、jsonlite库介绍 三、图片爬虫实现步骤 1、发送HTTP请求获取图片列表 2、解析JSON数据提取图片链接 3、下载图片 四、实践与评估 五、注意事项 总结与展望 一、引言 随着互联网的发展&#xff0c;图片已经成为人们获取信息的重要途径之一。图片爬虫…

招聘小程序源码 招聘网源码 人才网源码 招聘求职小程序源码

招聘小程序源码 招聘网源码 人才网源码 招聘求职小程序源码 功能介绍&#xff1a; 1、发布招聘&#xff0c;建立企业人才库 支持企业入驻发布招聘职位&#xff0c;建立人才库&#xff1b; 2、发布简历&#xff0c;在线投简历 支持用户发布简历&#xff0c;向意向职位在线投简…

stm32 ADC

目录 简介 stm32的adc 框图 ①电压输入范围 ②输入通道 ​编辑③ADC通道 ④ADC触发 ⑤ADC中断 ⑥ADC数据 ⑦ADC时钟 ADC的四种转换模式 hal库代码 标准库代码 简介 自然界的信号几乎都是模拟信号&#xff0c;比如光亮、温度、压力、声音&#xff0c;而为了方便存储、…

OpenCV官方教程中文版 —— 图像修复

OpenCV官方教程中文版 —— 图像修复 前言一、基础二、代码三、更多资源 前言 本节我们将要学习&#xff1a; • 使用修补技术去除老照片中小的噪音和划痕 • 使用 OpenCV 中与修补技术相关的函数 一、基础 在我们每个人的家中可能都会几张退化的老照片&#xff0c;有时候…

领星ERP如何无需API开发轻松连接OA、电商、营销、CRM、用户运营、推广、客服等近千款系统

领星ERP&#xff08;LINGXING&#xff09;是一款专业的一站式亚马逊管理系统&#xff0c;帮助卖家构建完整的数据化运营闭环。&#xff0c;致力于为跨境电商卖家提供精细化运营和业财一体化的解决方案。 官网&#xff1a;https://erp.lingxing.com 集简云无代码集成平台&…

Spring Boot 使用断言抛出自定义异常,优化异常处理机制

文章目录 什么是断言&#xff1f;什么是异常&#xff1f;基于断言实现的异常处理机制创建自定义异常类创建全局异常处理器创建自定义断言类创建响应码类创建工具类测试效果 什么是断言&#xff1f; 实际上&#xff0c;断言&#xff08;Assertion&#xff09;是在Java 1.4 版本…

UE5数字孪生制作(一) - QGIS 学习笔记

1.下载 QGIS是免费的GIS工具&#xff0c;下载地址&#xff1a; https://www.qgis.org/en/site/ 2.安装 - 转中文 按照步骤安装&#xff0c;完成后&#xff0c;在菜单 设置settings里&#xff0c;选择options&#xff0c;修改语言 确定后&#xff0c;需要重启下软件 3.学习视…

Pycharm 对容器中的 Python 程序断点远程调试

pycharm如何连接远程服务器的docker容器有两种方法&#xff1a; 第一种&#xff1a;pycharm通过ssh连接已在运行中的docker容器 第二种&#xff1a;pycharm连接docker镜像&#xff0c;pycharm运行代码再自动创建容器 本文是第一种方法的教程&#xff0c;第二种请点击以上的链接…