【Block总结】DynamicFilter,动态滤波器降低计算复杂度,替换传统的MHSA|即插即用

论文信息

标题: FFT-based Dynamic Token Mixer for Vision

论文链接: https://arxiv.org/pdf/2303.03932

关键词: 深度学习、计算机视觉、对象检测、分割

GitHub链接: https://github.com/okojoalg/dfformer

在这里插入图片描述

创新点

本论文提出了一种新的标记混合器(token mixer),称为动态滤波器(Dynamic Filter),旨在解决多头自注意力(MHSA)模型在处理高分辨率图像时的计算复杂度问题。传统的MHSA模型在输入特征图中像素数量的平方上具有计算复杂度,导致处理速度缓慢。通过引入基于快速傅里叶变换(FFT)的动态滤波器,论文展示了在保持性能的同时显著降低计算复杂度的可能性。

方法

论文中提出的动态滤波器结合了全局操作的优点,类似于MHSA,但在计算效率上更具优势。具体方法包括:

  • FFT-based Token Mixer: 通过FFT实现全局操作,降低计算复杂度。
  • DFFormer和CDFFormer模型: 这两种新型图像识别模型利用动态滤波器进行图像分类和其他下游任务。
    在这里插入图片描述

动态滤波器如何具体降低MHSA模型的计算复杂度?

动态滤波器通过引入基于快速傅里叶变换(FFT)的机制,显著降低了多头自注意力(MHSA)模型的计算复杂度。以下是其具体工作原理和优势:

计算复杂度问题

传统的MHSA模型在处理输入特征图时,其计算复杂度与特征图中像素数量的平方成正比。这意味着,当输入图像的分辨率增加时,计算需求会急剧上升,导致处理速度变慢,尤其是在高分辨率图像的情况下。

动态滤波器的工作原理

  1. 频域转换: 动态滤波器首先利用FFT将输入特征图转换到频域。FFT是一种高效的算法,可以将计算复杂度降低到 O ( N log ⁡ N ) O(N \log N) O(NlogN),其中 N N N是数据的长度。这一转换使得后续的操作可以在频域中进行,从而减少了计算量。

  2. 动态生成滤波器: 在频域中,动态滤波器通过一个多层感知机(MLP)动态生成每个特征通道的滤波器。这些滤波器是根据输入特征图的内容进行调整的,能够更好地捕捉到图像中的重要信息。

  3. 频域操作: 生成的滤波器在频域中应用于特征图,进行全局信息的捕捉。通过这种方式,动态滤波器能够有效地进行全局操作,同时避免了MHSA中计算复杂度的急剧增加。

  4. 逆FFT转换: 最后,经过滤波的频域特征图通过逆FFT转换回空间域,得到最终的输出结果。

优势

  • 降低计算复杂度: 通过在频域中进行操作,动态滤波器显著降低了MHSA模型的计算复杂度,使得处理高分辨率图像时的速度得以提升。

  • 提高内存效率: 动态滤波器的设计使得模型在处理时占用更少的内存,适合在资源有限的环境中运行。

  • 保持性能: 尽管计算复杂度降低,动态滤波器仍然能够保持与MHSA相似的性能,尤其是在图像分类和其他视觉任务中表现出色。

效果

实验结果表明,DFFormer和CDFFormer在高分辨率图像识别任务中表现出色,具有显著的吞吐量和内存效率。具体而言,这些模型在处理高分辨率图像时的性能优于传统的MHSA模型,显示出动态滤波器在实际应用中的潜力。

实验结果

论文通过一系列实验验证了提出模型的有效性,包括:

  • 图像分类: DFFormer和CDFFormer在标准数据集上的表现接近或超过了现有的最先进模型。
  • 下游任务分析: 通过对比实验,展示了动态滤波器在不同视觉任务中的适用性和优势。

总结

本论文的研究表明,基于FFT的动态滤波器是一种值得认真考虑的标记混合器选项,尤其是在处理高分辨率图像时。通过降低计算复杂度,动态滤波器不仅提高了模型的处理速度,还保持了良好的性能,推动了计算机视觉领域的进一步发展。研究结果为未来的视觉模型设计提供了新的思路和方向。

代码

import torch
import torch.nn as nn
from timm.models.layers import to_2tuple


class StarReLU(nn.Module):
    """
    StarReLU: s * relu(x) ** 2 + b
    """

    def __init__(self, scale_value=1.0, bias_value=0.0,
                 scale_learnable=True, bias_learnable=True,
                 mode=None, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.relu = nn.ReLU(inplace=inplace)
        self.scale = nn.Parameter(scale_value * torch.ones(1),
                                  requires_grad=scale_learnable)
        self.bias = nn.Parameter(bias_value * torch.ones(1),
                                 requires_grad=bias_learnable)

    def forward(self, x):
        return self.scale * self.relu(x) ** 2 + self.bias


class Mlp(nn.Module):
    """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
    Mostly copied from timm.
    """

    def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0.,
                 bias=False, **kwargs):
        super().__init__()
        in_features = dim
        out_features = out_features or in_features
        hidden_features = int(mlp_ratio * in_features)
        drop_probs = to_2tuple(drop)

        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

