【RT-DETR有效改进】遥感旋转网络 | LSKNet动态的空间感受野网络(轻量又提点)

前言

大家好,我是Snu77,这里是RT-DETR有效涨点专栏

本专栏的内容为根据ultralytics版本的RT-DETR进行改进,内容持续更新,每周更新文章数量3-10篇。

专栏以ResNet18、ResNet50为基础修改版本,同时修改内容也支持ResNet32、ResNet101和PPHGNet版本,其中ResNet为RT-DETR官方版本1:1移植过来的,参数量基本保持一致(误差很小很小),不同于ultralytics仓库版本的ResNet官方版本,同时ultralytics仓库的一些参数是和RT-DETR相冲的所以我也是会教大家调好一些参数和代码,真正意义上的跑ultralytics的和RT-DETR官方版本的无区别

一、本文介绍

本文给大家带来的改进内容是LSKNet(Large Kernel Selection, LK Selection),其是一种专为遥感目标检测设计的网络架构,其核心思想是动态调整其大的空间感受野,以更好地捕捉遥感场景中不同对象的范围上下文。实验部分我在一个包含三十多个类别的数据集上进行实验,其中包含大目标检测和小目标检测,mAP的平均涨点幅度在0.04-0.1之间(也有极个别的情况没有涨点),同时官方的版本只提供了一个大版本,我在其基础上提供一个轻量化版本给大家选择,本文会先给大家对比试验的结果,供大家参考,该网络的mAP提高大概在2个点,参数量也有一定程度上的下降可以起到轻量化的作用。

 专栏链接:RT-DETR剑指论文专栏,持续复现各种顶会内容——论文收割机RT-DETR

目录

一、本文介绍

二、LSKNet原理

2.1  LSKNet的基本原理

2.2 大型核选择(LK Selection)子块

2.3 前馈网络(FFN)子块

三、LSKNet核心代码

 四、手把手教你添加LSKNet机制

4.1 修改一

4.2 修改二 

4.3 修改三 

4.4 修改四

4.5 修改五

4.6 修改六

4.7 修改七 

4.8 修改八

4.9 RT-DETR不能打印计算量问题的解决

4.10 可选修改

五、LSKNet的yaml文件

5.1 yaml文件

5.2 运行文件

5.3 成功训练截图

六、全文总结


二、LSKNet原理

论文地址: 官方论文地址

代码地址: 官方代码地址


2.1  LSKNet的基本原理

LSKNet(Large Selective Kernel Network)是一种专为遥感目标检测设计的网络架构,其核心优势在于能够动态调整其大的空间感受野,以更好地捕捉遥感场景中不同对象的范围上下文。这是第一次在遥感目标检测领域探索大型和选择性核机制。

LSKNet(大型选择性核网络)的基本原理包括以下关键组成部分:

1. 大型核选择(LK Selection)子块:这个子块能够动态地调整网络的感受野,以便根据需要捕获不同尺度的上下文信息。这使得网络能够根据遥感图像中对象的不同尺寸和复杂性调整其处理能力。

2. 前馈网络(FFN)子块:该子块用于通道混合和特征精炼。它由一个完全连接的层、一个深度卷积、一个GELU激活函数以及第二个完全连接的层组成。这些组件一起工作,提高了特征的质量并为分类和检测提供了必要的信息。

这两个子块共同构成LSKNet块,能够提供大范围的上下文信息,同时保持对细节的敏感度,这对于遥感目标检测尤其重要。

下面我将为大家展示四种不同的选择性机制模块的架构比较:

对于LSK模块:

1. 有一个分解步骤,似乎是用来处理大尺寸的卷积核(Large K)。
2. 接着是一个空间选择*步骤,可能用于选择或优化空间信息的特定部分。

这与其他三种模型的架构相比较,显示了LSK模块在处理空间信息方面可能有其独特的方法。具体来说,LSK模块似乎强调了在大尺寸卷积核上进行操作,这可能有助于捕获遥感图像中较大范围的上下文信息,这对于检测图像中的对象特别有用。空间选择步骤可能进一步增强了模型对于输入空间特征的选择能力,从而使其能够更加有效地聚焦于图像的重要部分。


