Mamba-yolo|结合Mamba注意力机制的视觉检测

一、本文介绍

 PDF地址:https://arxiv.org/pdf/2405.16605v1

代码地址:GitHub - LeapLabTHU/MLLA: Official repository of MLLA

Demystify Mamba in Vision: A Linear AttentionPerspective一文中引入Baseline Mamba,指明Mamba在处理各种高分辨率图像的视觉任务有着很好的效率。发现了强大的Mamba和线性注意力Transformer( linear attention Transformer)非常相似,然后就分析了两者之间的异同。将Mamba模型重述为linear attention Transformer的变体,并且主要有六大差异,分别是:input gate, forget gate,shortcut, no attention normalization, single-head, and modified block design。作者对每个设计都细致的分析了优缺点,评估了性能,最终发现forget gate和block design是Mamba这么给力的主要贡献点。基于以上发现,作者提出了一个类似mamba的线性注意力模型,Mamba-Like Linear Attention (MLLA) ,相当于取其精华,去其糟粕,把mamba两个最为关键的优点设计结合到线性注意力模型当中,具有可并行计算和快速推理的特点。本文将结合YOlOV8检测模型通过添加MLLA模块提升检测精度。

二、宏观架构设计

线性注意 Transformer 模型通常采用图 (a) 中的设计,它由线性注意力模块和 MLP 模块组成。相比之下,Mamba 通过结合 H3和 Gated Attention这两个设计来改进,得到如图 (b) 所示的架构。改进的 Mamba Block 集成了多种操作,例如选择性 SSM、深度卷积、线性映射、激活函数、门控机制等,并且往往比传统的 Transformer 设计更有效。

MLLA (Mamba-Like Linear Attention)的则是通过将Mamba模型的一些核心设计融入线性注意力机制,从而提升模型的性能。具体来说,MLLA主要整合了Mamba中的"忘记门”(forget gate9)和模块设计(block design)这两个关键因素,这些因素被认为是Mamba成功的主要原因。
以下是对MLLA原理的详细分析:
1.忘记门(Forget Gate)
1.忘记门提供了局部偏差和位置信息。所有的忘记门元素严格限制在0到1之间,这意味着模型在接收到当前输入后会持续衰减失前的隐藏状态。这种特性确保了模型对输入序列的顺序敏感。
2.忘记门的局部偏差和位置信息对于图像处理任务来说非常重要,尽管引入忘记门会导致计算需要采用递归的形式,从而降低并行计算的效率。
2.模块设计(Block Design)
1.Mamba的模块设计在保持相似的浮点运算次数(FLOPS)的同时,通过替换注意力子模块为线性注意力来提升性能。结果表明,采用这种模块设计能够显著提高模型的表现。
3.线性注意力的改进:
1.线性注意力被重新设计以整合忘记门和模块设计,这种改进后的模型被称为MLLA。实验结果显示,MLLA在图像分类和高分辨率密集预测任务中均优于各种视觉Mamba模型
4.并行计算和快速推理速度:
1.MLLA通过使用位置编码(ROPE)来替代忘记门,从而在保持并行计算和快速推理速度的同时,提供必要的位置信息。这使得MLLA在处理非自回归的视觉任务时更加有效

结合yolov8改进

核心代码
 

import torch
import torch.nn as nn
 
__all__ = ['MLLAttention']
 
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
 
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
 
 
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, groups=1,
                 bias=True, dropout=0, norm=nn.BatchNorm2d, act_func=nn.ReLU):
        super(ConvLayer, self).__init__()
        self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=(kernel_size, kernel_size),
            stride=(stride, stride),
            padding=(padding, padding),
            dilation=(dilation, dilation),
            groups=groups,
            bias=bias,
        )
        self.norm = norm(num_features=out_channels) if norm else None
        self.act = act_func() if act_func else None
 
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.conv(x)
        if self.norm:
            x = self.norm(x)
        if self.act:
            x = self.act(x)
        return x
 
 