class DynamicFilter(nn.Module):
    def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25,
                 act1_layer=StarReLU, act2_layer=nn.Identity,
                 bias=False, num_filters=4, size=14, weight_resize=False,
                 **kwargs):
        super().__init__()
        size = to_2tuple(size)
        self.size = size[0]
        self.filter_size = size[1] // 2 + 1
        self.num_filters = num_filters
        self.dim = dim
        self.med_channels = int(expansion_ratio * dim)
        self.weight_resize = weight_resize
        self.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias)
        self.act1 = act1_layer()
        self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels)
        self.complex_weights = nn.Parameter(
            torch.randn(self.size, self.filter_size, num_filters, 2,
                        dtype=torch.float32) * 0.02)
        self.act2 = act2_layer()
        self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)

    def forward(self, x):
        B, H, W, _ = x.shape

        routeing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters,
                                                          -1).softmax(dim=1)
        x = self.pwconv1(x)
        x = self.act1(x)
        x = x.to(torch.float32)
        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')

        if self.weight_resize:
            complex_weights = resize_complex_weight(self.complex_weights, x.shape[1],
                                                    x.shape[2])
            complex_weights = torch.view_as_complex(complex_weights.contiguous())
        else:
            complex_weights = torch.view_as_complex(self.complex_weights)
        routeing = routeing.to(torch.complex64)
        weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)
        if self.weight_resize:
            weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)
        else:
            weight = weight.view(-1, self.size, self.filter_size, self.med_channels)
        x = x * weight
        x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')

        x = self.act2(x)
        x = self.pwconv2(x)
        return x
def resize_complex_weight(origin_weight, new_h, new_w):
    h, w, num_heads = origin_weight.shape[0:3]  # size, w, c, 2
    origin_weight = origin_weight.reshape(1, h, w, num_heads * 2).permute(0, 3, 1, 2)
    new_weight = torch.nn.functional.interpolate(
        origin_weight,
        size=(new_h, new_w),
        mode='bicubic',
        align_corners=True
    ).permute(0, 2, 3, 1).reshape(new_h, new_w, num_heads, 2)
    return new_weight