2.2 大型核选择(LK Selection)子块

LSKNet的大型核选择(Large Kernel Selection, LK Selection)子块是其架构的核心组成部分之一。这个子块的功能是根据需要动态调整网络的感受野大小。通过这种方式,LSKNet能够根据遥感图像中不同对象的大小和上下文范围,调整处理这些对象所需的空间信息范围。

大型核选择子块与前馈网络(Feed-forward Network, FFN)子块一起工作。FFN子块用于通道混合和特征细化,它包括一个序列,这个序列由一个全连接层、一个深度卷积、一个GELU激活函数以及第二个全连接层组成。这种设计允许LSKNet块进行特征深度融合和增强,进一步提升了遥感目标检测的性能

下面我将通过LSK(Large Selective Kernel)模块的概念性插图,展示LSKNet如何通过大型核选择子块和空间选择机制来处理遥感数据,从而使网络能够适应不同对象的长范围上下文需求。

1. Large Kernel Decomposition:原始输入X经过大核分解,使用两种不同的大型卷积核(Large K)进行处理,以捕获不同尺度的空间信息。

2. Channel Concatenation:两个不同的卷积输出U_1U_2通过通道拼接组合在一起,这样可以在后续步骤中同时利用不同的空间特征。

3. Mixed Pooling:拼接后的特征图经过平均池化和最大池化的组合操作,然后与自注意力(SA)机制一起使用,以进一步强化特征图的关键区域。

4. Convolution and Spatial Selection:通过卷积操作和自注意力(SA)生成新的特征图,然后通过空间选择机制进一步增强对目标区域的关注。

5. Element Product and Sigmoid:使用Sigmoid函数生成一个掩码S,然后将这个掩码与特征图F进行元素乘积操作,得到最终的输出特征图Y。这一步骤用于加权特征图中更重要的区域,以增强网络对遥感图像中特定对象的检测能力。

整个LSK模块的设计强调了对遥感图像中不同空间尺度和上下文信息的有效捕获,这对于在复杂背景下准确检测小型或密集排布的目标至关重要。通过上述步骤的复合操作,LSK模块能够提升遥感目标检测的性能。


2.3 前馈网络(FFN)子块

LSKNet的前馈网络(Feed-forward Network, FFN)子块用于通道混合和特征精炼。该子块包含以下组成部分:

1. 全连接层:用于特征变换,提供网络额外的学习能力。
2. 深度卷积(depth-wise convolution):用于在通道间独立地应用空间滤波,减少参数量的同时保持效果。
3. GELU激活函数:一种高斯误差线性单元,用于引入非线性,提高模型的表达能力。
4. 第二个全连接层:进一步变换和精炼特征。

这个FFN子块紧随LK Selection子块之后,作用是在保持特征空间信息的同时,增强网络在特征通道上的表示能力。通过这种设计,FFN子块有效地对输入特征进行了深度加工,提升了最终特征的质量,从而有助于提高整个网络在遥感目标检测任务中的性能。


三、LSKNet核心代码

将此代码复制粘贴到''ultralytics/nn/modules''目录下新建一个py文件我这里起名字为LSKNet.py,然后把代码复制粘贴进去即可,使用教程看章节四。

import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair as to_2tuple
from timm.models.layers import DropPath, to_2tuple
from functools import partial
import warnings



