详解SwinIR的论文和代码(SwinIR: Image Restoration Using Swin Transformer)

paper:https://arxiv.org/abs/2108.10257
code:https://github.com/JingyunLiang/SwinIR

目录

  • 1. Swin Transformer layers
    • 1.1 局部注意力
    • 1.2 移动窗口机制
    • 1.3 关键代码理解
  • 2. 整体网络结构
    • 2.1 浅层特征提取
    • 2.2 深层特征提取
    • 2.3 图像重建
  • 3.总结

SwinIR将Swin transformer1应用到low level领域的图像增强任务,结合卷积设计了网络结构,在以下三个任务上取得了很好的效果:图像超分辨率(包括classical、lightweight和real-world SR)、图像去噪(包括灰度图和彩色图像去噪)和 JPEG压缩失真去除。本文将结合代码对SwinIR进行详解。

SwinIR的网络结构并不复杂,关键部件就是Swin Transformer layers(STL)卷积层残差连接。卷积和残差连接大家都比较熟悉了,因此我首先结合代码介绍一下swin transformer层,然后自底向上的介绍SwinIR的全貌

1. Swin Transformer layers

SwinIR使用的Swin Transformer layers(STL)是在swin transformer中提出的,并未有改动。STL基于原始的多头注意力transformer层进行优化,主要的不同点在于:1. 局部注意力(local attention);2. 移动窗口机制(shifted window mechanism);

1.1 局部注意力

原始的全局注意力会将图像分成若干个patch,所有的patch之间做自注意力计算;所谓的局部注意力就是首先将图像划分成若干个window,每个window内在进行patch的划分,然后在window内部进行自注意力的计算,而不在一个window内的patch是没有交互的。也就是说,只考虑一个window内的patch,他们之间的计算和全局注意力操作是一样的。

理解局部注意力具体是怎么做的,很好的一个办法是看代码和分析tensor在不同层之间的shape整理出来。下面是我整理的tensor shape变化:

请添加图片描述
其中,b: batchsize, h: 输入高, w:输入宽, ws: 窗口大小, C: channel数, num_heads:attention的head数

1.2 移动窗口机制

由于基于窗口的多头注意力(W-MSA)没有考虑跨窗口的连接,模型建模长距离关联的能力受损。因此swin transformer提出了移动窗口多头注意力机制(SW-MSA),可在保证计算高效性的前提下,扩大感受野。

如下图所示,W-MSA的窗口大小为M*M(图中M=4),那么SW-MSA的窗口划分将向右下移动 ⌊ M / 2 ⌋ ∗ ⌊ M / 2 ⌋ \lfloor M/2 \rfloor *\lfloor M/2 \rfloor M/2M/2

请添加图片描述

但是经过位移之后,窗口数量会变多,由原来的 ⌊ h / M ⌋ ∗ ⌊ w / M ⌋ \lfloor h/M \rfloor *\lfloor w/M \rfloor h/Mw/M变成 ( ⌊ h / M ⌋ + 1 ) ∗ ( ⌊ w / M ⌋ + 1 ) (\lfloor h/M \rfloor + 1) *(\lfloor w/M \rfloor +1) (⌊h/M+1)(⌊w/M+1),而且窗口大小不一致。因此swin transformer提出了循环位移,减少窗口数量,同时可以获得相同大小的窗口进行并行计算。循环位移如下图所示。
请添加图片描述

在代码中,循环位移通过torch.roll实现,shifts为负,代表从下往上移动,从右往左移动,最上和最左循环移动到最下和最右。

# cyclic shift
if self.shift_size > 0:
    shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))

关于torch.roll可参考:https://blog.csdn.net/weixin_42899627/article/details/116095067
如上图所示,经过循环移位后,有三个窗口中有一些patch是本不相邻的,它们不应该做自注意力,所以swin transformer建立了mask机制来完成最终的注意力计算。

关于mask的理解可参考https://github.com/microsoft/Swin-Transformer/issues/38

1.3 关键代码理解

