【NeurIPS 2023】PromptIR: Prompting for All-in-One Blind Image Restoration

PromptIR: Prompting for All-in-One Blind Image Restoration, NeurIPS 2023

论文:https://arxiv.org/abs/2306.13090

代码:https://github.com/va1shn9v/promptir

解读:即插即用系列 | PromptIR:MBZUAI提出一种基于Prompt的全能图像恢复网络 - 知乎 (zhihu.com)

摘要

图像恢复是从其受损版本中恢复高质量清晰图像的过程。deep-learning方法显著提升了图像恢复性能,然而,它们在不同类型和级别的退化上的泛化能力有限。这限制了它们在实际应用中的使用,因为需要针对每种具体的退化进行单独训练模型,并了解输入图像的退化类型才能应用相应的模型。本文介绍了一种基于提示的学习方法,称为PromptIR,用于全能图像恢复,可以有效地从各种类型和级别的退化中恢复图像。具体而言,本文方法使用提示来编码退化特定信息,并动态引导恢复网络。 PromptIR提供了一个通用且高效的插件模块,只需少量轻量级提示即可用于恢复各种类型和级别的受损图像,无需事先了解图像中存在的损坏信息。

动机

图像恢复过程中,会由于各种客观因素或限制(相机设备、环境条件)出现各种退化现象。deep-learning方法虽然有用,但在特定的退化类型和程度之外缺乏泛化性。因此,迫切需要开发一种能够有效恢复各种类型和程度的退化图像的一体化方法。

AirNet,采用对比学习解决一体化恢复任务,但需要训练额外的编码器,会增加训练负担。

为此,本文提出了一种基于提示学习的方法来执行一体化图像恢复。该方法利用提示(一组可调参数),用于编码关于各种图像退化类型的重要区分信息。通过将提示与主恢复网络的特征表示相互作用,动态地增强表示,以获得具有退化特定知识的适应性,这种适应性使网络能够通过动态调整其行为有效地恢复图像。

该图显示了在PromptIR和AirNet中使用的退化嵌入的tSNE图。不同的颜色表示不同的退化类型。每个任务的嵌入更好地聚集在一起,显示了提示标记学习具有区分性的退化上下文的有效性,从而有助于恢复过程。

 

所提的Prompt 方法

本文PromptIR,提出了一个即插即用的提示模块,它隐式地预测与退化条件相关的提示,以引导具有未知退化的输入图像的恢复过程。来自提示的指导被注入到网络的多个解码阶段,其中包含少量可学习参数。

 

贡献

  • 提出了一种基于提示的一体化恢复框架PromptIR,它仅依赖于输入图像来恢复清晰图像,而不需要任何关于图像中存在的退化的先验知识。
  • 本文prompt block 是一个可轻松集成到任何现有恢复网络中的插件模块。它由提示生成模块(PGM)和提示交互模块(PIM)组成。提示块的目标是生成与输入条件相关的提示(通过PGM),这些提示具有有用的上下文信息,以指导恢复网络(通过PIM)有效地消除输入图像中的破坏。
  • 本文实验证明了PromptIR的动态适应行为,在包括图像降噪、去雨和去雾在内的各种图像恢复任务上实现了最先进的性能。

方法 PromptIR

“一体式”图像恢复的目标是,学习单个模型M,以从退化的图像中恢复干净图像,该图像已使用退化方式D退化,而没有关于D的先验信息。虽然该模型对退化方式“不可见”,可以通过提供关于退化类型的隐含上下文信息来增强其在恢复干净图像方面的性能。基于提示学习的图像恢复框架PromptIR,用于在恢复干净图像的同时用退化类型的相关知识补充模型。关键元素是 提示块prompt block。

PromptIR使用提示块来生成可学习的提示参数,并在恢复过程中利用这些提示来指导模型。框架通过逐级编码器-解码器将特征逐步转换为深层特征,并在解码器中引入提示块来辅助恢复过程。提示块在解码器的每个级别中连接,隐式地为输入特征提供关于退化类型的信息,以实现引导恢复。总体来说,PromptIR框架通过逐级编码和解码以及引入提示块的方式实现图像恢复任务。 