class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.dwconv(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class LSKblock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
        self.conv1 = nn.Conv2d(dim, dim // 2, 1)
        self.conv2 = nn.Conv2d(dim, dim // 2, 1)
        self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3)
        self.conv = nn.Conv2d(dim // 2, dim, 1)

    def forward(self, x):
        attn1 = self.conv0(x)
        attn2 = self.conv_spatial(attn1)

        attn1 = self.conv1(attn1)
        attn2 = self.conv2(attn2)

        attn = torch.cat([attn1, attn2], dim=1)
        avg_attn = torch.mean(attn, dim=1, keepdim=True)
        max_attn, _ = torch.max(attn, dim=1, keepdim=True)
        agg = torch.cat([avg_attn, max_attn], dim=1)
        sig = self.conv_squeeze(agg).sigmoid()
        attn = attn1 * sig[:, 0, :, :].unsqueeze(1) + attn2 * sig[:, 1, :, :].unsqueeze(1)
        attn = self.conv(attn)
        return x * attn


class Attention(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.proj_1 = nn.Conv2d(d_model, d_model, 1)
        self.activation = nn.GELU()
        self.spatial_gating_unit = LSKblock(d_model)
        self.proj_2 = nn.Conv2d(d_model, d_model, 1)

    def forward(self, x):
        shorcut = x.clone()
        x = self.proj_1(x)
        x = self.activation(x)
        x = self.spatial_gating_unit(x)
        x = self.proj_2(x)
        x = x + shorcut
        return x


class Block(nn.Module):
    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_cfg=None):
        super().__init__()
        if norm_cfg:
            self.norm1 = nn.BatchNorm2d(norm_cfg, dim)
            self.norm2 = nn.BatchNorm2d(norm_cfg, dim)
        else:
            self.norm1 = nn.BatchNorm2d(dim)
            self.norm2 = nn.BatchNorm2d(dim)
        self.attn = Attention(dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        layer_scale_init_value = 1e-2
        self.layer_scale_1 = nn.Parameter(
            layer_scale_init_value * torch.ones((dim)), requires_grad=True)
        self.layer_scale_2 = nn.Parameter(
            layer_scale_init_value * torch.ones((dim)), requires_grad=True)

    def forward(self, x):
        x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
        return x


class OverlapPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768, norm_cfg=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2))
        if norm_cfg:
            self.norm = nn.BatchNorm2d(norm_cfg, embed_dim)
        else:
            self.norm = nn.BatchNorm2d(embed_dim)

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = self.norm(x)
        return x, H, W



class LSKNet(nn.Module):
    def __init__(self, img_size=224, in_chans=3, dim=None,  embed_dims=[64, 128, 256, 512],
                 mlp_ratios=[8, 8, 4, 4], drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
                 depths=[3, 4, 6, 3], num_stages=4,
                 pretrained=None,
                 init_cfg=None,
                 norm_cfg=None):
        super().__init__()
        assert not (init_cfg and pretrained), \
            'init_cfg and pretrained cannot be set at the same time'
        if isinstance(pretrained, str):
            warnings.warn('DeprecationWarning: pretrained is deprecated, '
                          'please use "init_cfg" instead')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
        elif pretrained is not None:
            raise TypeError('pretrained must be a str or None')
        self.depths = depths
        self.num_stages = num_stages

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        cur = 0

        for i in range(num_stages):
            patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
                                            patch_size=7 if i == 0 else 3,
                                            stride=4 if i == 0 else 2,
                                            in_chans=in_chans if i == 0 else embed_dims[i - 1],
                                            embed_dim=embed_dims[i], norm_cfg=norm_cfg)

            block = nn.ModuleList([Block(
                dim=embed_dims[i], mlp_ratio=mlp_ratios[i], drop=drop_rate, drop_path=dpr[cur + j], norm_cfg=norm_cfg)
                for j in range(depths[i])])
            norm = norm_layer(embed_dims[i])
            cur += depths[i]

            setattr(self, f"patch_embed{i + 1}", patch_embed)
            setattr(self, f"block{i + 1}", block)
            setattr(self, f"norm{i + 1}", norm)

        self.width_list = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]



    def freeze_patch_emb(self):
        self.patch_embed1.requires_grad = False

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'}  # has pos_embed may be better

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        B = x.shape[0]
        outs = []
        for i in range(self.num_stages):
            patch_embed = getattr(self, f"patch_embed{i + 1}")
            block = getattr(self, f"block{i + 1}")
            norm = getattr(self, f"norm{i + 1}")
            x, H, W = patch_embed(x)
            for blk in block:
                x = blk(x)
            x = x.flatten(2).transpose(1, 2)
            x = norm(x)
            x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
            outs.append(x)
        return outs

    def forward(self, x):
        x = self.forward_features(x)
        # x = self.head(x)
        return x


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x):
        x = self.dwconv(x)
        return x


