从零开始的目标检测和关键点检测(二):训练一个Glue的RTMDet模型

从零开始的目标检测和关键点检测(二):训练一个Glue的RTMDet模型

  • 一、config文件解读
  • 二、开始训练
  • 三、数据集分析
  • 四、ncnn部署

从零开始的目标检测和关键点检测(一):用labelme标注数据集

从零开始的目标检测和关键点检测(三):训练一个Glue的RTMPose模型

[1]用labelme标注自己的数据集中已经标注好数据集(关键点和检测框),通过labelme2coco脚本将所有的labelme json文件集成为两个coco格式的json文件,即train_coco.json和val_coco.json。训练一个RTMDet模型,需要重写config文件。

一、config文件解读

1、数据集类型即coco格式的数据集,metainfo是指框的类别,因为这里只有一个glue的类,因此NUM_CLASSES为1,注意metainfo类别名后的逗号,

# 数据集类型及路径
dataset_type = 'CocoDataset'
data_root = 'data/glue_134_Keypoint/'
metainfo = {'classes': ('glue',)}
NUM_CLASSES = len(metainfo['classes'])

2、加载backnbone预训练权重和RTMDet-tiny预训练权重

# RTMDet-tiny
load_from = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'
backbone_pretrain = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth'
deepen_factor = 0.167
widen_factor = 0.375
in_channels = [96, 192, 384]
neck_out_channels = 96
num_csp_blocks = 1
exp_on_reg = False

3、训练参数设置,如epoch、batchsize…

MAX_EPOCHS = 200
TRAIN_BATCH_SIZE = 8
VAL_BATCH_SIZE = 4
stage2_num_epochs = 20
base_lr = 0.004
VAL_INTERVAL = 5  # 每隔多少轮评估保存一次模型权重

4、default_runtime,即默认设置,在config文件夹的default_runtime.py可看到。不同的MM-框架的默认设置不一样(如default_scope = 'mmdet'),可以包含这个.py也可以直接复制过来。

default_scope = 'mmdet'
default_hooks = dict(
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=1),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=2, save_best='coco/bbox_mAP'),
    # auto coco/bbox_mAP_50 coco/bbox_mAP_75 coco/bbox_mAP_s
    sampler_seed=dict(type='DistSamplerSeedHook'),
    visualization=dict(type='DetVisualizationHook'))
env_cfg = dict(
    cudnn_benchmark=False,
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
    dist_cfg=dict(backend='nccl'))
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
    type='DetLocalVisualizer',
    vis_backends=[dict(type='LocalVisBackend')],
    name='visualizer')
log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
log_level = 'INFO'
load_from = None
resume = False

5、训练超参数配置

train_cfg = dict(
    type='EpochBasedTrainLoop',
    max_epochs=MAX_EPOCHS,
    val_interval=VAL_INTERVAL,
    dynamic_intervals=[(MAX_EPOCHS - stage2_num_epochs, 1)])

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# 学习率
param_scheduler = [
    dict(
        type='LinearLR', start_factor=1e-05, by_epoch=False, begin=0,
        end=1000),
    dict(
        type='CosineAnnealingLR',
        eta_min=0.0002,
        begin=150,
        end=300,
        T_max=150,
        by_epoch=True,
        convert_to_iter_based=True)
]

# 优化器
optim_wrapper = dict(
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
    paramwise_cfg=dict(
        norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
auto_scale_lr = dict(enable=False, base_batch_size=16)

6、数据处理pipeline,做数据预处理(数据增强)

# DataLoader
backend_args = None
train_pipeline = [
    dict(type='LoadImageFromFile', backend_args=None),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='CachedMosaic',
        img_scale=(640, 640),
        pad_val=114.0,
        max_cached_images=20,
        random_pop=False),
    dict(
        type='RandomResize',
        scale=(1280, 1280),
        ratio_range=(0.5, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=(640, 640)),
    dict(type='YOLOXHSVRandomAug'),
    dict(type='RandomFlip', prob=0.5),
    dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
    dict(
        type='CachedMixUp',
        img_scale=(640, 640),
        ratio_range=(1.0, 1.0),
        max_cached_images=10,
        random_pop=False,
        pad_val=(114, 114, 114),
        prob=0.5),
    dict(type='PackDetInputs')
]
test_pipeline = [
    dict(type='LoadImageFromFile', backend_args=None),
    dict(type='Resize', scale=(640, 640), keep_ratio=True),
    dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
    dict(
        type='PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
                   'scale_factor'))
]

7、加载数据和标注并用对应pipeliane做预处理

train_dataloader = dict(
    batch_size=TRAIN_BATCH_SIZE,
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    batch_sampler=None,
    dataset=dict(
        type='CocoDataset',
        data_root=data_root,
        metainfo=metainfo,
        ann_file='train_coco.json',
        data_prefix=dict(img='images/'),
        filter_cfg=dict(filter_empty_gt=True, min_size=32),
        pipeline=train_pipeline,
        backend_args=None),
    pin_memory=True)
val_dataloader = dict(
    batch_size=VAL_BATCH_SIZE,
    num_workers=2,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type='CocoDataset',
        data_root=data_root,
        metainfo=metainfo,
        ann_file='val_coco.json',
        data_prefix=dict(img='images/'),
        test_mode=True,
        pipeline=test_pipeline,
        backend_args=None))
