免费阅读篇 | 芒果YOLOv8改进114:上采样Dysample:顶会ICCV2023,轻量级图像增采样器,通过学习采样来学习上采样,计算资源需求小

💡🚀🚀🚀本博客 改进源代码改进 适用于 YOLOv8 按步骤操作运行改进后的代码即可

该专栏完整目录链接: 芒果YOLOv8深度改进教程
🚀🚀🚀

DySample是一个超轻量级和有效的动态上采样器,是一种更简洁、更高效的方式,用于提升图像分辨率。相较于传统的CARAFE和SAPA方法,DySample对计算资源的需求更小,能够在不增加额外负担的情况下实现图像分辨率的提升。

该篇博客为免费阅读内容,YOLOv8 + 上采样Dysample 改进内容🚀🚀🚀

文章目录

      • 1. Dysample 论文
      • 2. YOLOv8 核心代码改进部分
      • 2.1 核心新增代码
        • 2.2 代码修改部分
      • 2.3 YOLOv8-Dysample网络配置文件
      • 2.4 运行代码
      • 改进说明


1. Dysample 论文

我们介绍了 DySample,这是一款超轻量级且有效的动态上采样器。虽然最近基于内核的动态上采样器(如 CARAFE、FADE 和 SAPA)已经取得了令人印象深刻的性能提升,但它们引入了大量工作负载,这主要是由于耗时的动态卷积和用于生成动态内核的额外子网。此外,FADE和SAPA对高分辨率特征引导的需求在某种程度上限制了它们的应用场景。为了解决这些问题,我们绕过了动态卷积,从点采样的角度制定了上采样,这样可以节省资源,并且可以通过 PyTorch 中的标准内置函数轻松实现。我们首先展示了一个朴素的设计,然后演示了如何逐步加强其上采样行为,以实现我们新的上采样器 DySample。与以前基于内核的动态上采样器相比,DySample 不需要定制的 CUDA 包,参数、FLOP、GPU 内存和延迟也少得多。除了轻量级特性外,DySample 在五项密集预测任务中的表现优于其他上采样器,包括语义分割、目标检测、实例分割、全景分割和单目深度估计。
在这里插入图片描述

具体细节可以去看原论文:https://arxiv.org/pdf/2308.15085.pdf


2. YOLOv8 核心代码改进部分

2.1 核心新增代码

首先在ultralytics/nn/modules文件夹下,创建一个 dysample.py文件,新增以下代码

import torch
import torch.nn as nn
import torch.nn.functional as F


def normal_init(module, mean=0, std=1, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.normal_(module.weight, mean, std)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)


def constant_init(module, val, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.constant_(module.weight, val)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)


class DySample(nn.Module):
    def __init__(self, in_channels, scale=2, style='lp', groups=4, dyscope=False):
        super().__init__()
        self.scale = scale
        self.style = style
        self.groups = groups
        assert style in ['lp', 'pl']
        if style == 'pl':
            assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
        assert in_channels >= groups and in_channels % groups == 0

        if style == 'pl':
            in_channels = in_channels // scale ** 2
            out_channels = 2 * groups
        else:
            out_channels = 2 * groups * scale ** 2

        self.offset = nn.Conv2d(in_channels, out_channels, 1)
        normal_init(self.offset, std=0.001)
        if dyscope:
            self.scope = nn.Conv2d(in_channels, out_channels, 1)
            constant_init(self.scope, val=0.)

        self.register_buffer('init_pos', self._init_pos())

    def _init_pos(self):
        h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
        return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)

    def sample(self, x, offset):
        B, _, H, W = offset.shape
        offset = offset.view(B, 2, -1, H, W)
        coords_h = torch.arange(H) + 0.5
        coords_w = torch.arange(W) + 0.5
        coords = torch.stack(torch.meshgrid([coords_w, coords_h])
                             ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
        normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
        coords = 2 * (coords + offset) / normalizer - 1
        coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view(
            B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
        return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear',
                             align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W)

    def forward_lp(self, x):
        if hasattr(self, 'scope'):
            offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos
        else:
            offset = self.offset(x) * 0.25 + self.init_pos
        return self.sample(x, offset)

    def forward_pl(self, x):
        x_ = F.pixel_shuffle(x, self.scale)
        if hasattr(self, 'scope'):
            offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_pos
        else:
            offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_pos
        return self.sample(x, offset)

    def forward(self, x):
        if self.style == 'pl':
            return self.forward_pl(x)
        return self.forward_lp(x)
2.2 代码修改部分

第一步:
ultralytics/nn/tasks.py文件开头部分中,新增:定义在 dysample.py 里面的模块

from ultralytics.nn.modules.dysample import DySample

直接复制这段代码即可


如下图所示:
在这里插入图片描述


然后在 在tasks.py中配置
找到

        elif m is nn.BatchNorm2d:
            args = [ch[f]]

在这句上面加一个

        elif m in [DySample]:
            args = [ch[f], *args[0:]]

直接复制这段代码即可


2.3 YOLOv8-Dysample网络配置文件

新增YOLOv8-Dysample.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, DySample, []]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, DySample, []]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