def _conv_filter(state_dict, patch_size=16):
    """ convert patch embedding weight from manual patchify + linear proj to conv"""
    out_dict = {}
    for k, v in state_dict.items():
        if 'patch_embed.proj.weight' in k:
            v = v.reshape((v.shape[0], 3, patch_size, patch_size))
        out_dict[k] = v

    return out_dict

def LSKNET_Tiny():
    model = LSKNet(depths = [2, 2, 2, 2])
    return model


if __name__ == '__main__':
    model = LSKNet()
    inputs = torch.randn((1, 3, 640, 640))
    for i in model(inputs):
        print(i.size())


 四、手把手教你添加LSKNet机制

下面教大家如何修改该网络结构,主干网络结构的修改步骤比较复杂,我也会将task.py文件上传到CSDN的文件中,大家如果自己修改不正确,可以尝试用我的task.py文件替换你的,然后只需要修改其中的第1、2、3、5步即可。

⭐修改过程中大家一定要仔细⭐


4.1 修改一

首先我门中到如下“ultralytics/nn”的目录,我们在这个目录下在创建一个新的目录,名字为'Addmodules'(此文件之后就用于存放我们的所有改进机制),之后我们在创建的目录内创建一个新的py文件复制粘贴进去 ,可以根据文章改进机制来起,这里大家根据自己的习惯命名即可。


4.2 修改二 

第二步我们在我们创建的目录内创建一个新的py文件名字为'__init__.py'(只需要创建一个即可),然后在其内部导入我们本文的改进机制即可,其余代码均为未发大家没有不用理会!


4.3 修改三 

第三步我门中到如下文件'ultralytics/nn/tasks.py'然后在开头导入我们的所有改进机制(如果你用了我多个改进机制,这一步只需要修改一次即可)


4.4 修改四

添加如下两行代码!!!


4.5 修改五

找到七百多行大概把具体看图片,按照图片来修改就行,添加红框内的部分,注意没有()只是函数名(此处我的文件里已经添加很多了后期都会发出来,大家没有的不用理会即可)。

        elif m in {自行添加对应的模型即可,下面都是一样的}:
            m = m(*args)
            c2 = m.width_list  # 返回通道列表
            backbone = True


4.6 修改六

用下面的代码替换红框内的内容。 

if isinstance(c2, list):
    m_ = m
    m_.backbone = True
else:
    m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # module
    t = str(m)[8:-2].replace('__main__.', '')  # module type
m.np = sum(x.numel() for x in m_.parameters())  # number params
m_.i, m_.f, m_.type = i + 4 if backbone else i, f, t  # attach index, 'from' index, type
if verbose:
    LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f}  {t:<45}{str(args):<30}')  # print
save.extend(
    x % (i + 4 if backbone else i) for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelist
layers.append(m_)
if i == 0:
    ch = []
if isinstance(c2, list):
    ch.extend(c2)
    if len(c2) != 5:
        ch.insert(0, 0)
else:
    ch.append(c2)


4.7 修改七 

修改七这里非常要注意,不是文件开头YOLOv8的那predict,是400+行的RTDETR的predict!!!初始模型如下,用我给的代码替换即可!!!

代码如下->

 def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
        """
        Perform a forward pass through the model.

        Args:
            x (torch.Tensor): The input tensor.
            profile (bool, optional): If True, profile the computation time for each layer. Defaults to False.
            visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
            batch (dict, optional): Ground truth data for evaluation. Defaults to None.
            augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
            embed (list, optional): A list of feature vectors/embeddings to return.

        Returns:
            (torch.Tensor): Model's output tensor.
        """
        y, dt, embeddings = [], [], []  # outputs
        for m in self.model[:-1]:  # except the head part
            if m.f != -1:  # if not from previous layer
                x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            if hasattr(m, 'backbone'):
                x = m(x)
                if len(x) != 5:  # 0 - 5
                    x.insert(0, None)
                for index, i in enumerate(x):
                    if index in self.save:
                        y.append(i)
                    else:
                        y.append(None)
                x = x[-1]  # 最后一个输出传给下一层
            else:
                x = m(x)  # run
                y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
            if embed and m.i in embed:
                embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1))  # flatten
                if m.i == max(embed):
                    return torch.unbind(torch.cat(embeddings, 1), dim=0)
        head = self.model[-1]
        x = head([y[j] for j in head.f], batch)  # head inference
        return x

