【前沿模型解析】潜在扩散模型 2-1 | 手撕感知图像压缩 基础块ResNet块

文章目录

  • 1 残差结构回顾
  • 2 LDM结构中的残差结构设计
    • 2.1 组归一化GroupNorm层
    • 2.2 激活函数层
    • 2.3 卷积层
    • 2.4 dropout层
  • 3 代码实现

1 残差结构回顾

残差结构应该是非常重要的基础块之一了,你肯定会在各种各样的网络模型结构里看到残差结构,他是非常强大的~

残差结构往往可以抽象为如下的结构

在这里插入图片描述

为什么要设置残差模块?

因为在深度学习中,网络层级很多复杂,非常重要的一部分是前后的梯度信息流动,否则很容易发生梯度爆炸或梯度消失

想象一下我们开车去一个遥远的目的地。我们可以选择直接开到目的地,也可以选择在途中设置几个“中转站”,但是中转站多了有可能会丢失最终目的地的信息,中途有好玩的就被吸引了,所以我们有时候要直接开车前往目的地,避免一些信息的丢失或遗忘。

2 LDM结构中的残差结构设计

第一阶段的残差块是主要的部分

负责不断将送进来的图像进行特征提取~,是多个ResnetBlock前后堆叠起来形成的

  • 其中中间部分的ResnetBlock是不改变特征图的通道和尺寸大小
  • 最后一个ResnetBlock将通道数加倍

在这里插入图片描述

2.1 组归一化GroupNorm层

不是用的传统的批归一化BN,而是用的组归一化GN

是因为BN在Batch_size比较小的时候,表现很差,而我们在图像生成的实际任务中,由于分辨率比较大,所以Batch_size往往比较小

这时候GN的效果会更好

GN实现方式就是按照通道去分组,每组各自归一化,比如图例中的通道数是128,分为32组,那么每组就是4个通道进行归一化

像了解更多归一化知识可以看文章

全面解读Group Normalization

2.2 激活函数层

没有采用传统的ReLU什么的,而是采用了一个组合形式

这是作者尝试不同实验效果比较好的

其实启发我们在自己的平时网络设计过程中,也可以进行损失函数的调整

2.3 卷积层

卷积层就很熟悉了

如果是中间层的ResNetBlock我们要实现两点

  1. 维持前后通道数不变
  2. 维持前后尺寸大小不变

所以用了这样的设计

  • 小卷积核3*3的大小,步长1,填充1

经过以后,尺寸大小可以维持不变

公式

特征图长或宽 − 卷积核尺寸 + 2 ∗ 填充尺寸 步长 + 1 \frac{特征图长或宽-卷积核尺寸+2*填充尺寸}{步长}+1 步长特征图长或宽卷积核尺寸+2填充尺寸+1

如果是最后的ResNetBlock

则进行通道数的改变

2.4 dropout层

防止过拟合的经典操作,让部分神经元失活(变为0)组织信息传递,避免模型能力过强

最后再经过一个残差输出即可

3 代码实现

from torch import nn
import torch