if __name__ == "__main__":
    # 如果GPU可用,将模块移动到 GPU
    input_size=20
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 输入张量 (batch_size, height, width,channels)
    x = torch.randn(1, input_size , input_size, 32).to(device)
    # 初始化 pconv 模块
    dim = 32
    block = DynamicFilter(dim=dim,size=input_size)
    print(block)
    block = block.to(device)
    # 前向传播
    output = block(x)
    print("输入:", x.shape)
    print("输出:", output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/962019.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

设计模式Python版 原型模式

文章目录 前言一、原型模式二、原型模式示例三、原型管理器 前言 GOF设计模式分三大类: 创建型模式:关注对象的创建过程,包括单例模式、简单工厂模式、工厂方法模式、抽象工厂模式、原型模式和建造者模式。结构型模式:关注类和对…

一文讲解Java中的BIO、NIO、AIO之间的区别

BIO、NIO、AIO是Java中常见的三种IO模型 BIO:采用阻塞式I/O模型,线程在执行I/O操作时被阻塞,无法处理其他任务,适用于连接数比较少的场景;NIO:采用非阻塞 I/O 模型,线程在等待 I/O 时可执行其…

使用 postman 测试思源笔记接口

思源笔记 API 权鉴 官方文档-中文:https://github.com/siyuan-note/siyuan/blob/master/API_zh_CN.md 权鉴相关介绍截图: 对应的xxx,在软件中查看 如上图:在每次发送 API 请求时,需要在 Header 中添加 以下键值对&a…

AWTK 骨骼动画控件发布

Spine 是一款广泛使用的 2D 骨骼动画工具,专为游戏开发和动态图形设计设计。它通过基于骨骼的动画系统,帮助开发者创建流畅、高效的角色动画。本项目是基于 Spine 实现的 AWTK 骨骼动画控件。 代码:https://gitee.com/zlgopen/awtk-widget-s…

新年手搓--本地化部署DeepSeek-R1,全程实测

1.环境准备安装ollma ollma官网链接: Download Ollama on Linux ubuntu命令行安装: curl -fsSL https://ollama.com/install.sh | sh 选择运行模型,用7b模型试一下(模型也差不多5G): ollama run deepseek-r1:7b 运行qwen: ollama run qwen2.5:7b 2.为方便运行…

STM32使用VScode开发

文章目录 Makefile形式创建项目新建stm项目下载stm32cubemx新建项目IED makefile保存到本地arm gcc是编译的工具链G++配置编译Cmake +vscode +MSYS2方式bilibiliMSYS2 统一环境配置mingw32-make -> makewindows环境变量Cmake CmakeListnijia 编译输出elfCMAKE_GENERATOR查询…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.21 索引宗师:布尔索引的七重境界

1.21 索引宗师:布尔索引的七重境界 目录 #mermaid-svg-Iojpgw5hl0Ptb9Ti {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-Iojpgw5hl0Ptb9Ti .error-icon{fill:#552222;}#mermaid-svg-Iojpgw5hl0Ptb9Ti .…

毕业设计--具有车流量检测功能的智能交通灯设计

摘要: 随着21世纪机动车保有量的持续增加,城市交通拥堵已成为一个日益严重的问题。传统的固定绿灯时长方案导致了大量的时间浪费和交通拥堵。为解决这一问题,本文设计了一款智能交通灯系统,利用车流量检测功能和先进的算法实现了…

课题推荐:基于matlab,适用于自适应粒子滤波的应用

自适应粒子滤波(Adaptive Particle Filter, APF)是一种用于状态估计的有效方法,特别适用于非线性和非高斯系统。 文章目录 应用场景MATLAB 代码示例代码说明结果扩展说明 以下是一个基于自适应粒子滤波的简单应用示例,模拟一个一维…

Redis(5,jedis和spring)

在前面的学习中,只是学习了各种redis的操作,都是在redis命令行客户端操作的,手动执行的,更多的时候就是使用redis的api(),进一步操作redis程序。 在java中实现的redis客户端有很多,…

基于聚类与相关性分析对马来西亚房价数据进行分析

碎碎念:由于最近太忙了,更新的比较慢,提前祝大家新春快乐,万事如意!本数据集的下载地址,读者可以自行下载。 1.项目背景 本项目旨在对马来西亚房地产市场进行初步的数据分析,探索各州的房产市…

(2025 年最新)MacOS Redis Desktop Manager中文版下载,附详细图文

MacOS Redis Desktop Manager中文版下载 大家好,今天给大家带来一款非常实用的 Redis 可视化工具——Redis Desktop Manager(简称 RDM)。相信很多开发者都用过 Redis 数据库,但如果你想要更高效、更方便地管理 Redis 数据&#x…

智慧园区管理系统为企业提供高效运作与风险控制的智能化解决方案

内容概要 快鲸智慧园区管理系统,作为一款备受欢迎的智能化管理解决方案,致力于为企业提供高效的运作效率与风险控制优化。具体来说,这套系统非常适用于工业园、产业园、物流园、写字楼及公寓等多种园区和商办场所。它通过数字化与智能化的手…

C++中常用的排序方法之——冒泡排序

成长路上不孤单😊😊😊😊😊😊 【14后😊///计算机爱好者😊///持续分享所学😊///如有需要欢迎收藏转发///😊】 今日分享关于C中常用的排序方法之——冒泡排序的…

基础项目实战——3D赛车(c++)

目录 前言一、渲染引擎二、关闭事件三、梯形绘制四、轨道绘制五、边缘绘制六、草坪绘制七、前后移动八、左右移动​九、曲线轨道​十、课山坡轨道​十一、循环轨道​十二、背景展示​十三、引入速度​十四、物品绘制​十五、课数字路障​十六、分数展示​十七、重新生成​十八、…

深度学习指标可视化案例

TensorBoard 代码案例:from torch.utils.tensorboard import SummaryWriter import torch import torchvision from torchvision import datasets, transforms# 设置TensorBoard日志路径 writer SummaryWriter(runs/mnist)# 加载数据集 transform transforms.Comp…

AI时序预测: iTransformer算法代码深度解析

在之前的文章中,我对iTransformer的Paper进行了详细解析,具体文章如下: 文章链接:深度解析iTransformer:维度倒置与高效注意力机制的结合 今天,我将对iTransformer代码进行解析。回顾Paper,我…

Java内存模型 volatile 线程安全

目录 Java内存模型可见性例子和volatilevolatile如何保证可见性原子性与单例模式i非原子性 线程安全 Java内存模型 参考学习: Java Memory Model外文文档 CPU与内存,可参考:https://blog.csdn.net/qq_26437925/article/details/145303267 Java线程与内…

【FreeRTOS 教程 三】协程状态、优先级、实现及调度

目录 一、协程介绍: (1)协程的特点: (2)协程的优势: 二、协程状态: (1)协程状态说明: (2)协程状态图示:…

堆的存储(了解)

由于堆是⼀个完全⼆叉树,因此可以⽤⼀个数组来存储。(如果不清楚大家可以回顾⼆叉树的存储(上)c文章里的顺序存储) 结点下标为 i : 如果⽗存在,⽗下标为 i/2 ; 如果左孩⼦存在&…