test_dataloader = val_dataloader

8、定义模型结构backbone + neck + head

# 模型结构
model = dict(
    type='RTMDet',
    data_preprocessor=dict(
        type='DetDataPreprocessor',
        mean=[103.53, 116.28, 123.675],
        std=[57.375, 57.12, 58.395],
        bgr_to_rgb=False,
        batch_augments=None),
    backbone=dict(
        type='CSPNeXt',
        arch='P5',
        expand_ratio=0.5,
        deepen_factor=deepen_factor,
        widen_factor=widen_factor,
        channel_attention=True,
        norm_cfg=dict(type='SyncBN'),
        act_cfg=dict(type='SiLU', inplace=True),
        init_cfg=dict(
            type='Pretrained',
            prefix='backbone.',
            checkpoint=backbone_pretrain

        )),
    neck=dict(
        type='CSPNeXtPAFPN',
        in_channels=in_channels,
        out_channels=neck_out_channels,
        num_csp_blocks=num_csp_blocks,
        expand_ratio=0.5,
        norm_cfg=dict(type='SyncBN'),
        act_cfg=dict(type='SiLU', inplace=True)),
    bbox_head=dict(
        type='RTMDetSepBNHead',
        num_classes=NUM_CLASSES,
        in_channels=neck_out_channels,
        stacked_convs=2,
        feat_channels=neck_out_channels,
        anchor_generator=dict(
            type='MlvlPointGenerator', offset=0, strides=[8, 16, 32]),
        bbox_coder=dict(type='DistancePointBBoxCoder'),
        loss_cls=dict(
            type='QualityFocalLoss',
            use_sigmoid=True,
            beta=2.0,
            loss_weight=1.0),
        loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
        with_objectness=False,
        exp_on_reg=exp_on_reg,
        share_conv=True,
        pred_kernel_size=1,
        norm_cfg=dict(type='SyncBN'),
        act_cfg=dict(type='SiLU', inplace=True)),
    train_cfg=dict(
        assigner=dict(type='DynamicSoftLabelAssigner', topk=13),
        allowed_border=-1,
        pos_weight=-1,
        debug=False),
    test_cfg=dict(
        nms_pre=30000,
        min_bbox_size=0,
        score_thr=0.001,
        nms=dict(type='nms', iou_threshold=0.65),
        max_per_img=300))

二、开始训练

1、开始训练

python tools/train.py data/glue_134_Keypoint/rtmdet_tiny_glue.py

训练结果

 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.719
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.483
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.766
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.766
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.766

2、测试一下训练结果

python demo/image_demo.py data/glue_134_Keypoint/test_image/test.png data/glue_134_Keypoint/rtmdet_tiny_glue.py --weights work_dirs/rtmdet_tiny_glue/best_coco_bbox_mAP_epoch_180.pth --device cpu

在这里插入图片描述

3、可视化训练过程

在这里插入图片描述
在这里插入图片描述

4、由于标注数据集的glue都是小目标的,因此大目标无法识别,如下:
在这里插入图片描述

三、数据集分析

1、可视化部分图像

在这里插入图片描述

框标注-框中心点位置分布

在这里插入图片描述

框标注-框宽高分布

显然都是小目标的检测

在这里插入图片描述

四、ncnn部署

在线模型转换:Deploee

上传文件完成在线转换

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/115070.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Spring、SpringMVC、Mybatis

一.Spring基础 1.Spring 框架是什么 Spring 是一款开源的轻量级 Java 开发框架,我们一般说 Spring 框架指的都是 Spring Framework,它是很多模块的集合,例如,Spring core、Spring JDBC、Spring MVC 等,使用这些模块可…

python 数据挖掘库orange3 介绍

orange3 是一个非常适合初学者的data mining library. 它让使用者通过拖拽内置的组件来形成工作流。让你不需要写任何代码就可以体验到数据挖掘和可视化的魅力。 它的桌面如下,这里我创建了 3 个节点,分别是数据集、小提琴图,散点图 其中 …

误删的文件恢复了成乱码 误删的文件恢复了成乱码怎么调整

电脑系统:Windows11 电脑型号:惠普 软件版本:EasyRcovery14 关于电脑,我们可以说是非常熟悉,并熟练掌握了对电脑的最基本操作,比如复制、粘贴、新建、删除文件。但我们真的很懂它吗?比如误删…

SAP SD 定价 删除不满足条件的的条件类型

项目上的需求:当销售订单行项目类别满足条件时,根据配置表,删除不满足条件的的条件类型。 直接上增强点,bapi也能跑到这个位置。

vite安装Tailwind CSS

安装 - Tailwind CSS 中文网 (nodejs.cn) 这是官网,平常我练习一般会用vite脚手架 我们选择这个vite模块 可选择React和Vue版本的,这里选择react的按照操作,没问题的话就要出问题了 1、在npm run dev的时候我是出现了这么个问题&#xff0c…