class ResnetBlock(nn.Module):
    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
                 dropout):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
        self.conv1 = torch.nn.Conv2d(in_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)

        self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels,
                                     out_channels,
                                     kernel_size=3,
                                     stride=1,
                                     padding=1)
        if self.in_channels != self.out_channels:#如果分辨率改变则进行更小的卷积核
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv2d(in_channels,
                                                     out_channels,
                                                     kernel_size=3,
                                                     stride=1,
                                                     padding=1)
            else:
                self.nin_shortcut = torch.nn.Conv2d(in_channels,
                                                    out_channels,
                                                    kernel_size=1,
                                                    stride=1,
                                                    padding=0)
    def forward(self, x):
        h = x
        h = self.norm1(h) #靠后面的时候有问题 shape [80, 256, 8, 8]
        h = h*torch.sigmoid(h)
        h = self.conv1(h)
        h = self.norm2(h)
        h = h*torch.sigmoid(h)
        h = self.dropout(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                x = self.conv_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        return x+h

其中有几个特别的参数需要说明

  • in_channels 就是输入ResNet的网络通道数,out_channels 就是输出ResNet块的网络通道数,注意后者默认是None,这样的话就不会改变通道数,适用于中间ResNet块部分,最后一层的ResNet块要给出out_channels,方便做通道更改
  • use_conv_shortcut决定着最后一层的ResNet改变通道数的时候的方式,为True的话则是size=3的卷积核,而如果是false的话则是size为1的小卷积核
  • dropout是失活比率

注:与源码相比,简化了time_emb相关的内容,因为在源码中这部分也没有排上用场

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/523723.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

二叉搜索树、AVL树、红黑树

为者常成,行者常至 文章目录 二叉搜索树节点查找插入重头戏——删除叶子节点只有一个子节点有两个子节点 分析 平衡二叉搜索树右单旋左右双旋插入的四种情况左左右右左右右左插入操作 小结 红黑树 二叉搜索树 二叉搜索树就是在二叉树的基础上增加一些规则&#xff…

【LeetCode】排序数组——不一样的方式实现快排

目录 题目链接 颜色分类 算法原理 代码实现 排序数组 算法原理 代码实现 最小的k个数 算法原理 代码实现 题目链接 LeetCode链接:75. 颜色分类 - 力扣(LeetCode) LeetCode链接:912. 排序数组 - 力扣(L…

【三十七】【算法分析与设计】STL 练习,凌波微步,栈和排序,吐泡泡,[HNOI2003]操作系统,优先队列自定义类型

凌波微步 链接:登录—专业IT笔试面试备考平台_牛客网 来源:牛客网 时间限制:C/C 1 秒,其他语言 2 秒 空间限制:C/C 32768K,其他语言 65536K 64bit IO Format: %lld 题目描述 小 Z 的体型实在是太胖了&…

【论文复现|智能算法改进】改进猎人猎物优化算法在WSN覆盖中的应用

目录 1.算法原理2.改进点3.结果展示4.参考文献 1.算法原理 【智能算法】猎人猎物算法(HPO)原理及实现 【智能算法应用】猎人猎物优化算法(HPO)在WSN覆盖中的应用 2.改进点 差分进化 自适应α变异 全局最优引导的动态反向学…

中仕公考:2024年成人高考大专能考事业编吗?

关于2024年成人高考大专学历是否具备报考事业单位编制的资格,相关规定明确地指出,该学历符合国家认证标准,并可在学信网进行验证。持有成人高考大专学历的考生,在满足其他职位需求的条件下,是有资格参加事业编考试的。…

VIM支持C/C++/Verilog/SystemVerilog配置并支持Win/Linux环境的配置

作为一个芯片公司打杂人口,同时兼数字IC和软件,往往需要一个皮实耐打上天入地的编辑器… 一、先附上github路径,方便取走 git clone gitgithub.com:qqqw4549/vim_config_c_verilog.git 二、效果展示 支持ctrl]函数/模块跳转,支持…

书生·浦语大模型实战营之茴香豆:搭建你的 RAG 智能助理

书生浦语大模型实战营之茴香豆:搭建你的 RAG 智能助理 RAG(Retrieval Augmented Generation)技术,通过检索与用户输入相关的信息,并结合外部知识库来生成更准确、更丰富的回答。解决 LLMs 在处理知识密集型任务时可能遇…

学习CSS Flexbox 玩flexboxfroggy flexboxfroggy1-24关详解

欢迎来到Flexbox Froggy,这是一个通过编写CSS代码来帮助Froggy和朋友的游戏! justify-content 和 align-items 是两个用于控制 CSS Flexbox 布局的属性。 justify-content:该属性用于控制 Flexbox 容器中子项目在主轴(水平方向)…

C++算法 —— 位运算

一、基本的位运算操作 1.基础位运算操作符 << : 二进制位整体左移 >> : 二进制位整体右移 ~ : 按位取反 & &#xff1a; 按位与 | &#xff1a; 按位或 ^ : 按位异或 &#xff08;无进位相加&#xff09; 2.给一个数n&#xff0c;确定它的二进制表示中第…

聚类算法 | Kmeans:肘方法、Kmeans++、轮廓系数 | DBSCAN

目录 一. 聚类算法划分方式1. 划分式2. 层次式3. 基于密度4. 基于网络5. 基于模型 二. K-means算法1. K-means算法公式表达2. K-means算法流程3. Kmeans算法缺点3.1 肘方法3.2 k-means 算法3.2.1 k-means|| 算法 3.3 Mini Batch K-Means 算法 4. 聚类评估 三. DBSCAN算法1. DBS…

秋招学习数据库LeetCode刷题

数据库基本知识以前学过次数较多&#xff0c;今天看完一遍后都是可以理解的。直接刷Leetcode题吧 牛客上题库刷基础&#xff0c;Leetcode刷 写语句题(争取坚持每日2个sql语句题) 牛客&#xff1a;https://www.nowcoder.com/exam/intelligent?questionJobId10&tagId21015 L…

C#速览入门

C# & .NET C# 程序在 .NET 上运行&#xff0c;而 .NET 是名为公共语言运行时 (CLR) 的虚执行系统和一组类库。 CLR 是 Microsoft 对公共语言基础结构 (CLI) 国际标准的实现。 CLI 是创建执行和开发环境的基础&#xff0c;语言和库可以在其中无缝地协同工作。 用 C# 编写的…

私域必备神器来袭!朋友圈转发一键搞定!

在私域运营中&#xff0c;朋友圈的转发和管理是至关重要的环节。为了能够更好的管理和运营多个微信号的朋友圈&#xff0c;很多人都开始使用微信管理系统来辅助自己开展管理和运营工作。 接下来&#xff0c;就一起来看看微信管理系统的朋友圈管理功能吧&#xff01; 首先&…

π162U31 增强ESD功能,3.0kVrms隔离电压150kbps 六通道数字隔离器 可提高工业应用系统性能

π162U31产品概述&#xff1a; π162U31是一款数字隔离器&#xff0c;π162U31数字隔离器具有出色的性能特 征和可靠性&#xff0c;整体性能优于光耦和基于其他原理的数字隔离器 产品。 智能分压技术(iDivider技术)利用电容分压原理&#xff0c; 在不需要调制和解调的情况下&a…

【测试开发学习历程】python流程控制

前言&#xff1a;写到这里也许自己真的有些疲惫了&#xff0c;但是人生不就是像西西弗斯推石上山一样的枯燥乏味吗&#xff1f; 在python当中的流程控制主要包含以下两部分的内容&#xff1a; 条件判断 循环 1 条件判断 条件判断用if语句实现&#xff0c;if语句的几种格式…

【研发管理】产品经理知识体系-数字化战略

导读: 数字化战略对于企业的长期发展具有重要意义。实施数字化战略需要企业从多个方面进行数字化转型和优化&#xff0c;以提高效率和创新能力&#xff0c;并实现长期竞争力和增长。 目录 1、定义 2、数字化战略必要性 3、数字战略框架 4、数字化转型对产品和服务设计的影响…

京东云轻量云主机8核16G配置租用价格1198元1年、4688元三年

京东云轻量云主机8核16G服务器租用优惠价格1198元1年、4688元三年&#xff0c;配置为8C16G-270G SSD系统盘-5M带宽-500G月流量&#xff0c;华北-北京地域。京东云8核16G服务器活动页面 yunfuwuqiba.com/go/jd 活动链接打开如下图&#xff1a; 京东云8核16G服务器优惠价格 京东云…

截稿倒计时 CCF-B 推荐会议EGSR’24论文摘要4月9日提交

会议之眼 快讯 第35届EGSR 2024 (Eurographics Symposium on Rendering)即欧洲图形学渲染专题讨论会将于 2024 年 7月3日-5日在英国伦敦帝国理工学院举行&#xff01;在EGSR之前&#xff0c;将有一个全新的联合研讨会&#xff0c;即材料外观建模&#xff08;MAM&#xff09;和…

非机构化解析【包含PDF、word、PPT】

此项目是针对PDF、docx、doc、PPT四种非结构化数据进行解析&#xff0c;识别里面的文本和图片。 代码结构 ├── Dockerfile ├── requirements ├── resluts ├── test_data │ ├── 20151202033304658.pdf │ ├── 2020_World_Energy_Data.pdf │ ├── …

JVM字节码与类的加载——class文件结构

文章目录 1、概述1.1、class文件的跨平台性1.2、编译器分类1.3、透过字节码指令看代码细节 2、虚拟机的基石&#xff1a;class文件2.1、字节码指令2.2、解读字节码方式 3、class文件结构3.1、魔数&#xff1a;class文件的标识3.2、class文件版本号3.3、常量池&#xff1a;存放所…