class RoPE(torch.nn.Module):
    r"""Rotary Positional Embedding.
    """
 
    def __init__(self, base=10000):
        super(RoPE, self).__init__()
        self.base = base
 
    def generate_rotations(self, x):
        # 获取输入张量的形状
        *channel_dims, feature_dim = x.shape[1:-1][0], x.shape[-1]
        k_max = feature_dim // (2 * len(channel_dims))
 
        assert feature_dim % k_max == 0, "Feature dimension must be divisible by 2 * k_max"
 
        # 生成角度
        theta_ks = 1 / (self.base ** (torch.arange(k_max, dtype=x.dtype, device=x.device) / k_max))
        angles = torch.cat([t.unsqueeze(-1) * theta_ks for t in
                            torch.meshgrid([torch.arange(d, dtype=x.dtype, device=x.device) for d in channel_dims],
                                           indexing='ij')], dim=-1)
 
        # 计算旋转矩阵的实部和虚部
        rotations_re = torch.cos(angles).unsqueeze(dim=-1)
        rotations_im = torch.sin(angles).unsqueeze(dim=-1)
        rotations = torch.cat([rotations_re, rotations_im], dim=-1)
 
        return rotations
 
    def forward(self, x):
        # 生成旋转矩阵
        rotations = self.generate_rotations(x)
 
        # 将 x 转换为复数形式
        x_complex = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2))
 
        # 应用旋转矩阵
        pe_x = torch.view_as_complex(rotations) * x_complex
 
        # 将结果转换回实数形式并展平最后两个维度
        return torch.view_as_real(pe_x).flatten(-2)
 
 
class MLLAttention(nn.Module):
    r""" Linear Attention with LePE and RoPE.
    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
    """
 
    def __init__(self, dim=3, input_resolution=[160, 160], num_heads=4, qkv_bias=True, **kwargs):
 
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.elu = nn.ELU()
        self.lepe = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
        self.rope = RoPE()
 
    def forward(self, x):
        """
        Args:
            x: input features with shape of (B, N, C)
        """
        x = x.reshape((x.size(0), x.size(2) * x.size(3), x.size(1)))
        b, n, c = x.shape
        h = int(n ** 0.5)
        w = int(n ** 0.5)
        # self.rope = RoPE(shape=(h, w, self.dim))
        num_heads = self.num_heads
        head_dim = c // num_heads
 
        qk = self.qk(x).reshape(b, n, 2, c).permute(2, 0, 1, 3)
        q, k, v = qk[0], qk[1], x
        # q, k, v: b, n, c
 
        q = self.elu(q) + 1.0
        k = self.elu(k) + 1.0
        q_rope = self.rope(q.reshape(b, h, w, c)).reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        k_rope = self.rope(k.reshape(b, h, w, c)).reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
 
        z = 1 / (q @ k.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-6)
        kv = (k_rope.transpose(-2, -1) * (n ** -0.5)) @ (v * (n ** -0.5))
        x = q_rope @ kv * z
 
        x = x.transpose(1, 2).reshape(b, n, c)
        v = v.transpose(1, 2).reshape(b, h, w, c).permute(0, 3, 1, 2)
        x = x + self.lepe(v).permute(0, 2, 3, 1).reshape(b, n, c)
        x = x.transpose(2, 1).reshape((b, c, h, w))
        return x
 
    def extra_repr(self) -> str:
        return f'dim={self.dim}, num_heads={self.num_heads}'
 
 
if __name__ == "__main__":
    # Generating Sample image
    image_size = (1, 64, 160, 160)
    image = torch.rand(*image_size)
 
    # Model
    model = MLLAttention(64)
 
    out = model(image)
    print(out.size())

修改一

第一还是建立文件,我们找到如下ultralvtics/n文件夹下建立一个目录名字呢就是'Addmodules文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。

修改二

第二步我们在该目录下创建一个新的py文件名字为'  __init__ .py,然后在其内部导入我们的检测头如
下图所示。

修改三 

第三步我门中到如下文件uitralytics/nn/tasks.py进行导入和注册我们的模块

修改四

按照我的添加在parse model里添加即可。

修改5

修改6 配置yolov8-MLLA.yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
 
# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOP
 
# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
 
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12
 
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
 
 
  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)
 
 
  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)
  - [-1, 1, MLLAttention, []]  # 22 (P5/32-large) # 添加在大目标检测层后!
 
  - [[15, 18, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

7. 训练代码

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO
 
if __name__ == '__main__':
    model = YOLO('yolov8-MLLA.yaml')
    # 如何切换模型版本, 上面的ymal文件可以改为 yolov8s.yaml就是使用的v8s,
    # 类似某个改进的yaml文件名称为yolov8-XXX.yaml那么如果想使用其它版本就把上面的名称改为yolov8l-XXX.yaml即可(改的是上面YOLO中间的名字不是配置文件的)!
    # model.load('yolov8n.pt') # 是否加载预训练权重,科研不建议大家加载否则很难提升精度
    model.train(data=r"C:\Users\Administrator\PycharmProjects\yolov5-master\yolov5-master\Construction Site Safety.v30-raw-images_latestversion.yolov8\data.yaml",
                # 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, pose
                cache=False,
                imgsz=640,
                epochs=150,
                single_cls=False,  # 是否是单类别检测
                batch=16,
                close_mosaic=0,
                workers=0,
                device='0',
                optimizer='SGD', # using SGD
                # resume='runs/train/exp21/weights/last.pt', # 如过想续训就设置last.pt的地址
                amp=True,  # 如果出现训练损失为Nan可以关闭amp
                project='runs/train',
                name='exp',
                )

8.开启训练

专栏推荐

专栏将持续收集整理市场上深度学习的相关项目,旨在为准备从事深度学习工作或相关科研活动的伙伴,储备、提升更多的实际开发经验,每个项目实例都可作为实际开发项目写入简历,且都附带完整的代码与数据集。可通过百度云盘进行获取,实现开箱即用

正在跟新中~

深度学习落地实战_机 _ 长的博客-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/870087.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Modbus转BACnet/IP网关的技术实现与应用

引言 随着智能建筑和工业自动化的快速发展,不同通信协议之间的数据交换也变得日益重要。Modbus和BACnet/IP是两种广泛应用于自动化领域的通信协议,Modbus以其简单性和灵活性被广泛用于工业自动化,而BACnet/IP则在楼宇自动化系统中占据主导地…

华为网络模拟器eNSP安装部署教程

eNSP是图形化网络仿真平台,该平台通过对真实网络设备的仿真模拟,帮助广大ICT从业者和客户快速熟悉华为数通系列产品,了解并掌握相关产品的操作和配置、提升对企业ICT网络的规划、建设、运维能力,从而帮助企业构建更高效&#xff0…

【日常记录】【JS】JS中查询参数处理工具URLSearchParams

文章目录 1. 引言2. URLSearchParams2.1 URLSearchParams 的构造函数2.2 append() 方法2.3 delete() 方法2.4 entries() 方法2.5 forEach() 方法2.6 get() 方法2.7 getAll() 方法2.8 has() 方法2.9 keys() 方法2.10 set() 方法2.11 toString() 方法2.12 values() 方法 参考链接…

懒人精灵安卓版纯本地离线文字识别插件

目的 懒人精灵是一款可以模拟鼠标和键盘操作的自动化工具。它可以帮助用户自动完成一些重复的、繁琐的任务,节省大量人工操作的时间。懒人精灵也包含图色功能,识别屏幕上的图像,根据图像的变化自动执行相应的操作。本篇文章主要讲解下更优秀的…

2019数字经济公测大赛-VMware逃逸

文章目录 环境搭建漏洞点exp 环境搭建 ubuntu :18.04.01vmware: VMware-Workstation-Full-15.5.0-14665864.x86_64.bundle 这里环境搭不成功。。patch过后就报错,不知道咋搞 发现可能是IDA加载后的patch似乎不行对原来的patch可能有影响,重新下了patch&…

通信原理-思科实验三:无线局域网实验

实验三 无线局域网实验 一:无线局域网基础服务集 实验步骤: 进入物理工作区,导航选择 城市家园; 选择设备 AP0,并分别选择Laptop0、Laptop1放在APO范围外区域 修改笔记本的网卡,从以太网卡切换到无线网卡WPC300N 切…

Web前端:HTML篇(一)

HTML简介: 超文本标记语言(英语:HyperText Markup Language,简称:HTML)是一种用于创建网页的标准标记语言。 您可以使用 HTML 来建立自己的 WEB 站点,HTML 运行在浏览器上,由浏览器…

集合的面试题和五种集合的详细讲解

20240724 一、面试题节选二、来自于b站人人都是程序员的视频截图 (感谢人人都是程序员大佬的视频,针对于个人复习。) 一、面试题节选 二、来自于b站人人都是程序员的视频截图 hashmap: 唯一的缺点,无序&#xf…

【JavaScript】`Map` 数据结构

文章目录 一、Map 的基本概念二、常见操作三、与对象的对比四、实际应用场景 在现代 JavaScript 中,Map 是一种非常重要且强大的数据结构。与传统的对象(Object)不同,Map 允许您使用各种类型的值作为键,不限于字符串或…

DjangoRF实战-2-apps-users

1、用户模块 创建一个用户模块子应用,用来管理用户,和认证和授权。 1.1根目录创建apps, 为了使用方便,还需要再pycharm中设置一下资源路径,就可以自动提示 1.2注册子应用 1.3添加应用根目录到环境变量path python导…

7月21日,贪心练习

大家好呀,今天带来一些贪心算法的应用解题、 一,柠檬水找零 . - 力扣(LeetCode) 解析: 本题的贪心体现在对于20美元的处理上,我们总是优先把功能较少的10元作为找零,这样可以让5元用处更大 …

Golang实现免费天气预报获取(OpenWeatherMap)

最近接到公司的一个小需求,需要天气数据,所以就做了一个小接口,供前端调用 这些数据包括六个元素,如降水、风、大气压力、云量和温度。有了这些,你可以分析趋势,知道明天的数据来预测天气。 1.1 工具简介 …

Linux 安装 GDB (无Root 权限)

引入 在Linux系统中,如果你需要在集群或者远程操作没有root权限的机子,安装GDB(GNU调试器)可能会有些限制,因为通常安装新软件或更新系统文件需要管理员权限。下面我们介绍可以在没有root权限的情况下安装GDB&#xf…

vue3响应式用法(高阶性能优化)

文章目录 前言:一、 shallowRef()二、 triggerRef()三、 customRef()四、 shallowReactive()五、 toRaw()六、 markRaw()七、 shallowReadonly()小结: 前言: 翻别人代码时,总结发现极大部分使用vue3的人只会用ref和reactive处理响…

谷歌AI拿下IMO奥数银牌!6道题轻松解出4道~

本周四,谷歌DeepMind团队宣布了一项令人瞩目的成就::用 AI 做出了今年国际数学奥林匹克竞赛 IMO 的真题,并且距拿金牌仅一步之遥。这一成绩不仅标志着人工智能在数学推理领域的重大突破,也引发了全球范围内的广泛关注和…

时序分解 | Matlab基于CEEMDAN-CPO-VMD的CEEMDAN结合冠豪猪优化算法(CPO)优化VMD二次分解

时序分解 | Matlab基于CEEMDAN-CPO-VMD的CEEMDAN结合冠豪猪优化算法(CPO)优化VMD二次分解 目录 时序分解 | Matlab基于CEEMDAN-CPO-VMD的CEEMDAN结合冠豪猪优化算法(CPO)优化VMD二次分解效果一览基本介绍程序设计参考资料 效果一览…

leetcode-148. 排序链表

题目描述 给你链表的头结点 head ,请将其按 升序 排列并返回 排序后的链表 。 示例 1: 输入:head [4,2,1,3] 输出:[1,2,3,4]示例 2: 输入:head [-1,5,3,4,0] 输出:[-1,0,3,4,5]示例 3&#x…

UFO:革新Windows操作系统交互的UI聚焦代理

人工智能咨询培训老师叶梓 转载标明出处 人机交互的便捷性和效率直接影响着我们的工作和生活质量。尽管现代操作系统如Windows提供了丰富的图形用户界面(GUI),使得用户能够通过视觉和简单的点击操作来控制计算机,但随着应用程序功…

javaEE-04-Filter

文章目录 FilterFilter 的生命周期FilterConfig类FilterChain过滤器链Filter 的拦截路径 Filter Filter 过滤器它是 JavaWeb 的三大组件之一,它是 JavaEE 的规范。也就是接口,它的作用是:拦截请求,过滤响应。 Filter的工作流程图解: 以管…

HarmonyOS NEXT零基础入门到实战-第四部分

自定义组件: 概念: 由框架直接提供的称为 系统组件, 由开发者定义的称为 自定义组件。 源代码: Component struct MyCom { build() { Column() { Text(我是一个自定义组件) } } } Component struct MyHeader { build() { Row(…