直接复制这段代码即可

2.4 运行代码

直接替换YOLOv8-Dysample.yaml 进行训练即可

到这里就完成了这篇的改进。

改进说明

这里改进是放在了主干后面,如果想放在改进其他地方,也是可以的。直接新增,然后调整通道,配齐即可,如果有不懂的,可以添加博主联系方式,如下


🥇🥇🥇
添加博主联系方式:

友好的读者可以添加博主QQ: 2434798737, 有空可以回答一些答疑和问题

🚀🚀🚀


参考

https://github.com/ultralytics/ultralytics

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/464680.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

DDos攻击如何被高防服务器有效防范?

德迅云安全-领先云安全服务与解决方案提供商 什么是DDos攻击? DDos攻击是一种网络攻击手段,旨在通过使目标系统的服务不可用或中断,导致无法正常使用网络服务。DDos攻击可以采取多种方式实施,包括洪水攻击、压力测试、UDP Flood…

HTML静态网页成品作业(HTML+CSS)——游戏战地介绍设计制作(4个页面)

🎉不定期分享源码,关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 🏷️本套采用HTMLCSS,未使用Javacsript代码,共有4个页面。 二、作品演示 三、代…

关于PXIE3U18槽背板原理拓扑关系

如今IT行业日新月异,飞速发展,随之带来的是数据吞吐量的急剧升高。大数据,大存储将成为未来数据通信的主流,建立快速、大容量的数据传输通道将成为电子系统的关键。随着集成技术和互连技术的发展,新的串口技术&#xf…

【QT+QGIS跨平台编译】之七十七:【QGIS_Gui跨平台编译】—【错误处理:字符串错误】

文章目录 一、字符串错误二、处理方法三、涉及到的文件一、字符串错误 常量中有换行符错误:(也有const char * 到 LPCWSTR 转换的错误) 二、处理方法 需要把对应的文档用记事本打开,另存为 “带有BOM的UTF-8” 三、涉及到的文件 src\gui\qgsadvanceddigitizingdockwidge…

ClickHouse中的设置的分类

ClickHouse中的各种设置 ClickHouse中的设置有几百个,下面对这些设置做了一个简单的分类。

【Godot 4.2】常见几何图形、网格、刻度线点求取函数及原理总结

概述 本篇为ShapePoints静态函数库的补充和辅助文档。ShapePoints函数库是一个用于生成常见几何图形顶点数据(PackedVector2Array)的静态函数库。生成的数据可用于_draw和Line2D、Polygon2D等进行绘制和显示。因为不断地持续扩展,ShapePoint…

Orbit 使用指南 03 | 与刚体交互 | Isaac Sim | Omniverse

如是我闻: “在之前的指南中,我们讨论了独立脚本( standalone script)的基本工作原理以及如何在模拟器中生成不同的对象(prims)。在指南03中,我们将展示如何创建并与刚体进行交互。为此&#xf…

机器学习周记(第三十周:文献阅读-SageFormer)2024.3.11~2024.3.17

目录 摘要 ABSTRACT 1 论文信息 1.1 论文标题 1.2 论文摘要 1.3 论文背景 2 论文模型 2.1 问题描述 2.2 模型信息 2.2.1 Series-aware Global Tokens(序列感知全局标记) 2.2.2 Graph Structure Learning(图结构学习) …

大数据面试题之SQL题