下面来看一下关键代码及注释,首先是WindowAttention的forward函数:

def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape  # 此处的输入是经过window partition的
        # self.qkv(x): num_windows*B, window_size*window_size, 3*C
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # 通过一个全连接层获取所有头的qkv,(3, num_windows*B, num_heads, window_size*window_size, C // num_heads)
        q, k, v = qkv[0], qkv[1], qkv[2] 
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1)) # num_windows*B, num_heads, window_size*window_size, window_size*window_size
		# 可学习的相对位置bias
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0) # num_windows*B, num_heads, window_size*window_size, window_size*window_size

        if mask is not None:
            nW = mask.shape[0]
            # 将mask和attn相加,mask只有两种取值0和-100,因此为0时对attn无影响,为-100时,self.softmax(attn)将变为接近于0
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N) # num_windows*B, num_heads, window_size*window_size, window_size*window_size
            attn = self.softmax(attn) # num_windows*B, num_heads, window_size*window_size, window_size*window_size
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)
        # v:num_windows*B, num_heads, window_size*window_size, C // num_heads
        # attn:num_windows*B, num_heads, window_size*window_size, window_size*window_size
        # attn @ v: num_windows*B, num_heads, window_size*window_size, C // num_heads
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # num_windows*B, window_size*window_size, C
        x = self.proj(x) # 全连接层
        x = self.proj_drop(x)
        return x

接下来是SwinTransformerBlock的forward函数

    def forward(self, x, x_size):
        H, W = x_size
        B, L, C = x.shape
        # assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # (num_windows*B, window_size, window_size, C)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # num_windows*B, window_size*window_size, C

        # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
        if self.input_resolution == x_size:
            attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C
        else:
            attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

可以看到每个SwinTransformerBlock内部完成的是:
X = M S A ( L N ( X ) ) + X X = MSA(LN(X)) + X X=MSA(LN(X))+X
X = M L P ( L N ( X ) ) + X X = MLP(LN(X)) + X X=MLP(LN(X))+X
其中MSA为W-MSA和SW-MSA交替。

2. 整体网络结构

请添加图片描述
如上图所示,SwinIR包括三个modules,浅层特征提取、深层特征提取和图像重建。其中特征提取模块对所有任务都是一样的,但是图像重建对于不同的任务是不同的。

2.1 浅层特征提取

一个3×3卷积层将特征图通道转成embed_dim:(b, embed_dim, h, w)

self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)

2.2 深层特征提取

深层特征提取的基本模块则是第一节中讲解的STL和卷积层和残差连接。STL和卷积组成RSTB,RSTB和卷积组成了深层特征提取。

2.3 图像重建

以下代码可以看到对于不同的任务,图像重建模块是不同的,有的采用最邻近插值+卷积,有的采用pixelshuffle+卷积,有的直接采用卷积。


if self.upsampler == 'pixelshuffle':
    # for classical SR
    x = self.conv_first(x)
    x = self.conv_after_body(self.forward_features(x)) + x
    x = self.conv_before_upsample(x)
    x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffledirect':
    # for lightweight SR
    x = self.conv_first(x)
    x = self.conv_after_body(self.forward_features(x)) + x
    x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
    # for real-world SR
    x = self.conv_first(x) # (b, embed_dim, h, w)
    x = self.conv_after_body(self.forward_features(x)) + x
    x = self.conv_before_upsample(x)
    x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
    x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
    x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
    # for image denoising and JPEG compression artifact reduction
    x_first = self.conv_first(x)
    res = self.conv_after_body(self.forward_features(x_first)) + x_first
    x = x + self.conv_last(res)

SwinIR可以很灵活配置网络的复杂度。影响W-MSA计算复杂度: 4 h w C 2 + 2 M 2 h w C 4hwC^2 + 2M^2hwC 4hwC2+2M2hwC
请添加图片描述