Prompt Block

提示块 prompt block,由两个模块组成:提示生成模块(PGM)和提示交互模块(PIM)。

  • 提示生成模块使用输入特征Fl和提示组件生成与输入条件相关的提示P。
  • 提示交互模块通过Transformer块使用生成的提示动态调整输入特征。提示与解码器特征在多个级别交互,以丰富特定于退化的上下文信息。 

给N个prompt-components P_c \in R^{N*\hat{H}*\hat{W}*\hat{C}}和 input feature F_1\in R^{\hat{H}*\hat{W}*\hat{C}}, prompt block 可表示为:

 

 其中,PGM表示提示生成模块,PIM表示提示交互模块。

Prompt Generation Module (PGM)

提示组件 P_c是一组可学习的参数,与输入特征交互,嵌入了退化信息。一种直接的特征-提示交互方法是直接使用学习到的提示来校准特征,可能会产生次优结果,因为它对输入内容是无知的。本文提出了提示生成模块(PGM),它从输入特征中动态预测基于注意力的权重,并将这些权重应用于提示组件,生成与输入条件相关的提示P。此外,PGM创建了一个共享空间,促进了提示组件之间的相关知识共享。PGM公式表达为:

Prompt Interaction Module (PIM)

PIM的主要目标是实现输入特征F_1和提示P之间的交互,以实现有指导的恢复过程。

在PIM中,沿着通道维度将生成的提示与输入特征进行拼接。接下来将拼接后的表示通过一个Transformer块进行处理,该块利用提示中编码的退化信息来转换输入特征。

​​本文的主要贡献是提示块,它是一个插件模块,与具体的架构无关。PromptIR框架中,使用了现有的Transformer块。Transformer块由两个顺序连接的子模块组成:多转置卷积头转置注意力(MDTA)和门控转置卷积前馈网络(GDFN)。MDTA在通道而不是空间维度上应用自注意操作,并具有线性复杂度。GDFN的目标是以可控的方式转换特征,即抑制信息较少的特征,只允许有用的特征在网络中传播。PIM的整体过程为:

实验

主要实验与可视化样例 

全能恢复设置下的比较:

去雾、去雨、去噪实验比较: 

 

去雾、去雨、去噪可视化比较:

消融实验 

关键代码

PGM

# https://github.com/va1shn9v/PromptIR/blob/main/net/model.py

##---------- Prompt Gen Module -----------------------
class PromptGenBlock(nn.Module):
    def __init__(self,prompt_dim=128,prompt_len=5,prompt_size = 96,lin_dim = 192):
        super(PromptGenBlock,self).__init__()
        self.prompt_param = nn.Parameter(torch.rand(1,prompt_len,prompt_dim,prompt_size,prompt_size))
        self.linear_layer = nn.Linear(lin_dim,prompt_len)
        self.conv3x3 = nn.Conv2d(prompt_dim,prompt_dim,kernel_size=3,stride=1,padding=1,bias=False)
        

    def forward(self,x):
        B,C,H,W = x.shape
        emb = x.mean(dim=(-2,-1))
        prompt_weights = F.softmax(self.linear_layer(emb),dim=1)
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B,1,1,1,1,1).squeeze(1)
        prompt = torch.sum(prompt,dim=1)
        prompt = F.interpolate(prompt,(H,W),mode="bilinear")
        prompt = self.conv3x3(prompt)

        return prompt

PromptIR

# https://github.com/va1shn9v/PromptIR/blob/main/net/model.py