4.8 修改八

我们将下面的s用640替换即可,这一步也是部分的主干可以不修改,但有的不修改就会报错,所以我们还是修改一下。


4.9 RT-DETR不能打印计算量问题的解决

计算的GFLOPs计算异常不打印,所以需要额外修改一处, 我们找到如下文件'ultralytics/utils/torch_utils.py'文件内有如下的代码按照如下的图片进行修改,大家看好函数就行,其中红框的640可能和你的不一样, 然后用我给的代码替换掉整个代码即可。

def get_flops(model, imgsz=640):
    """Return a YOLO model's FLOPs."""
    try:
        model = de_parallel(model)
        p = next(model.parameters())
        # stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32  # max stride
        stride = 640
        im = torch.empty((1, 3, stride, stride), device=p.device)  # input image in BCHW format
        flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1E9 * 2 if thop else 0  # stride GFLOPs
        imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz]  # expand if int/float
        return flops * imgsz[0] / stride * imgsz[1] / stride  # 640x640 GFLOPs
    except Exception:
        return 0


4.10 可选修改

有些读者的数据集部分图片比较特殊,在验证的时候会导致形状不匹配的报错,如果大家在验证的时候报错形状不匹配的错误可以固定验证集的图片尺寸,方法如下 ->

找到下面这个文件ultralytics/models/yolo/detect/train.py然后其中有一个类是DetectionTrainer class中的build_dataset函数中的一个参数rect=mode == 'val'改为rect=False


五、LSKNet的yaml文件

5.1 yaml文件

大家复制下面的yaml文件,然后通过我给大家的运行代码运行即可,RT-DETR的调参部分需要后面的文章给大家讲,现在目前免费给大家看这一部分不开放。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, LSKNET_Tiny, []]  # 4

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 5 input_proj.2
  - [-1, 1, AIFI, [1024, 8]] # 6
  - [-1, 1, Conv, [256, 1, 1]]  # 7, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 8
  - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 9 input_proj.1
  - [[-2, -1], 1, Concat, [1]] # 10
  - [-1, 3, RepC3, [256, 0.5]]  # 11, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]]   # 12, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 13
  - [2, 1, Conv, [256, 1, 1, None, 1, 1, False]]  # 14 input_proj.0
  - [[-2, -1], 1, Concat, [1]]  # 15 cat backbone P4
  - [-1, 3, RepC3, [256, 0.5]]    # X3 (16), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]]   # 17, downsample_convs.0
  - [[-1, 12], 1, Concat, [1]]  # 18 cat Y4
  - [-1, 3, RepC3, [256, 0.5]]    # F4 (19), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]]   # 20, downsample_convs.1
  - [[-1, 7], 1, Concat, [1]]  # 21 cat Y5
  - [-1, 3, RepC3, [256, 0.5]]    # F5 (22), pan_blocks.1

  - [[16, 19, 22], 1, RTDETRDecoder, [nc, 256, 300, 4, 8, 3]]  # Detect(P3, P4, P5)


5.2 运行文件

大家可以创建一个train.py文件将下面的代码粘贴进去然后替换你的文件运行即可开始训练。

import warnings
from ultralytics import RTDETR
warnings.filterwarnings('ignore')

if __name__ == '__main__':
    model = RTDETR('替换你想要运行的yaml文件')
    # model.load('') # 可以加载你的版本预训练权重
    model.train(data=r'替换你的数据集地址即可',
                cache=False,
                imgsz=640,
                epochs=72,
                batch=4,
                workers=0,
                device='0',
                project='runs/RT-DETR-train',
                name='exp',
                # amp=True
                )


5.3 成功训练截图

下面是成功运行的截图(确保我的改进机制是可用的),已经完成了有1个epochs的训练,图片太大截不全第2个epochs了。 


六、全文总结

从今天开始正式开始更新RT-DETR剑指论文专栏,本专栏的内容会迅速铺开,在短期呢大量更新,价格也会乘阶梯性上涨,所以想要和我一起学习RT-DETR改进,可以在前期直接关注,本文专栏旨在打造全网最好的RT-DETR专栏为想要发论文的家进行服务。

 专栏链接:RT-DETR剑指论文专栏,持续复现各种顶会内容——论文收割机RT-DETR

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/339741.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

