Glow: Generative Flow with Invertible 1×1 Convolutions论文解析及实现(二)

Glow: Generative Flow with Invertible 1×1 Convolutions
代码github: https://github.com/rosinality/glow-pytorch添加链接描述

1 模型架构如下

1.1 左边图flow模型

Flow model
① ActNorm
② InvConv2dLU
③ AffineCoupling

1.2 右边模型结构Glow模型

Glow Model
Block (L-1)个
Flow K

在这里插入图片描述
在这里插入图片描述

2 Flow层

在这里插入图片描述

2.1 ActNorm

在这里插入图片描述

1)ActNorm 就是一个线性函数:

		# y=a*x +b 
        self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1))   # b
        self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1))  #  a 

2) 参数初始化:
取第一批数据,计算均值方差,初始化 a 和b

 flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
            mean = ( flatten.mean(1) .unsqueeze(1) .unsqueeze(2).unsqueeze(3)
                .permute(1, 0, 2, 3)  )
            std = ( flatten.std(1)  .unsqueeze(1) .unsqueeze(2) .unsqueeze(3)
                .permute(1, 0, 2, 3))
            self.loc.data.copy_(-mean)
            self.scale.data.copy_(1 / (std + 1e-6))

3)log deteminant

log_abs = logabs(self.scale)
  logdet = height * width * torch.sum(log_abs)
  1. 反函数
 output / self.scale - self.loc

2.2 Invertible 1*1 convolution

在这里插入图片描述

2.2.1使用LU 分解计算行列式

1)权重计算

        weight = (
            self.w_p
            @ (self.w_l * self.l_mask + self.l_eye)
            @ ((self.w_u * self.u_mask) + torch.diag(self.s_sign * torch.exp(self.w_s)))
        )

在这里插入图片描述

2)前向计算

  out = F.conv2d(input, weight)

3) 行列式计算

logdet = height * width * torch.sum(self.w_s)

在这里插入图片描述
4)反函数计算

 F.conv2d(output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))

2.2.2 行列式直接计算方法

在这里插入图片描述

2.3 AffineCoupling

1)网络结构

        self.net = nn.Sequential(
            nn.Conv2d(in_channel // 2, filter_size, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(filter_size, filter_size, 1),
            nn.ReLU(inplace=True),
            ZeroConv2d(filter_size, in_channel if self.affine else in_channel // 2),
        )

2)前向公式

in_a, in_b = input.chunk(2, 1)
log_s, t = self.net(in_a).chunk(2, 1)
# s = torch.exp(log_s)
s = F.sigmoid(log_s + 2)
# out_a = s * in_a + t
out_b = (in_b + t) * s

torch.cat([in_a, out_b], 1)

3)log 行列式

 logdet = torch.sum(torch.log(s).view(input.shape[0], -1), 1)

3)反函数

out_a, out_b = output.chunk(2, 1)


log_s, t = self.net(out_a).chunk(2, 1)
# s = torch.exp(log_s)
s = F.sigmoid(log_s + 2)
# in_a = (out_a - t) / s
in_b = out_b / s - t

torch.cat([out_a, in_b], 1)

在这里插入图片描述

2.4 Flow 模型

1)前向计算及log行列式

out, logdet = self.actnorm(input)
out, det1 = self.invconv(out)
out, det2 = self.coupling(out)
logdet = logdet + det1+det2

2)反函数

        input = self.coupling.reverse(output)
        input = self.invconv.reverse(input)
        input = self.actnorm.reverse(input)

3 Glow 层

在这里插入图片描述

Flow * k
Block * (L-s)

3.1 Block

3.1.1 squeeze

        squeezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2)
        squeezed = squeezed.permute(0, 1, 3, 5, 2, 4)
        out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2)

3.1.2 step of flow

        for i in range(n_flow):
            self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu))

3.1.3 split

            out, z_new = out.chunk(2, 1)
            mean, log_sd = self.prior(out).chunk(2, 1)
            log_p = gaussian_log_p(z_new, mean, log_sd)
            log_p = log_p.view(b_size, -1).sum(1)

3.1.4block 层

