RT-DETR代码详解(官方pytorch版)——参数配置(1)

前言

RT-DETR虽然是DETR系列,但是它的代码结构和之前的DETR系列代码不一样。

它是通过很多的yaml文件进行参数配置,和之前在train.py的parser = argparse.ArgumentParser()去配置所有参数不同,所以刚开始不熟悉代码的时候可能不知道在哪儿修改参数。

RT-DETR有官方版和ultralytics版两个版本代码,可以参考以下链接,分别使用两种方法对代码进行复现:
详解RT-DETR网络结构/数据集获取/环境搭建/训练/推理/验证/导出/部署_rt-dert-CSDN博客

下述内容主要是针对参数配置的代码实现进行解读,因为刚开始我拿着代码都不知道是怎么运行的,模型在哪儿加载参数都找不到

一、train.py文件

在RT-DETR中,train.py文件需要配置的内容很少,因为需要的参数配置全都放在了rtdetr_rxxvd_6x_coco.yml(骨干网络可选)文件中。在这个文件中又包含了其他所有的文件,可以依需修改:

左边是可以选择的backbone骨干网络,后续以ResNet18为例。

二、rtdetr_r18vd_6x_coco.yaml文件


__include__: [
  '../dataset/coco_detection.yml',  # 数据集
  '../runtime.yml', # 运行参数配置
  './include/dataloader.yml', # 定义数据加载器参数
  './include/optimizer.yml', # 定义优化器通用设置
  './include/rtdetr_r50vd.yml', # 定义 RT-DETR 模型的结构参数(如 backbone 和解码器层数等
]


output_dir: ./output/rtdetr_r18vd_6x_coco  # 输出的文件地址

PResNet:
  depth: 18
  freeze_at: -1 # 不冻结任何层(如果设置为正数,则冻结 ResNet 的前几层)
  freeze_norm: False # 不冻结归一化层(如 BatchNorm)
  pretrained: True # 加载预训练权重(通常是基于 ImageNet 数据集的权重)

HybridEncoder:
  in_channels: [128, 256, 512] # 编码器的输入特征通道数,分别对应 ResNet-18 不同尺度的特征图输出
  hidden_dim: 256
  expansion: 0.5 # 特征通道扩展比例


RTDETRTransformer:
  eval_idx: -1 # 指定在哪一层解码器输出进行评估(-1 表示最后一层)
  num_decoder_layers: 3 # 解码器的层数
  num_denoising: 100  # 去噪查询的数量



optimizer:
  type: AdamW # 该优化器改进了 Adam,支持权重衰减以减轻过拟合
  params:  # 参数分组,针对不同模块的参数设置不同的学习率和权重衰减
    - 
      params: '^(?=.*backbone)(?=.*norm).*$'      # 匹配骨干网络中的归一化层参数,设置较低学习率和无权重衰减
      lr: 0.00001
      weight_decay: 0.
    - 
      params: '^(?=.*backbone)(?!.*norm).*$'      # 匹配骨干网络中非归一化参数
      lr: 0.00001
    - 
      params: '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bias)).*$'   # 匹配 Transformer 中归一化层或偏置参数
      weight_decay: 0.

  lr: 0.0001
  betas: [0.9, 0.999] # Adam 优化器的 beta 参数
  weight_decay: 0.0001 # 权重衰减值

上面的注释只是为了解释各行代码意思,但是运行代码过程中,yaml文件不能有注释,否则会报错:

三、yaml_config.py文件

 在train.py文件中,实际是通过YAMLConfig()这个类读取rtdetr_r18vd_6x_coco.yaml中的配置信息。通过加载 YAML 配置文件,将不同的模型、优化器、数据加载器等组件以模块化的方式创建

 主要功能

1. 动态加载 YAML 配置文件

  • 使用 load_config 函数加载 YAML 文件,读取其中的配置数据。
  • 支持通过 merge_dict 将命令行或其他来源的参数覆盖 YAML 文件中的默认配置。

2. 组件动态创建

  • 根据 YAML 文件的配置,动态创建模型(model)、损失函数(criterion)、优化器(optimizer)、学习率调度器(lr_scheduler)和数据加载器(dataloader)等。

3. 参数分组和正则匹配

  • 支持为优化器指定不同模块的参数组,并通过正则表达式选择分组的参数。