class PromptIR(nn.Module):
    def __init__(self, 
        inp_channels=3, 
        out_channels=3, 
        dim = 48,
        num_blocks = [4,6,6,8], 
        num_refinement_blocks = 4,
        heads = [1,2,4,8],
        ffn_expansion_factor = 2.66,
        bias = False,
        LayerNorm_type = 'WithBias',   ## Other option 'BiasFree'
        decoder = False,
    ):

        super(PromptIR, self).__init__()

        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
        
        
        self.decoder = decoder
        
        if self.decoder:
            self.prompt1 = PromptGenBlock(prompt_dim=64,prompt_len=5,prompt_size = 64,lin_dim = 96)
            self.prompt2 = PromptGenBlock(prompt_dim=128,prompt_len=5,prompt_size = 32,lin_dim = 192)
            self.prompt3 = PromptGenBlock(prompt_dim=320,prompt_len=5,prompt_size = 16,lin_dim = 384)
        
        
        self.chnl_reduce1 = nn.Conv2d(64,64,kernel_size=1,bias=bias)
        self.chnl_reduce2 = nn.Conv2d(128,128,kernel_size=1,bias=bias)
        self.chnl_reduce3 = nn.Conv2d(320,256,kernel_size=1,bias=bias)



        self.reduce_noise_channel_1 = nn.Conv2d(dim + 64,dim,kernel_size=1,bias=bias)
        self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        
        self.down1_2 = Downsample(dim) ## From Level 1 to Level 2

        self.reduce_noise_channel_2 = nn.Conv2d(int(dim*2**1) + 128,int(dim*2**1),kernel_size=1,bias=bias)
        self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        
        self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3

        self.reduce_noise_channel_3 = nn.Conv2d(int(dim*2**2) + 256,int(dim*2**2),kernel_size=1,bias=bias)
        self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])

        self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4
        self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])])
        
        self.up4_3 = Upsample(int(dim*2**2)) ## From Level 4 to Level 3
        self.reduce_chan_level3 = nn.Conv2d(int(dim*2**1)+192, int(dim*2**2), kernel_size=1, bias=bias)
        self.noise_level3 = TransformerBlock(dim=int(dim*2**2) + 512, num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)
        self.reduce_noise_level3 = nn.Conv2d(int(dim*2**2)+512,int(dim*2**2),kernel_size=1,bias=bias)


        self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])


        self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2
        self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.noise_level2 = TransformerBlock(dim=int(dim*2**1) + 224, num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)
        self.reduce_noise_level2 = nn.Conv2d(int(dim*2**1)+224,int(dim*2**2),kernel_size=1,bias=bias)


        self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        
        self.up2_1 = Upsample(int(dim*2**1))  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

        self.noise_level1 = TransformerBlock(dim=int(dim*2**1)+64, num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type)
        self.reduce_noise_level1 = nn.Conv2d(int(dim*2**1)+64,int(dim*2**1),kernel_size=1,bias=bias)


        self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        
        self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)])
                    
        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, inp_img,noise_emb = None):

        inp_enc_level1 = self.patch_embed(inp_img)

        out_enc_level1 = self.encoder_level1(inp_enc_level1)
        
        inp_enc_level2 = self.down1_2(out_enc_level1)

        out_enc_level2 = self.encoder_level2(inp_enc_level2)

        inp_enc_level3 = self.down2_3(out_enc_level2)

        out_enc_level3 = self.encoder_level3(inp_enc_level3) 

        inp_enc_level4 = self.down3_4(out_enc_level3)        
        latent = self.latent(inp_enc_level4)
        if self.decoder:
            dec3_param = self.prompt3(latent)

            latent = torch.cat([latent, dec3_param], 1)
            latent = self.noise_level3(latent)
            latent = self.reduce_noise_level3(latent)
                        
        inp_dec_level3 = self.up4_3(latent)

        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)

        out_dec_level3 = self.decoder_level3(inp_dec_level3) 
        if self.decoder:
            dec2_param = self.prompt2(out_dec_level3)
            out_dec_level3 = torch.cat([out_dec_level3, dec2_param], 1)
            out_dec_level3 = self.noise_level2(out_dec_level3)
            out_dec_level3 = self.reduce_noise_level2(out_dec_level3)

        inp_dec_level2 = self.up3_2(out_dec_level3)
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)

        out_dec_level2 = self.decoder_level2(inp_dec_level2)
        if self.decoder:
           
            dec1_param = self.prompt1(out_dec_level2)
            out_dec_level2 = torch.cat([out_dec_level2, dec1_param], 1)
            out_dec_level2 = self.noise_level1(out_dec_level2)
            out_dec_level2 = self.reduce_noise_level1(out_dec_level2)
        
        inp_dec_level1 = self.up2_1(out_dec_level2)
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        
        out_dec_level1 = self.decoder_level1(inp_dec_level1)

        out_dec_level1 = self.refinement(out_dec_level1)


        out_dec_level1 = self.output(out_dec_level1) + inp_img


        return out_dec_level1

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/210738.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