XML External Entity-XXE-XML实体注入

XML 实体? XML 实体允许定义标签,在解析 XML 文档时这些标签将被内容替换。一般来说,实体分为三种类型: 内部实体 外部实体 参数实体。 必须在文档类型定义(DTD)中创建实体 一旦 XML 文档被解析器处理,它将js用定义的常量“Jo Smith”替换定义的实体。正如您所看到…

React Hooks的使用

目录 1.React Hooks使用注意事项 1.useState Hook: 2.useEffect Hook: 3.其他常用Hooks: 2.使用React Hooks需要遵循 1.安装React: 2.导入所需的Hooks: 3.使用Hooks创建组件: 4.在应用中使用组件&…

Pytorch从零开始实战08

Pytorch从零开始实战——YOLOv5-C3模块实现 本系列来源于365天深度学习训练营 原作者K同学 文章目录 Pytorch从零开始实战——YOLOv5-C3模块实现环境准备数据集模型选择开始训练可视化模型预测总结 环境准备 本文基于Jupyter notebook,使用Python3.8&#xff0c…

在Node.js中,什么是中间件(middleware)?它们的作用是什么?

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 欢迎来到前端入门之旅!感兴趣的可以订阅本专栏哦!这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

NodeJS 安装及环境配置

下载地址:https://nodejs.org/zh-cn/download/ 安装 NodeJS 根据自己电脑系统及位数选择,一般都选择 windows 64位 .msi 格式安装包。 所用命令: node -v npm -v PS:如果以上两条命令都能执行成功,表示安装完成&#…

虹科荣誉 | 喜讯!虹科成功入选“广州首届百家新锐企业”!!

文章来源:虹科品牌部 阅读原文:虹科荣誉 | 喜讯!虹科成功入选“广州首届百家新锐企业”!! 近日,由中共广州市委统战部、广州市工商业联合会、广州市工业和信息化局、广州市人民政府国有资产监督管理委员会…

linux系统SQL server数据库定时收缩

问题现象 出现下图问题,导致连接该数据库的程序不能正常启动 解决办法 定时收缩数据库 数据库定时收缩脚本 需要三个脚本文件 linux_sqlcmd_timing_task_shrink.sh:主脚本文件 # 设置数据库名称、用户名、密码等信息 # db_name"volador"…

Python:字符串格式化

文章目录 %用法使用format方法进行格式化 %用法 格式字符说明%s字符串%c单个字符%d十进制整数%o八进制整数%x十六进制整数%e指数(基底写为e)%E指数(基底写为E) x 1235 print(%o % x) print(%d % x) print(%x % x) print(%e % x) print(%s % 65) print(%c % a)使用format方法…

单链表的建立(头插法、尾插法)(数据结构与算法)

如果要把很多个数据元素存到一个单链表中,如何操作? 1.初始化一个单链表 2. 每次取一个数据元素,插入到表尾/表头 1. 尾插法建立单链表 尾插法建立的单链表元素顺序与输入数据集合的顺序相同,即按照输入数据的顺序排列。 使用尾插…

心理咨询预约小程序

随着微信小程序的日益普及,越来越多的人开始关注如何利用小程序来提供便捷的服务。对于心理咨询行业来说,搭建一个心理咨询预约小程序可以大大提高服务的效率和用户体验。本文以乔拓云平台为例,详细介绍如何轻松搭建一个心理咨询预约小程序。…

甘特图组件DHTMLX Gantt用例 - 如何拆分任务和里程碑项目路线图

创建一致且引人注意的视觉样式是任何项目管理应用程序的重要要求,这就是为什么我们会在这个系列中继续探索DHTMLX Gantt图库的自定义。在本文中我们将考虑一个新的甘特图定制场景,DHTMLX Gantt组件如何创建一个项目路线图。 DHTMLX Gantt正式版下载 用…

R语言的DICE模型实践技术

随着温室气体排放量的增大和温室效应的增强,全球气候变化问题受到日益的关注。我国政府庄严承诺在2030和2060年分别达到“碳达峰”和“碳中和”,因此气候变化和碳排放已经成为科研人员重点关心的问题之一。气候变化问题不仅仅是科学的问题,同…

百度百科怎么创建?百科创建需要注意哪些(一文看懂品牌/企业/人物百科创建)

随着互联网的不断发展,许多企业或品牌都选择创建百度百科作为一种很好的展示方式。百度百科可以被视为一张网络名片,拥有它能够提高人物、企业、品牌的知名度和影响力。那么人物百科、企业百科、品牌百科到底怎么创建呢? 大家创建百科前建议先…

【Python】

列表 列表中元素的类型可以不同,列表内部存储方式是元素值存储在不连续的空间,但是把他们的指针存在一块连续的空间 列表的创建 1.list1[] 创建一个空列表 2.用list函数 3.split函数截取 列表的更新 1.通过索引[]改变 2.切片修改 3.列表方法更新 列表…