SMA2:代码实现详解——Image Encoder篇(Hiera章)

SMA2:代码实现详解——Image Encoder篇(Hiera)

写在前面

大家在SMA2:代码实现详解——Image Encoder篇(FpnNeck)下的留言我已收到,感谢大家的支持,后面如果遇到比较难以讲清的部分可能会使用视频的形式。博主最近要准备秋招,更新可能会慢许多,希望大家能谅解。

言归正传,在SMA2:代码实现详解——Image Encoder篇(FpnNeck)中,我们已经知道了SMA2的整体架构,并且介绍了Image Encoder组件中的FpnNeck。这一篇博客我们就来详细介绍Image Encoder的基本骨架backbone——Hiera

Hiera介绍

Hiera是文章Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles中提出的一种分层视觉Transformer架构。它不仅可以处理图像,而且这个架构可以应用于视频。Hiera是一个纯粹的简单分层ViT模型,不存在任何卷积、移位或者十字窗口操作,仅有Transformer结构组件。它比之前跨多个模型大小、领域和任务的工作更快、更准确。
在这里插入图片描述

Hiera与MAE(Masked AutoEncoder)

MAE(Masked AutoEncoder, 掩码自编码器)

图像MAE由论文Masked Autoencoders Are Scalable Vision Learners提出,它表明,MAE是计算机视觉的可扩展自监督学习器。方法非常简单:屏蔽输入图像的随机Patch并重建丢失的像素。它基于两个核心设计。首先,作者开发了一种非对称编码器-解码器架构,其中的编码器仅对Patch的可见子集(没有掩码标记)进行操作,而轻量级解码器可根据潜在表示和掩码标记重建原始图像。作者发现屏蔽高比例的输入图像(例如 75%)会产生一项不简单且有意义的自我监督任务。将这两种设计结合起来能够高效且有效地训练大型模型:加速训练(3 倍或更多)并提高准确性。可扩展方法允许学习泛化良好的高容量模型:例如,在仅使用 ImageNet-1K 数据的方法中,普通 ViT-Huge 模型实现了最佳准确率 (87.8%)。下游任务中的传输性能优于监督预训练,并显示出有希望的扩展行为。

Hiera便使用了MAE的方式进行训练。

Hiera架构

在这里插入图片描述

选择使用像MAE(如图所示)这样的强代理任务(pretext task)来教导模型。 Hiera完全由标准ViT块组成。为了提高效率,在前两个阶段使用“掩模单元”内的局部注意力,其余阶段使用全局注意力(Global Attention)。在每个阶段转换中,Q和跳跃连接的特征通过线性层加倍,空间维度通过2×2最大池池化。

SMA2中Hiera(HieraDet)的实现