1)反函数

   input = output

        if reconstruct:
            if self.split:
                input = torch.cat([output, eps], 1)

            else:
                input = eps

        else:
            if self.split:
                mean, log_sd = self.prior(input).chunk(2, 1)
                z = gaussian_sample(eps, mean, log_sd)
                input = torch.cat([output, z], 1)

            else:
                zero = torch.zeros_like(input)
                # zero = F.pad(zero, [1, 1, 1, 1], value=1)
                mean, log_sd = self.prior(zero).chunk(2, 1)
                z = gaussian_sample(eps, mean, log_sd)
                input = z

        for flow in self.flows[::-1]:
            input = flow.reverse(input)

        b_size, n_channel, height, width = input.shape

        unsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width)
        unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3)
        unsqueezed = unsqueezed.contiguous().view(
            b_size, n_channel // 4, height * 2, width * 2
        )

        return unsqueezed

3.2 Glow层

1)前向计算

     self.blocks = nn.ModuleList()
        n_channel = in_channel
        for i in range(n_block - 1):
            self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu))
            n_channel *= 2
        self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine))

2)反函数

        for i, block in enumerate(self.blocks[::-1]):
            if i == 0:
                input = block.reverse(z_list[-1], z_list[-1], reconstruct=reconstruct)

            else:
                input = block.reverse(input, z_list[-(i + 1)], reconstruct=reconstruct)

        return input

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/51263.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】-进程概念及进程状态(僵尸进程和孤儿进程)

💖作者:小树苗渴望变成参天大树🎈 🎉作者宣言:认真写好每一篇博客💤 🎊作者gitee:gitee✨ 💞作者专栏:C语言,数据结构初阶,Linux,C 动态规划算法🎄 如 果 你 …

更安全,更省心丨DolphinDB 数据库权限管理系统使用指南

在数据库产品使用过程中,为保证数据不被窃取、不遭破坏,我们需要通过用户权限来限制用户对数据库、数据表、视图等功能的操作范围,以保证数据库安全性。为此,DolphinDB 提供了具备以下主要功能的权限管理系统: 提供用户…

OpenMP

官方文档:OpenMP | LLNL HPC Tutorials OpenMP总览 统一内存访问:OpenMP、Pthreads 非统一内存访问:MPI OpenMP与Pthread OpenMP原理 串行区到达并行区后会派生多个线程,并行区代码执行完后进行线程合并,剩下主线程 编…

Linux - PostgreSQL 适用于9.x 以上的 tar.gz 源码安装与理解 - 报错集锦

这里写目录标题 序言主要内容bash 配置文件个人理解关于初始化 PostgreSQL 数据库的理解 启动方法检查服务器是否在PostgreSQL中运行关闭 postgresql 数据库方法参考链接 序言 PostgreSQL 9.x 以下版本笔者没用过,具体操作看参考链接,笔者就不记录重复操…

MODBUS-TCP转Ethernet IP 网关连接空压机 配置案例

本案例是工业现场应用捷米特JM-EIP-TCP的Ethernet/IP转Modbus-TCP网关连接欧姆龙PLC与空压机的配置案例。使用设备:欧姆龙PLC,捷米特JM-EIP-TCP网关, ETHERNET/IP 的电气连接 ETHERNET/IP 采用标准的 T568B 接法,支持直连和交叉接…

在centos 7系统docker上构建mysql 5.7

一、VM上已经安装centos 7.9,且已完成docker的构建 二、安装mysql5.7 安装镜像:[rootlocalhost lll]# docker pull mysql:5.7 查看镜像[rootlocalhost lll]# docker images 根据镜像id构建mysql容器,且分配端口号[rootlocalhost lll]# dock…

自定义view - 玩转字体变色

自定义View步骤: 1>:values__attrs.xml,定义自定义属性; 2>:在第三个构造方法中获取自定义属性; 3>:onMeasure【不是必须的】; 4>:onDraw:绘制代…

emacs打开git仓库下多个子工程的根目录问题解决案

emacs打开git仓库下多个子工程的根目录问题解决案 问题描述 如题所述,这个问题困扰我很久了,一直没搜到完整的解决方案。这次终于乘着空闲时间,研究了projectile.el源码找到了方案。 问题场景具体描述下: 我自己有一个私人git仓库&#x…

机器学习:GPT3

GPT3 模型过于巨大 GPT3是T5参数量的10倍! 训练GPT3的代价是$12百万美元 Zero-shot Ability GPT3的思想是不是能拿掉Fine-tune 只需要给定few-shot或者zero-shot就能干相应的任务了。 few-shot learning(no gradient descent)&#…