4. 支持扩展功能

  • 支持 EMA(Exponential Moving Average,指数滑动平均) 和 AMP(Automatic Mixed Precision,自动混合精度)
  • 自动处理模型参数的冻结、梯度裁剪等功能。

5. 模块化设计

  • 配置组件通过 create 函数动态实例化,便于扩展和自定义。

3.1 类初始化与加载配置

class YAMLConfig(BaseConfig):
    def __init__(self, cfg_path: str, **kwargs) -> None:
        super().__init__()
        cfg = load_config(cfg_path)  # 加载 YAML 配置文件
        merge_dict(cfg, kwargs)  # 合并外部输入的参数(高优先级)

        self.yaml_cfg = cfg  # 保存解析后的 YAML 配置

        # 一些常见配置的提取
        self.log_step = cfg.get('log_step', 100)
        self.checkpoint_step = cfg.get('checkpoint_step', 1)
        self.epoches = cfg.get('epoches', -1)
        self.resume = cfg.get('resume', '')
        self.tuning = cfg.get('tuning', '')
        self.sync_bn = cfg.get('sync_bn', False)
        self.output_dir = cfg.get('output_dir', None)
        self.use_ema = cfg.get('use_ema', False)
        self.use_amp = cfg.get('use_amp', False)
        self.autocast = cfg.get('autocast', dict())
        self.find_unused_parameters = cfg.get('find_unused_parameters', None)
        self.clip_max_norm = cfg.get('clip_max_norm', 0.0)
  • 功能
    • 从 YAML 配置文件中加载配置,初始化训练流程中常用的参数。
    • cfg_path:YAML 配置文件路径。
    • kwargs:支持通过外部传入参数(如命令行参数)覆盖 YAML 中的默认配置
    • 使用 get 方法设置默认值,避免配置文件缺失某些字段时程序报错。

 3.1.1 yaml_config.py文件

  通过cfg = load_config(cfg_path)已经将所有的配置信息传递给cfg了

尽管传入的只有一个rtdetr_r18vd_6x_coco.yaml文件,但它里面包含了其他的配置文件地址:

load_config()函数在yaml_utils.py文件中


def load_config(file_path, cfg=dict()):
    """
    加载 YAML 配置文件,并支持递归加载包含的其他 YAML 文件。
    Args:
        file_path (str): 要加载的 YAML 文件路径。
        cfg (dict): 全局配置字典,默认为空字典。
    Returns:
        dict: 加载并合并后的配置字典。
    """
    # 获取文件扩展名并确保是 YAML 文件
    _, ext = os.path.splitext(file_path)
    assert ext in ['.yml', '.yaml'], "仅支持 YAML 文件(.yml 或 .yaml)"

    # 打开并加载 YAML 文件
    with open(file_path, 'r') as f:
        file_cfg = yaml.load(f, Loader=yaml.Loader)
        if file_cfg is None:
            return {}  # 如果文件为空,则返回空字典

    # 检查是否需要加载包含的 YAML 配置(递归加载)
    if INCLUDE_KEY in file_cfg:
        # 提取 'include' 键的值,通常是其他 YAML 文件路径的列表
        base_yamls = list(file_cfg[INCLUDE_KEY])
        for base_yaml in base_yamls:
            # 将路径展开为完整路径(支持用户目录 ~ 和相对路径)
            if base_yaml.startswith('~'):
                base_yaml = os.path.expanduser(base_yaml)
            if not base_yaml.startswith('/'):  # 如果是相对路径
                base_yaml = os.path.join(os.path.dirname(file_path), base_yaml)

            # 递归加载被包含的 YAML 文件
            base_cfg = load_config(base_yaml, cfg)
            # 合并当前加载的配置到全局配置中
            merge_config(base_cfg, cfg)

    # 最终合并当前文件的配置到全局配置中
    return merge_config(file_cfg, cfg)

  • 通过 include 字段,可以将配置拆分成多个 YAML 文件,便于管理和维护。
  • 支持递归加载多个 YAML 文件,并通过 merge_config 实现配置合并,确保最终配置完整。

  

 3.2 动态加载组件(如模型、优化器等)

 通 @property 装饰器延迟加载组件,仅在实际使用时创建对象

@property装饰器

是 Python 的一个内置装饰器,常用于定义一个类的方法,并将其伪装成“属性”。

  1. 保护类的封装特性
  2. 让开发者可以使用“对象.属性”的方式操作操作类属性