class Hiera(nn.Module):
    """
    Reference: https://arxiv.org/abs/2306.00989
    """

    def __init__(self, ...):
        ...
        self.blocks = nn.ModuleList()

        for i in range(depth):
            dim_out = embed_dim
            ...
            block = MultiScaleBlock(
                dim=embed_dim,
                dim_out=dim_out,
                num_heads=num_heads,
                drop_path=dpr[i],
                q_stride=self.q_stride if i in self.q_pool_blocks else None,
                window_size=window_size,
            )
            embed_dim = dim_out
            self.blocks.append(block)

    def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
        h, w = hw
        window_embed = self.pos_embed_window
        pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
        pos_embed = pos_embed + window_embed.tile(
            [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1)
        return pos_embed

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        x = self.patch_embed(x)
        # x: (B, H, W, C)

        # Add pos embed
        x = x + self._get_pos_embed(x.shape[1:3])

        outputs = []
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            if (i == self.stage_ends[-1]) or (
                i in self.stage_ends and self.return_interm_layers
            ):
                feats = x.permute(0, 3, 1, 2)
                outputs.append(feats)

        return outputs

首先,Hiera先将图片划分并映射为patch嵌入向量(上述代码62行),然后计算位置信息并相加(代码第66行)。值得注意的是,SMA2在实现Hiera中位置嵌入时,参照了Window Attention is Bugged: How not to Interpolate Position Embeddings一文,他们发现在使用窗口注意力的同时插值位置嵌入是错误的。Hiera和ViTDet两者确实都存在此错误。于是作者提出了一种简单的绝对窗口位置嵌入策略,它彻底解决了Hiera中的错误,并提高了ViTDet中模型的速度和性能。

代码的68-75行实际上就是Hiera主体ViT块的处理,值得关注的只有带有Q pooling的ViT块,这是在MultiScaleBlock中实现的。

class PatchEmbed(nn.Module):
    """
    Image to Patch Embedding.
    """

    def __init__(
        self,
        kernel_size: Tuple[int, ...] = (7, 7),
        stride: Tuple[int, ...] = (4, 4),
        padding: Tuple[int, ...] = (3, 3),
        in_chans: int = 3,
        embed_dim: int = 768,
    ):
        """
        Args:
            kernel_size (Tuple): kernel size of the projection layer.
            stride (Tuple): stride of the projection layer.
            padding (Tuple): padding size of the projection layer.
            in_chans (int): Number of input image channels.
            embed_dim (int):  embed_dim (int): Patch embedding dimension.
        """
        super().__init__()
        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        # B C H W -> B H W C
        x = x.permute(0, 2, 3, 1)
        return x

PatchEmbed模块将图片的形状(B,C,H,W)转化为更常见的适用于Transformer处理的形状(B, H, W, C),因为后面经过VIT块时会要求(B,L,C)的形式。实际上,这个模块的卷积映射继承了ViT的做法,直接利用了卷积的特性,通过指定Kernel_size与strides隐式划分了窗口,并且完成了线性变换得到patch enmbedding

值得注意的是位置嵌入的计算:

class Hiera(nn.Module):
 
    def __init__(...)
        super().__init__()
        ...
        self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
        self.pos_embed = nn.Parameter(
            torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
        )
        self.pos_embed_window = nn.Parameter(
            torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
        )
        ...

    def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
        h, w = hw
        window_embed = self.pos_embed_window
        pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
        pos_embed = pos_embed + window_embed.tile(
            [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1)
        return pos_embed

代码第18行是计算全局的可学习位置嵌入。第19行加号的右边window_embed.tile(...)是计算每个window内的局部位置编码,每个window的位置编码都是相同的。我们可以使用matplotlib做一个可视化的样例,可能更容易理解。示例如下(由于代码中是零初始化,不太好展示,这里我选择随机初始化来展示):

在这里插入图片描述

从左到右依次为全局编码、局部编码和最终位置编码。

接下来我们来看MultiScaleBlock的实现:

class MultiScaleBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        dim_out: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        drop_path: float = 0.0,
        norm_layer: Union[nn.Module, str] = "LayerNorm",
        q_stride: Tuple[int, int] = None,
        act_layer: nn.Module = nn.GELU,
        window_size: int = 0,
    ):
        super().__init__()

        if isinstance(norm_layer, str):
            norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)

        self.dim = dim
        self.dim_out = dim_out
        self.norm1 = norm_layer(dim)

        self.window_size = window_size

        self.pool, self.q_stride = None, q_stride
        if self.q_stride:
            self.pool = nn.MaxPool2d(
                kernel_size=q_stride, stride=q_stride, ceil_mode=False
            )

        self.attn = MultiScaleAttention(
            dim,
            dim_out,
            num_heads=num_heads,
            q_pool=self.pool,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = norm_layer(dim_out)
        self.mlp = MLP(
            dim_out,
            int(dim_out * mlp_ratio),
            dim_out,
            num_layers=2,
            activation=act_layer,
        )

        if dim != dim_out:
            self.proj = nn.Linear(dim, dim_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x  # B, H, W, C
        x = self.norm1(x)

        # Skip connection
        if self.dim != self.dim_out:
            shortcut = do_pool(self.proj(x), self.pool)

        # Window partition
        window_size = self.window_size
        if window_size > 0:
            H, W = x.shape[1], x.shape[2]
            x, pad_hw = window_partition(x, window_size)

        # Window Attention + Q Pooling (if stage change)
        x = self.attn(x)
        if self.q_stride:
            # Shapes have changed due to Q pooling
            window_size = self.window_size // self.q_stride[0]
            H, W = shortcut.shape[1:3]

            pad_h = (window_size - H % window_size) % window_size
            pad_w = (window_size - W % window_size) % window_size
            pad_hw = (H + pad_h, W + pad_w)

        # Reverse window partition
        if self.window_size > 0:
            x = window_unpartition(x, window_size, pad_hw, (H, W))

        x = shortcut + self.drop_path(x)
        # MLP
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


def window_partition(x, window_size):
    """
    Partition into non-overlapping windows with padding if needed.
    Args:
        x (tensor): input tokens with [B, H, W, C].
        window_size (int): window size.
    Returns:
        windows: windows after partition with [B * num_windows, window_size, window_size, C].
        (Hp, Wp): padded height and width before partition
    """
    B, H, W, C = x.shape

    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size
    if pad_h > 0 or pad_w > 0:
        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
    Hp, Wp = H + pad_h, W + pad_w

    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
    windows = (
        x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    )
    return windows, (Hp, Wp)

def do_pool():... #(B, H, W, C) -> (B, H', W' C)

MultiScaleBlockMultiScaleAttentionMLP构成,有经验的小伙伴看到注意力机制和MLP,显然得出它是一个Transformer。第60-63行代码就是根据每个stage给定的window size划分patch。
在这里插入图片描述

而且针对于每个Stage的交界,都使用Q pooling,这在MultiScaleAttention中实现。

class MultiScaleAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        dim_out: int,
        num_heads: int,
        q_pool: nn.Module = None,
    ):
        super().__init__()

        self.dim = dim
        self.dim_out = dim_out
        self.num_heads = num_heads
        self.q_pool = q_pool
        self.qkv = nn.Linear(dim, dim_out * 3)
        self.proj = nn.Linear(dim_out, dim_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, H, W, _ = x.shape
        # qkv with shape (B, H * W, 3, nHead, C)
        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
        # q, k, v with shape (B, H * W, nheads, C)
        q, k, v = torch.unbind(qkv, 2)

        # Q pooling (for downsample at stage changes)
        if self.q_pool:
            q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
            H, W = q.shape[1:3]  # downsampled shape
            q = q.reshape(B, H * W, self.num_heads, -1)

        # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
        x = F.scaled_dot_product_attention(
            q.transpose(1, 2),
            k.transpose(1, 2),
            v.transpose(1, 2),
        )
        # Transpose back
        x = x.transpose(1, 2)
        x = x.reshape(B, H, W, -1)

        x = self.proj(x)

        return x

代码19-23以及31-41都是比较传统的自注意力机制的计算了。
而所谓的Q pooling在26-30行,只是对Q向量转换为宽高的形状(B, H*W,)->(B, H, W, …),然后进行池化。其实对于H和W,它们应该是我们之前指定的window size。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/876043.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python 课程9-資料庫操作

前言 在现代软件开发中,数据库是核心组件之一,它负责数据的存储、管理和检索。无论是简单的应用程序还是复杂的企业级系统,数据库操作都是必不可少的。本教程将深入讲解如何使用 Python 进行数据库操作,涵盖使用 sqlite3 进行本地…

OpenHarmony(鸿蒙南向开发)——轻量系统STM32F407芯片移植案例

往期知识点记录: 鸿蒙(HarmonyOS)应用层开发(北向)知识点汇总 鸿蒙(OpenHarmony)南向开发保姆级知识点汇总~ OpenHarmony(鸿蒙南向开发)——轻量和小型系统三方库移植指南…

Android SystemUI组件(06)导航栏创建分析虚拟按键

该系列文章总纲链接:专题分纲目录 Android SystemUI组件 本章关键点总结 & 说明: 说明:本章节持续迭代之前章节的思维导图,主要关注左侧SystemBars分析中导航栏部分即可。 1 导航栏创建之makeStatusBarView 通过上一篇文章的…

代理IP设置后IP不变?可能的原因及解决方法

在使用代理IP时,有时会遇到代理设置后IP地址却没有变化的情况。这种问题可能会让人感到困惑,但其实背后有多种原因。本文将详细探讨这些原因,并提供相应的解决方法,帮助你顺利解决问题。 可能的原因 代理IP设置后IP地址不变的原…

Spring的核心思想

目录 一、Spring要解决的问题 二、Spring的核心结构 三、核心思想 3.1.1 什么是IOC 3.1.2 IOC解决的问题:耦合 3.1.3 IOC和DI的区别 3.2.1 什么是AOP 3.2.2 AOP解决的问题:耦合 3.2.3 为什么叫做面向切面编程 一、Spring要解决的问题 问题1&am…

maya-vray渲染蒙版

要用一个叫vrayMulWrapper的材质球,把alpha Conterbution调到-1,勾选matte surface启用蒙版物体。

爬虫逆向学习(六):补环境过某数四代

声明:本篇文章内容是整理并分享在学习网上各位大佬的优秀知识后的实战与踩坑记录 引用博客: https://blog.csdn.net/shayuchaor/article/details/103629294 https://blog.csdn.net/qq_36291294/article/details/128600583 https://blog.csdn.net/weixin_…

时序预测 | Matlab实现GA-CNN遗传算法优化卷积神经网络时间序列预测

时序预测 | Matlab实现GA-CNN遗传算法优化卷积神经网络时间序列预测 目录 时序预测 | Matlab实现GA-CNN遗传算法优化卷积神经网络时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 时序预测 | Matlab实现GA-CNN遗传算法优化卷积神经网络时间序列预测&#xff…

巴西电商市场规模、前景及支付方式(pix、Boleto)

一、巴西电商市场分析 作为拉丁美洲最大经济体,巴西在拉丁美洲经济中占据领先地位,根据巴西地理与统计研究所(IBGE)的数据,2023年巴西GDP达到2.2万亿美元,跃居世界第九大经济体。数字化进程以及经济多元化推进正在推动该国中产阶…

TiDB 数据库核心原理与架构_Lesson 01 TiDB 数据库架构概述课程整理

作者: 尚雷5580 原文来源: https://tidb.net/blog/beeb9eaf 注:本文基于 TiDB 官网 董菲老师 《TiDB 数据库核心原理与架构(101) 》系列教程之 《Lesson 01 TiDB 数据库架构概述》内容进行整理和补充。 课程链接:…

PowerBI 关于FILTERS函数和VALUES函数

本人是powerbi新手,最近在使用Filters()函数和Values()函数时,有点不太明白它们之间的区别,u有时它们得到的结果是一样的,有时却不一样。 官方文档里,Filters()是表示返回直接作为筛选器应用到 columnName 的值 FILT…

凸优化学习(1)——什么是凸优化、凸集、凸函数

🍅 写在前面 👨‍🎓 博主介绍:大家好,这里是hyk写算法了吗,一枚致力于学习算法和人工智能领域的小菜鸟。 🔎个人主页:主页链接(欢迎各位大佬光临指导) ⭐️近…

Python之NumPy超详细学习指南:从入门到精通(上篇)

文章目录 Python NumPy学习指南:从入门到精通第一部分:NumPy简介与安装1. 什么是NumPy?2. 安装NumPy使用pip安装:使用Anaconda安装: 第二部分:NumPy数组基础1. NumPy数组的创建从列表创建一维数组&#xff…

OpenCV结构分析与形状描述符(14)拟合直线函数fitLine()的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 拟合一条直线到2D或3D点集。 fitLine 函数通过最小化 ∑ i ρ ( r i ) \sum_i \rho(r_i) ∑i​ρ(ri​)来拟合一条直线到2D或3D点集&#xff0c…

FishAudio发布了 Fish Speech V1.4

还记得今年OpenAI 刚推出 gpt4o 不久,开源界就出现了 ChatTTS 和 FishSpeech 这些不错的 TTS 项目。 而 Fish Speech V1.4 是一个领先的文本到语音(TTS)模型,它是在 700,000 小时的多语言音频数据基础上训练出来的。 该模型支持八…

K8s 之Pod的定义及详细资源调用案例

资源管理介绍 在kubernetes中,所有的内容都抽象为资源,用户需要通过操作资源来管理kubernetes。kubernetes的本质上就是一个集群系统,用户可以在集群中部署各种服务所谓的部署服务,其实就是在kubernetes集群中运行一个个的容器&a…

[XILINX] 正点原子ZYNQ7015开发板!ZYNQ 7000系列、双核ARM、PCIe2.0、SFPX2,性能强悍,资料丰富!

正点原子ZYNQ7015开发板!ZYNQ 7000系列、双核ARM、PCIe2.0、SFPX2,性能强悍,资料丰富! 正点原子Z15 ZYNQ开发板,搭载Xilinx Zynq7000系列芯片,核心板主控芯片的型号是XC7Z015CLG485-2。开发板由核心板&…

GLSL 棋盘shader

今日永杰开金 float size 100.;vec2 checkerboard mod(floor(gl_FragCoord.xy / size), 2.);float c mod(checkerboard.x checkerboard.y, 2.);gl_FragColor vec4(vec3(c), 1);或 vec2 uv floor(S * p.xy * vec2(iResolution.x / iResolution.y, 1) / iResolution.xy); …

【主机入侵检测】Wazuh规则详解

前言 Wazuh 规则是一组用XML格式编写的条件,它们定义了应该如何解释日志数据。这些规则由Wazuh Manager使用,用于在日志消息中检测特定的模式或行为,并相应地生成警报或响应。它们在威胁检测中扮演着至关重要的角色,因为它们允许系…

Golang | Leetcode Golang题解之第397题整数替换

题目: 题解: func integerReplacement(n int) (ans int) {for n ! 1 {switch {case n%2 0:ansn / 2case n%4 1:ans 2n / 2case n 3:ans 2n 1default:ans 2n n/2 1}}return }