掌上单片机实验室 – 低分辨率编码器测速方式完善(24)

一、背景 本以为“掌上单片机实验室”这一主题已告一段落&#xff0c;可最近在测试一批新做的“轮式驱动单元”时&#xff0c;发现原来的测速算法存在问题。 起因是&#xff1a;由于轮式驱动单元的连线较长&#xff0c;PCB体积也小&#xff0c;导致脉冲信号有干扰&#xff0c;加…

Linux系统监控:保障稳定性与性能的关键

Linux操作系统作为广泛应用于服务器和嵌入式设备的开源操作系统&#xff0c;对于系统监控的需求尤为重要。通过对Linux系统进行有效的监控&#xff0c;管理员可以实时了解系统的运行状态、识别潜在问题并采取相应的措施。本文将介绍Linux系统监控的基本原理、常用工具和关键指标…

央视推荐的护眼台灯是哪款?学生专用台灯第一品牌

近视问题在我国十分严重&#xff0c;据相关调查数据显示&#xff0c;我国有7亿近视人口。特别是在现代&#xff0c;青少年成为近视高发人群&#xff0c;其中大部分近视的原因与长时间不正确用眼导致的眼睛疲劳有关。台灯作为许多家庭中的小家电&#xff0c;不论是上班族还是孩子…

一文搞懂Microsoft Copilot品种及定价说明

Microsoft Copilot 是一个 AI 助手&#xff0c;提供跨 Microsoft Cloud 的创新解决方案。Copilot 使复杂的任务更易于管理&#xff0c;从而促进协作环境并增强用户体验。 目前Copilot一共有这么几种&#xff1a; 一、必应中的copilot 在edge浏览器侧边栏中使用&#xff0c;这…

ESP32-TCP服务端(Arduino)

将ESP32设置为TCP服务器 介绍 TCP&#xff08;Transmission Control Protocol&#xff09;传输控制协议&#xff0c;是一种面向连接的&#xff08;一个客户端对应一个服务端&#xff09;、可靠的传输层协议。在TCP的工作原理中&#xff0c;它会将消息或文件分解为更小的片段&a…

通俗易懂理解小波池化/WaveCNet

重要说明&#xff1a;本文从网上资料整理而来&#xff0c;仅记录博主学习相关知识点的过程&#xff0c;侵删。 一、参考资料 github代码&#xff1a;WaveCNet 通俗易懂理解小波变换(Wavelet Transform) 二、相关介绍 关于小波变换的详细介绍&#xff0c;请参考另一篇博客&…

大模型学习之书生·浦语大模型6——基于OpenCompass大模型评测

基于OpenCompass大模型评测 关于评测的三个问题Why/What/How Why What 有许多任务评测&#xff0c;包括垂直领域 How 包含客观评测和主观评测&#xff0c;其中主观评测分人工和模型来评估。 提示词工程 主流评测框架 OpenCompass 能力框架 模型层能力层方法层工具层 支持丰富…

使用Go发送HTTP GET请求

在Go语言中&#xff0c;我们可以使用net/http包来发送HTTP GET请求。以下是一个简单的示例&#xff0c;展示了如何使用Go发送HTTP GET请求并获取响应。 go复制代码 package main import ( "fmt" "io/ioutil" "net/http" …

用BK7251播放音乐

单片机的第一道难关无疑是烧录&#xff0c;如果烧录解决了&#xff0c;那么就有资格挑战各种坑了。 BK7251播放MP3 一、折腾材料 1、软件SDK&#xff1a; bk7251_audio_release_20190826_0701&#xff08;BK7251 rtt sdk&#xff09;&#xff0c;可以从github&#xff0c;gite…

HCIP网络的类型

一.网络类型&#xff1a; 点到点 BMA&#xff1a;广播型多路访问 -- 在一个MA网络中同时存在广播&#xff08;泛洪&#xff09;机制 NBMA&#xff1a;非广播型多路访问 -- 在一个MA网络中&#xff0c;没有泛洪机制-----不怎么使用了 MA&#xff1a;多路访问 -- 在一个…