通过 @property 装饰器,可以直接通过方法名来访问方法,不需要在方法名后添加一对“()”小括号。

语法格式:

@property
def 方法名(self)
    代码块

更多@property装饰器内容可看,其中包含延时加载的应用:@property装饰器-CSDN博客

 3.2.1 模型加载

@property
def model(self) -> torch.nn.Module:
    if self._model is None and 'model' in self.yaml_cfg:
        merge_config(self.yaml_cfg)  # 合并全局配置
        self._model = create(self.yaml_cfg['model'])  # 动态创建模型
    return self._model
  • 检查 _model 是否已经创建,若未创建且配置中包含 model 字段,则动态创建模型。(self.yaml_cfg已经存储了所有的配置信息,见3.1.1 图,提取model键的值)
  • 使用 create 函数按照 yaml_cfg['model'] 中的定义实例化模型。

在rtdetr_r18vd_6x_coco.yml--->./include/rtdetr_r50vd.yml中 :

3.2.2 优化器延迟加载

@property
def optimizer(self):
    if self._optimizer is None and 'optimizer' in self.yaml_cfg:
        merge_config(self.yaml_cfg)  # 合并全局配置
        params = self.get_optim_params(self.yaml_cfg['optimizer'], self.model)  # 获取参数分组
        self._optimizer = create('optimizer', params=params)  # 动态创建优化器
    return self._optimizer
  • 获取优化器参数分组(get_optim_params),根据配置动态创建优化器实例。

3.2.3  学习率调度器加载

@property
def lr_scheduler(self):
    if self._lr_scheduler is None and 'lr_scheduler' in self.yaml_cfg:
        merge_config(self.yaml_cfg)
        self._lr_scheduler = create('lr_scheduler', optimizer=self.optimizer)
        print('Initial lr: ', self._lr_scheduler.get_last_lr())
    return self._lr_scheduler
  • 动态创建学习率调度器对象,并与优化器绑定

在rtdetr_r18vd_6x_coco.yml--->./include/optimizer.yml中 :

基于MultiStepLR生成对应的学习率调度器

  • MultiStepLR 是 PyTorch 中 torch.optim.lr_scheduler 提供的一种学习率调度器
  • 它会在指定的训练步骤(milestones)调整学习率

根据配置,初始学习率为 0.1在第 1000 步时,学习率会乘以 gamma=0.1,变为 0.01。输出如下:

Step 0: Learning Rate = 0.1
Step 500: Learning Rate = 0.1
Step 1000: Learning Rate = 0.01
Step 1500: Learning Rate = 0.01

3.3 数据加载器

@property
def train_dataloader(self):
    if self._train_dataloader is None and 'train_dataloader' in self.yaml_cfg:
        merge_config(self.yaml_cfg)
        self._train_dataloader = create('train_dataloader')
        self._train_dataloader.shuffle = self.yaml_cfg['train_dataloader'].get('shuffle', False)
    return self._train_dataloader
  • 动态加载训练数据加载器,并根据配置调整 shuffle 参数

3.4 参数分组(正则表达式匹配)

@staticmethod
def get_optim_params(cfg: dict, model: nn.Module):
    '''
    E.g.:
        ^(?=.*a)(?=.*b).*$         means including a and b
        ^((?!b.)*a((?!b).)*$       means including a but not b
        ^((?!b|c).)*a((?!b|c).)*$  means including a but not (b | c)
    '''
    assert 'type' in cfg, ''
    cfg = copy.deepcopy(cfg)

    if 'params' not in cfg:
        return model.parameters()  # 如果未定义参数分组,返回默认模型参数

    assert isinstance(cfg['params'], list), ''

    param_groups = []
    visited = []
    for pg in cfg['params']:
        pattern = pg['params']
        params = {k: v for k, v in model.named_parameters() if v.requires_grad and len(re.findall(pattern, k)) > 0}
        pg['params'] = params.values()
        param_groups.append(pg)
        visited.extend(list(params.keys()))

    names = [k for k, v in model.named_parameters() if v.requires_grad]

    if len(visited) < len(names):
        unseen = set(names) - set(visited)
        params = {k: v for k, v in model.named_parameters() if v.requires_grad and k in unseen}
        param_groups.append({'params': params.values()})
        visited.extend(list(params.keys()))

    assert len(visited) == len(names), ''
    return param_groups
  • 根据正则表达式匹配模型中的参数(named_parameters 方法返回 <参数名, 参数> 的映射)。
  • 支持按模块或特定规则分组优化器参数(如设置不同学习率、权重衰减)。
  • 未匹配的参数会自动归为默认组。

  • ^(?=.*backbone)(?=.*norm).*$:匹配键名中包含 backbone 和 norm 的参数。
  • ^(?=.*encoder)(?!.*bias).*$:匹配键名中包含 encoder 且不包含 bias 的参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/952132.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Halcon在linux及ARM上的安装及c++工程化