大数据面试题之SQL题 1.有一个录取学生人数表,记录的是每年录取学生人数和入学学生的学制 以下是表结构: CREATE TABLE admit ( id int(11) NOT NULL AUTO_INCREMENT, year int(255) DEFAULT NULL COMMENT ‘入学年度’, num int(255) DEFAULT NULL COMM…

交流互动系统|基于springboot框架+ Mysql+Java+Tomcat的交流互动系统设计与实现(可运行源码+数据库+设计文档)

推荐阅读100套最新项目 最新ssmjava项目文档视频演示可运行源码分享 最新jspjava项目文档视频演示可运行源码分享 最新Spring Boot项目文档视频演示可运行源码分享 2024年56套包含java,ssm,springboot的平台设计与实现项目系统开发资源(可…

【医学图像处理】ECAT和HRRT格式转nii格式【超简单】

之前从ADNI上下载PET数据的时候发现有许多数据的格式不是DICOM的而是ECAT或者是HRRT格式,这对原本就少的PET数据是血上加霜啊。 当然只使用DICOM格式的数据也会得到不少的数据,我一开始也是只使用DICOM格式的样本,后来为了得到更多的数据&a…

2024年值得创作者关注的十大AI动画创新平台

别提找大型工作室制作动画了。如今,AI平台让我们就可以轻松制作动画。从简单的文本生动画功能到复杂的角色动作,这些平台为各种类型的创作者提供了不同的功能。 AI已经有了长足的发展,现在它可以理解复杂的人类动作和艺术意图,将简单的输入转化成丰富而详细的动画。 下面…

RoketMQ主从搭建

vim /etc/hosts# IP与域名映射,端口看自己的#nameserver 192.168.126.132 rocketmq-nameserver1 192.168.126.133 rocketmq-nameserver2# 注意主从节点不在同一个主机上 #broker 192.168.126.132 rocketmq-master1 192.168.126.133 rocketmq-master2#broker 192.168…

HarmonyOS(鸿蒙)不再适合JS语言开发

ArkTS是鸿蒙生态的应用开发语言。它在保持TypeScript(简称TS)基本语法风格的基础上,对TS的动态类型特性施加更严格的约束,引入静态类型。同时,提供了声明式UI、状态管理等相应的能力,让开发者可以以更简洁、…

用Python 3 开发的摄像头拍照程序

在当今数字化的世界中,使用摄像头进行拍照已成为日常生活的重要组成部分。无论是用于个人用途还是专业用途,能够使用电脑摄像头轻松拍照都是一项有用的技能。本文将指导您使用 Python 3 编写一个简单的程序,让您能够使用电脑摄像头拍照并将其…

如果网络不好 如何下载huggingface上的模型

很多朋友网络不太好,有时候上不了huggingface这样的国外网站; 或者网络流量不太够,想要下载一些stable diffusion模型,或者其他人工智能的大模型的时候,看到动辄几个G的模型文件,不太舍得下载;…

9. 综合案例-ATM系统 (1~7节知识综合练习)

ATM系统_综合大练习 今天的任务是对之前所有的学习的知识, 进行一个综合性的大练习. 老师说的好, 键盘敲烂 这个项目我写了大量的注释给大家参考, 如果有同学是跟着我的系列学习的, 一定动手练一练. 下面的代码只要按着敲是可以直接运行起来的, 我也把完整代码上传到了CSDN上…

Fritzing 简单使用

文章目录 1 Fritzing 资源2 Fritzing 简单使用3 添加自已的元器件3.1 面包板3.1.1 新建面包板 svg 文件3.1.2 新建面包板 3.2 原理图3.3 PCB3.4 图标3.5 使用 1 Fritzing 资源 1)官网: 开源的电子设计和原型平台:https://fritzing.org/免费开…

机试:砍树修路

问题描述 代码示例: //一坐标轴表示某道路,从0开始 到L,整数位置上都种有一颗树。现在该路修建地铁,要砍掉铁路线路上的树木。例如:L等于10,铺设4条铁路,坐标是1到2,2到3,2到8,3到…

【SpringBoot】解决数据库时间和返回时间格式不一致的问题

先看问题: 类中的属性中有Date类型的属性 数据库表中的数据: 可以看到也没问题 但是在返回实体类对象时,数据类型是这样的: 虽然数据是成功返回了,但这显然不是我们想要的结果.也不符合我们的日常使用习惯. 这个问题虽然前端,后端都能处理,但最好还是后端来进行处理.前端主…