3.总结

  1. 结构简单,性能全面超过cnn-based的方法,适用于多种任务,可做为Low-level的基线模型;
  2. 作者发现与以往基于transformer的方法不同,Swinir不需要比cnn更多的训练数据,收敛速度也更快;
  3. 结构模块化,可以方便调整出不同复杂度的模型;

  1. Liu Z, Lin Y, Cao Y, et al. Swin transformer: Hierarchical vision transformer using shifted windows[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2021: 10012-10022. ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/167870.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

BUUCTF 秘密文件 1

BUUCTF:https://buuoj.cn/challenges 题目描述: 深夜里,Hack偷偷的潜入了某公司的内网,趁着深夜偷走了公司的秘密文件,公司的网络管理员通过通过监控工具成功的截取Hack入侵时数据流量,但是却无法分析出Hack到底偷走…

Azure 机器学习 - 搜索中的检索增强 (RAG)

目录 一、Azure AI 信息检索系统介绍二、采用 Azure AI 搜索的 RAG 方法三、适合 Azure AI 搜索的自定义 RAG 模式四、Azure AI 搜索中的可搜索内容五、Azure AI 搜索中的内容检索构建查询响应按相关性排名适用于 RAG 方案的 Azure AI 搜索查询的示例代码 六、集成代码和 LLM七…

【MySQL】_JDBC

目录 1. JDBC原理 2. 导入JDBC驱动包 3. 编写JDBC代码实现Insert 3.1 创建并初始化一个数据源 3.2 和数据库服务器建立连接 3.3 构造SQL语句 3.4 执行SQL语句 3.5 释放必要的资源 4. JDBC代码的优化 4.1 从控制台输入 4.2 避免SQL注入的SQL语句 5. 编写JDBC代码实现…

深入Ansible

1.什么是ansible ansible是新出现的自动化运维工具,基于Python开发,集合了众多运维工具(puppet、chef、func、fabric)的优点,实现了批量系统配置、批量程序部署、批量运行命令等功能。 ansible是基于 paramiko 开发的…

11月20日星期一今日早报简报微语报早读

11月20日星期一,农历十月初八,早报微语早读。 1、T1以3-0横扫WBG,拿下S13冠军!Faker豪取第4冠; 2、天舟七号货运飞船已运抵文昌发射场,将于明年初发射; 3、“中韩之战”球票已经售罄&#xf…

没收到Win11 23H2正式版的推送怎么升级到23H2

没收到Win11 23H2正式版的推送怎么升级到23H2?用户反映自己没有收到Win11 23H2正式版的更新推送,又想升级为23H2版本。接下来小编给大家详细介绍不同的升级方法,帮助更多的用户完成Win11 23H2系统的更新,升级后就能体验到Win11 23…

解锁安全与信任的双重礼遇!JoySSL证书买二送一,买三送二

JoySSL是业内领先的SSL证书提供商,致力于为网站提供最高水平的安全性。通过使用JoySSL证书,您的网站将获得强大的加密保护,确保用户的敏感信息在传输过程中得到安全加密,有效地抵御各种网络威胁。 为何选择JoySSL证书&#xff1f…

解决龙芯loongarch64服务器编译安装Python后yum命令无法使用的问题“no module named ‘dnf‘”

引言 在使用Linux系统时,我们经常会使用yum来管理软件包。然而,有时候我们可能会遇到yum不可用的情况,其中一个原因就是Python的问题。本文将介绍Python对yum可用性的影响,并提供解决方案。 问题引发 正常情况下,安装linux系统后,yum命令是可用状态,升级Python版本后,…

pyqt5切换到pyqt6遇到问题

pyqt5切换到pyqt6变更点 FramelessWindowHint Qt.FramelessWindowHint Qt.WindowType.FramelessWindowHint globalPos event.globalPos() event.globalPosition() LeftButton Qt.LeftButton Qt.MouseButton.LeftButton StrongFocus Qt.StrongFocus Qt.FocusPolicy.Stro…

装备中国功勋企业——兰石重装,建设LTC全流程管理|基于得帆云低代码的CRM案例系列

兰石重型装备股份有限公司 兰石重型装备股份有限公司(以下简称“兰石重装”)成立于2001年,经营范围为炼油、化工、核电等能源领域所需的装备的设计、制造、安装、成套与服务;工程项目建设与服务;机械加工;检…

第94步 深度学习图像分割:DeepLabv3建模

基于WIN10的64位系统演示 一、写在前面 本期,我们继续学习深度学习图像分割系列的最后一个模型,DeepLabv3。 二、DeepLabv3简介 DeepLabv3 是 DeepLab 系列中的第三个版本,以其高准确性和能够在多个尺度上识别物体轮廓而著称。 以下是 De…

在建筑设计方面3DMax和Maya哪一个更好?

对于建筑设计可视化工具,有许多渲染程序提供了一套非常复杂的工具,使专业人员能够做到最好。ArchViz工具对建筑师和设计师来说非常重要,因为它们可以让他们更准确地识别设计发展阶段的问题,并更有效地与客户交流设计理念。 在今天…

自动化测试开发 —— 如何封装自动化测试框架?

封装自动化测试框架,测试人员不用关注框架的底层实现,根据指定的规则进行测试用例的创建、执行即可,这样就降低了自动化测试门槛,能解放出更多的人力去做更深入的测试工作。本篇文章就来介绍下,如何封装自动化测试框架…

图片OCR转表格:终极攻略,优缺点全解析

随着人工智能技术的不断发展,图片OCR识别已经成为了许多领域中不可或缺的一部分。将图片中的文字识别并转化为文本,可以大大提高数据处理的效率和准确性。而将图片中的表格识别并转化为可编辑的表格格式,则可以更加方便地对数据进行处理和分析…

Iris for Mac:简单易用的录屏神器,让你的演示更出色

你是否曾需要在Mac上进行屏幕录制,但却苦于找不到一款简单易用的工具?那么,今天我们将为你介绍一款全新的屏幕录制软件——Iris for Mac。这款软件不仅功能强大,而且操作简单,可以轻松满足你的各种录屏需求。 一、简单…

auth 对象创建 session 时出现 异常 详细解析

在使用 requests 库的 1.0.3 版本时,尝试使用 auth 对象创建 session 时出现了 TypeError 异常。然而,在 0.14.2 版本中,这个操作是正常工作的。下面是具体的错误信息: python import requests requests.version 1.0.3 auth re…

Leetcode—53.最大子数组和【中等】

2023每日刷题&#xff08;三十四&#xff09; Leetcode—53.最大子数组和 前缀和算法思想 参考灵茶山艾府 实现代码 #define MAX(a, b) ((a > b) ? (a) : (b)) #define MIN(a, b) ((a < b) ? (a) : (b)) int maxSubArray(int* nums, int numsSize) {int ans INT_…

27岁想转行IT,还来得及吗?

“种一棵树最好的时间是十年前&#xff0c;其次是现在&#xff01; ” 任何时候&#xff0c;只要开始&#xff0c;一切都来得及&#xff0c;如果你真的想要转行&#xff0c;何时都不会晚&#xff0c;只要你付出相应的努力&#xff0c;坚持下去&#xff0c;一切都会朝着自己希望…

easyExcel注解详情

前言11个注解字段注解 类注解基础综合示例补充颜色总结 11个注解 ExcelProperty ColumnWith 列宽 ContentFontStyle 文本字体样式 ContentLoopMerge 文本合并 ContentRowHeight 文本行高度 ContentStyle 文本样式 HeadFontStyle 标题字体样式 HeadRowHeight 标题高度 HeadStyle…

深入浅出讲解python闭包

一、定义 在 Python 中&#xff0c;当一个函数内部定义的函数引用了外部函数的局部变量时&#xff0c;就形成了一个闭包。这个内部函数可以访问并修改外部函数的局部变量&#xff0c;而这些局部变量的状态会一直被保存在闭包中&#xff0c;即使外部函数已经执行完毕。 这种机…