一、HALCON下载 建议到HALCON官方下载页选择linux版本下载,压缩包名为MVTec_HALCON_Progress-18.11.0.1-linux(x64-aarch64-armv7a)-FullVersion.tar.gz。下载前需要登录HALCON帐号,如果没有请自行注册,填写一些基本信息然后激活邮件,操作方便简易。 下载许可证文件 该许…

单通道串口服务器(三格电子)

一、产品介绍 1.1 功能简介 SG-TCP232-110 是一款用来进行串口数据和网口数据转换的设备。解决普通 串口设备在 Internet 上的联网问题。 设备的串口部分提供一个 232 接口和一个 485 接口&#xff0c;两个接口内部连接&#xff0c;同 时只能使用一个口工作。 设 备 的网 口…

Figma如何装中文字体-PingFang苹方字体、Alibaba PuHuiTi阿里普惠

**写在前面&#xff1a; 工具类软件更新迭代如此快的世界&#xff0c;不能靠历史知识来做操作反应。需要着眼于当下工具的形态来思考用法。另外&#xff0c;有人说&#xff0c;当前的用户越来越少发教程类的图文消息了&#xff08;转去了视频&#xff09;&#xff0c;现在很多…

JVM实战—13.OOM的生产案例

大纲 1.每秒仅上百请求的系统为何会OOM(RPC超时时间设置过长导致QPS翻几倍) 2.Jetty服务器的NIO机制如何导致堆外内存溢出(S区太小 禁NIO的显式GC) 3.一次微服务架构下的RPC调用引发的OOM故障排查实践(MAT案例) 4.一次没有WHERE条件的SQL语句引发的OOM问题排查实践(使用MA…

Photoshop PS批处理操作教程(批量修改图片尺寸、参数等)

前言 ‌Photoshop批处理的主要作用‌是通过自动化处理一系列相似的操作来同时应用于多张图片&#xff0c;从而节省时间和精力&#xff0c;提高工作效率。批处理功能特别适用于需要批量处理的任务&#xff0c;如图像尺寸调整、颜色校正、水印添加等‌。 操作步骤 1.创建动作 …

【物联网原理与运用】知识点总结(上)

目录 名词解释汇总 第一章 物联网概述 1.1物联网的基本概念及演进 1.2 物联网的内涵 1.3 物联网的特性——泛在性 1.4 物联网的基本特征与属性&#xff08;五大功能域&#xff09; 1.5 物联网的体系结构 1.6 物联网的关键技术 1.7 物联网的应用领域 第二章 感知与识别技术 2.1 …

新车月交付突破2万辆!小鹏汽车“激活”智驾之困待解

首次突破月交付2万辆规模的小鹏汽车&#xff0c;稳吗&#xff1f; 本周&#xff0c;高工智能汽车研究院发布的最新监测数据显示&#xff0c;2024年11月&#xff0c;小鹏汽车在国内市场&#xff08;不含出口&#xff09;交付量&#xff08;上险口径&#xff0c;下同&#xff09…

基于Springboot+Vue的仓库管理系统

开发一个基于Spring Boot和Vue的仓库管理系统涉及到前端和后端的开发。本文呢&#xff0c;给出一个简单的开发步骤指南&#xff0c;用于指导初入的新手小白如何开始构建这样一个系统&#xff0c;如果**你想直接学习全部内容&#xff0c;可以直接拉到文末哦。** 开始之前呢给小…

java项目之ONLY在线商城系统设计与实现源码(springboot+vue+mysql)

大家好我是风歌&#xff0c;曾担任某大厂java架构师&#xff0c;如今专注java毕设领域。今天要和大家聊的是一款基于springboot的ONLY在线商城系统设计与实现。项目源码以及部署相关请联系风歌&#xff0c;文末附上联系信息 。 项目简介&#xff1a; ONLY在线商城系统设计与实…

java后端对接飞书登陆

java后端对接飞书登陆 项目要求对接第三方登陆&#xff0c;飞书登陆&#xff0c;次笔记仅针对java后端&#xff0c;在看本笔记前&#xff0c;默认已在飞书开发方已建立了应用&#xff0c;并获取到了appid和appsecret。后端要做的其实很简单&#xff0c;基本都是前端做的&…

【2025最新计算机毕业设计】基于SpringBoot+Vue奶茶点单系统(高质量源码,提供文档,免费部署到本地)

作者简介&#xff1a;✌CSDN新星计划导师、Java领域优质创作者、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和学生毕业项目实战,高校老师/讲师/同行前辈交流。✌ 主要内容&#xff1a;&#x1f31f;Java项目、Python项目、前端项目、PHP、ASP.NET、人工智能…

import语句详解

在 Java 中&#xff0c;import 语句用于引入其他包中的类、接口或静态成员&#xff0c;以便在当前源文件中直接使用它们&#xff0c;而不需要写完整的类名&#xff08;包括包名&#xff09;。以下是 import 语句的详细解释和使用方法&#xff1a; 一、import语句的基本概念 定…

android刷机

android ota和img包下载地址&#xff1a; https://developers.google.com/android/images?hlzh-cn android启动过程 线刷 格式&#xff1a;ota格式 模式&#xff1a;recovery 优点&#xff1a;方便、简单&#xff0c;刷机方法通用&#xff0c;不会破坏手机底层数据&#xff0…

Wi-Fi Direct (P2P)原理及功能介绍

目录 Wi-Fi Direct &#xff08;P2P&#xff09;介绍Wi-Fi Direct P2P 概述P2P-GO&#xff08;P2P Group Owner&#xff09;工作流程 wifi-Direct使用windows11 wifi-directOpenwrtwifi的concurrent mode Linux环境下的配置工具必联wifi芯片P2P支持REF Wi-Fi Direct &#xff…

scrapy爬取图片

scrapy 爬取图片 环境准备 python3.10scrapy pillowpycharm 简要介绍scrapy Scrapy 是一个开源的 Python 爬虫框架&#xff0c;专为爬取网页数据和进行 Web 抓取而设计。它的主要特点包括&#xff1a; 高效的抓取性能&#xff1a;Scrapy 采用了异步机制&#xff0c;能够高效…

python学opencv|读取图像(二十八)使用cv2.warpAffine()函数平移图像

【1】引言 前序已经对图像操作进行了广泛的学习&#xff0c;包括读取、放大缩小&#xff0c;改变BGR通道值等&#xff0c;相关链接包括且不限于&#xff1a; python学opencv|读取图像-CSDN博客 python学opencv|读取图像&#xff08;三&#xff09;放大和缩小图像_python(1)使…

【数据库】四、数据库管理与维护

文章目录 四、数据库管理与维护1 安全性管理2 事务概述3 并发控制4 备份与恢复管理 四、数据库管理与维护 1 安全性管理 安全性管理是指保护数据库&#xff0c;以避免非法用户进行窃取数据、篡改数据、删除数据和破坏数据库结构等操作 三个级别认证&#xff1a; 服务器级别…

如何定位导致 Django 错误的文件

在 Django 开发中&#xff0c;当发生错误时&#xff0c;定位问题所在的文件和代码行是调试的重要步骤。以下是一些常用的方法和技巧来定位导致 Django 错误的文件&#xff1a; 1、问题背景 在项目中使用了 shrink 工具尝试运行 collect static 时&#xff0c;出现 TemplateSyn…

JavaSE——网络编程

一、InetAddress类 InetAddress是Java中用于封装IP地址的类。 获取本机的InetAddress对象&#xff1a; InetAddress localHost InetAddress.getLocalHost();根据指定的主机名获取InetAddress对象&#xff08;比如说域名&#xff09; InetAddress host InetAddress.getByNa…

在Windows环境下搭建无人机模拟器

最近要开发无人机地面站&#xff0c;但是没有无人机&#xff0c;开发无人机对我来说也是大姑娘坐花轿——头一回。我们要用 MAVLink 和无人机之间通信&#xff0c;看了几天 MAVLink&#xff0c;还是不得劲儿&#xff0c;没有实物实在是不好弄&#xff0c;所以想先装一个无人机模…