排序算法总结(Python、Java)

Title of Content 1 冒泡排序 Bubble sort概念排序可视化代码实现Python - 基础实现Python - 优化实现Java - 优化实现C - 优化实现C - 优化实现 2 选择排序 Selection sort概念排序可视化代码实现PythonJava 3 插入排序 Insertion sort概念 1 冒泡排序 Bubble sort 概念 解释…

面试就是这么简单,offer拿到手软(二)—— 常见65道非技术面试问题

面试系列: 面试就是这么简单,offer拿到手软(一)—— 常见非技术问题回答思路 面试就是这么简单,offer拿到手软(二)—— 常见65道非技术面试问题 文章目录 一、前言二、常见65道非技术面试问题…

算法题-统计字符个数(Python题解)

文章目录 前言思路code 前言 先前笔试做了一道算法题,题目是这样子的:(PS:不用惊讶,是的,我不打算24今年考研了,一是,当初填报的学校不是我想要去的学校(当初想一战成硕…

CAPL通过ethernetPacket发送以太网报文

文章目录 ethernetPacketCANoe帮助文档车载以太网协议函数CAPL通过ethernetPacket发送以太网报文例子ethernetPacket CANoe中,ethernetPacket类似于CAN的message. CANoe帮助文档 CANoe的帮助文档是很好的学习资料,后面会结合CANoe帮助文档来介绍车载以太网的相关内容。 车…

竞赛选题 : 题目:基于深度学习的水果识别 设计 开题 技术

1 前言 Hi,大家好,这里是丹成学长,今天做一个 基于深度学习的水果识别demo 这是一个较为新颖的竞赛课题方向,学长非常推荐! 🧿 更多资料, 项目分享: https://gitee.com/dancheng-senior/pos…

3_企业级Nginx使用-day2

企业级Nginx使用-day2 学习目标和内容 1、能够编译安装并使用第三方模块 2、能够理解location语法的作用 3、能够了解URL的rewrite重写规则 4、能够理解防盗链原理和实现 一、第三方模块使用 Nginx官方没有的功能,开源开发者定制开发一些功能,把代码公…

Linux中文件的打包压缩、解压,下载到本地——zip,tar指令等

目录 1 .zip后缀名: 1.1 zip指令 1.2 unzip指令 2 .tar后缀名 3. sz 指令 4. rz 指令 5. scp指令 1 .zip后缀名: 1.1 zip指令 语法:zip [namefile.zip] [namefile]... 功能:将目录或者文件压缩成zip格式 常用选项&#xff1a…

百度智能云文字识别使用问题解决合集

1.创建试用程序时需要16位的签名MD5 解决方法:使用Java8 201版本及以下的jdk创建签名 下载地址:http://www.codebaoku.com/jdk/jdk-oracle-jdk1-8.html#jdk8u201 生成签名代码:keytool -genkeypair -v -keystore D:\key.jks -storetype PKC…

Android实验:启动式service

目录 实验目的实验内容实验要求项目结构代码实现结果展示 实验目的 充分理解Service的作用,与Activity之间的区别,掌握Service的生命周期以及对应函数,了解Service的主线程性质;掌握主线程的界面刷新的设计原则,掌握启…

Java研学-配置文件

一 配置文件 1 作用–解决硬编码的问题 在实际开发中,有时将变量的值直接定义在.java源文件中;如果维护人员想要修改数据,无法完成(因为没有修改权限),这种操作称之为硬编码 2 执行原理: 将经常需要改变的数据定义在指定类型的文件中,通过java代码对指定的类型的文件进行操作…

(C语言)找出1-99之间的全部同构数

同构数&#xff1a;它出现在平方数的右边。例&#xff1a;5是25右边的数&#xff0c;25是625右边的数&#xff0c;即5和25均是同构数。 #include<stdio.h> int main() {for(int i 1;i < 100;i ){if((i*i % 10 i) || (i*i % 100 i))printf("%d\t%d\n",i,…

神经网络(第三周)

一、简介 1.1 非线性激活函数 1.1.1 tanh激活函数 使用一个神经网络时&#xff0c;需要决定在隐藏层上使用哪种激活函数&#xff0c;哪种用在输出层节点上。到目前为止&#xff0c;只用过sigmoid激活函数&#xff0c;但是&#xff0c;有时其他的激活函数效果会更好。tanh函数…

图文深入理解TCP三次握手

前言 TCP三次握手和四次挥手是面试题的热门考点&#xff0c;它们分别对应TCP的连接和释放过程&#xff0c;今天我们先来认识一下TCP三次握手过程&#xff0c;以及是否可以使用“两报文握手”建立连接&#xff1f;。 1、TCP是什么&#xff1f; TCP是面向连接的协议&#xff0c;…

Asp.Net Core Web Api内存泄漏问题

背景 使用Asp.Net Core Web Api框架开发网站中使用到了tcp socket通信&#xff0c;网站作为服务端开始tcp server&#xff0c;其他的客户端不断高速给它传输信息时&#xff0c;tcp server中读取信息每次申请的byte[]没有得到及时的释放&#xff0c;导致内存浪费越来越多&#…

从cmd登录mysql

说明 先看看mysql.exe文件在哪个目录下&#xff0c;为了后面的操作方便&#xff0c;可以将该文件所在的路径增加到环境变量path中。 如果不增加到path环境变量中&#xff0c;那么在cmd窗口就要切换到mysql.exe文件所在的目录下执行。 在cmd窗口查看mysql命令的帮助信息 在cm…

vue中实现纯数字键盘

一、完整 代码展示 <template><div class"login"><div class"login-content"><img class"img" src"../../assets/image/loginPhone.png" /><el-card class"box-card"><div slot"hea…

【蓝桥杯软件赛 零基础备赛20周】第6周——栈

文章目录 1. 基本数据结构概述1.1 数据结构和算法的关系1.2 线性数据结构概述1.3 二叉树简介 2. 栈2.1 手写栈2.2 CSTL栈2.3 Java 栈2.4 Python栈 3 习题 1. 基本数据结构概述 很多计算机教材提到&#xff1a;程序 数据结构 算法。 “以数据结构为弓&#xff0c;以算法为箭”…

Linux shell中的函数定义、传参和调用

Linux shell中的函数定义、传参和调用&#xff1a; 函数定义语法&#xff1a; [ function ] functionName [()] { } 示例&#xff1a; #!/bin/bash# get limit if [ $# -eq 1 ] && [ $1 -gt 0 ]; thenlimit$1echo -e "\nINFO: input limit is $limit" e…

Python程序员入门指南:就业前景

文章目录 标题Python程序员入门指南&#xff1a;就业前景Python 就业数据Python的就业前景SWOT分析法Python 就业分析 标题 Python程序员入门指南&#xff1a;就业前景 Python是一种流行的编程语言&#xff0c;它具有简洁、易读和灵活的特点。Python可以用于多种领域&#xff…

Zabbix HA高可用集群搭建

Zabbix HA高可用集群搭建 Zabbix HA高可用集群搭建一、Zabbix 高可用集群&#xff08;Zabbix HA&#xff09;二、部署Zabbix高可用集群1、两个服务端配置1.1主节点 Zabbix Server 配置1.2 备节点 Zabbix Server 配置1.3 主备节点添加监控主机1.4 查看高可用集群状态 2、两个客户…