(学习笔记)matplotlib.pyplot模块下基本画图函数的整理

matplotlib版本:3.7.1 python版本:3.10.12 基本函数 matplotlib版本:3.7.1python版本:3.10.12 1. plt.plot()函数1.1 plt.plot(x, y)1.2 plt.plot(x, y, **kwargs) 2. plt.xlable(), plt.ylable()3. plt.title()4. plt.show()5.p…

SkyWalking链路追踪-技术文档首页

SkyWalking 文档中文版(社区提供) (skyapm.github.io)https://skyapm.github.io/document-cn-translation-of-skywalking/ SkyWalking-基本概念 SkyWalking链路追踪是一个用于分布式系统的性能监控工具,它帮助开发人员了解系统中各组件之间…

向量vector模板输出、倒置、求和

运行代码&#xff1a; //向量vector模板输出、倒置、求和 #include"std_lib_facilities.h" //定义vector<double>的输入操作符>> istream& operator>>(istream& is, vector<double>& vv) {double dd0.0;if(is >> dd)vv.p…

【如何训练一个中英翻译模型】LSTM机器翻译模型部署之ncnn(python)(五)

系列文章 【如何训练一个中英翻译模型】LSTM机器翻译seq2seq字符编码&#xff08;一&#xff09; 【如何训练一个中英翻译模型】LSTM机器翻译模型训练与保存&#xff08;二&#xff09; 【如何训练一个中英翻译模型】LSTM机器翻译模型部署&#xff08;三&#xff09; 【如何训练…

Unity光照相关知识和实践 (烘焙光照,环境光设置,全局光照)

简介 本文将会通过一个简单的场景搭建&#xff0c;介绍如何使用烘焙光照以及相关的注意事项。另外还介绍了Unity内全局光照&#xff08;GI&#xff09;的知识和GI实际在游戏内的表现效果。 Unity关于光照相关的参考文档地址&#xff1a;https://docs.unity.cn/cn/current/Man…

Linux CentOS快速安装VNC并开启服务

以下是在 CentOS 上安装并开启 VNC 服务的步骤&#xff1a; 安装 VNC 服务器软件包。运行以下命令&#xff1a; sudo yum install tigervnc-server 输出 $ sudo yum install tigervnc-server Loaded plugins: fastestmirror, langpacks Repository epel is missing name i…

计算机论文中名词翻译和解释笔记

看论文中一些英文的简写不知道中文啥意思&#xff0c;或者一个名词不知道啥意思。 于是自己做了一个个人总结。 持续更新 目录 SoftmaxDeep Learning(深度学习)循环神经网络(Recurrent Neural Network简称 RNN)损失函数/代价函数(Loss Function)基于手绘草图的三维模型检索(Ske…

【笔记】PyTorch DDP 与 Ring-AllReduce

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhang.cn] 文内若有错误&#xff0c;欢迎指出&#xff01; 今天我想跟大家分享的是一篇虽然有点老&#xff0c;但是很经典的文章&#xff0c;这是一个在分布式训练中会用到的一项技术&#xff0c; 实际上叫ringallreduce。 …

用html+javascript打造公文一键排版系统8:主送机关排版

公文一般在标题和正文之间还有主送机关&#xff0c;相关规定为&#xff1a; 主送机关 编排于标题下空一行位置&#xff0c;居左顶格&#xff0c;回行时仍顶格&#xff0c;最后一个机关名称后标全角冒号。如主送机关名称过多导致公文首页不能显示正文时&#xff0c;应当将主送机…

redis的并发安全问题:redis的事务VSLua脚本

redis为什么会发生并发安全问题&#xff1f; 在redis中&#xff0c;处理的数据都在内存中&#xff0c;数据操作效率极高&#xff0c;单线程的情况下&#xff0c;qps轻松破10w。反而在使用多线程时&#xff0c;为了保证线程安全&#xff0c;采用了一些同步机制&#xff0c;以及多…

20.3 HTML 表格

1. table表格 table标签是HTML中用来创建表格的元素. table标签通常包含以下子标签: - th标签: 表示表格的表头单元格(table header), 用于描述列的标题. - tr标签: 表示表格的行(table row). - td标签: 表示表格的单元格(table data), 通常位于tr标签内, 用于放置单元格中的…