基于光口的以太网 udp 回环实验

文章目录 前言一、系统框架整体设计二、系统工程及 IP 创建三、UDP回环模块修改说明四、接口讲解五、顶层模块设计六、下载验证前言 本章实验我们通过网络调试助手发送数据给 FPGA,FPGA通过光口接收数据并将数据使用 UDP 协议发送给电脑。 提示:任何文章不要过度深思!万事万…

电工技术实验-电路元件伏安特性测绘

一、 实验目的 1、学会识别常用电路元件的方法 2、验证线性电阻、非线性电阻元件的伏安特性 3、熟悉实验台上直流电工仪表和设备的使用方法 二、实验器材 可调直流稳压电源、直流数字毫安表、直流数字电压表、万用表 二极管、稳压管、白炽灯、线性电阻 三、实验原理 任…

低压防雷箱综合选型应用方案

低压防雷箱是一种用于保护低压配电系统免受雷电过电压的影响的装置&#xff0c;它主要由防雷箱模块、浪涌保护器SPD、接地线等组成。本文将介绍低压防雷箱的作用原理和行业应用解决方案&#xff0c;以及低压防雷箱的选型方法。 低压防雷箱的作用原理 低压防雷箱的作用原理是利…

革新区块链:代理合约与智能合约升级的未来

作者 张群&#xff08;赛联区块链教育首席讲师&#xff0c;工信部赛迪特聘资深专家&#xff0c;CSDN认证业界专家&#xff0c;微软认证专家&#xff0c;多家企业区块链产品顾问&#xff09;关注张群&#xff0c;为您提供一站式区块链技术和方案咨询。 代理合约&#xff08;Prox…

职业规划,软件开发工程师的岗位任职资格

软件工程师是指从事软件开发的人&#xff0c;主要的工作涉及到项目培训和项目设计两个方面。在实际工作中&#xff0c;软件工程师是一个广义的概念&#xff0c;包括了很多与软件相关的人员。除开最基础的编程语言&#xff0c;还有数据库语言等等。从事这份工作&#xff0c;需要…

多标签节点分类

Multi-Label Node Classification on Graph-Structured Data,TMLR’23 Code 学习笔记 图结构数据的多标签分类 节点表示或嵌入方法 通常会生成查找表&#xff0c;以便将相似的节点嵌入的更近。学习到的表示用作各种下游预测模块的输入特征。 表现突出的方法是基于随机游走(ran…

【Spring 篇】MyBatis注解开发:编写你的数据乐章

欢迎来到MyBatis的音乐殿堂&#xff01;在这个充满节奏和韵律的舞台上&#xff0c;注解是我们编写数据乐章的得力助手。无需繁琐的XML配置&#xff0c;通过简单而强大的注解&#xff0c;你将能够轻松地与数据库交互。在这篇博客中&#xff0c;我们将深入探讨MyBatis注解开发的精…

MySQL数据库 | 事务中的一些问题(重点)

文章目录 什么是事务&#xff1f;事务的几个特性(ACID) -重点原子性(Atomicity)一致性(Consistency)隔离性(Isolation)持久性(Durability) Mysql中事务操作隐式事务显式事务 savepoint关键字只读事务事务中的一些问题&#xff08;重点&#xff09;隔离级别脏读解决办法 幻读解决…

C语言实战系列一:经典贪食蛇

C语言学习必须实战&#xff0c;并且学完语法后就必须立即用实战来巩固。一般需要10来个比较复杂的程序才能掌握C语言。今天就教大家第一个小程序&#xff0c;贪食蛇。 首先上代码 一、代码 #include <stdio.h> #include <stdlib.h> #include <curses.h> #…

Leetcode的AC指南 —— 栈与队列:20. 有效的括号

摘要&#xff1a; **Leetcode的AC指南 —— 栈与队列&#xff1a;20. 有效的括号 **。题目介绍&#xff1a;给定一个只包括 ‘(’&#xff0c;‘)’&#xff0c;‘{’&#xff0c;‘}’&#xff0c;‘[’&#xff0c;‘]’ 的字符串 s &#xff0c;判断字符串是否有效。 有效字…