improved-diffusion-main代码理解

目录

  • 一、 TimestepEmbedSequential
  • 二、PyTorch之Checkpoint机制
  • 三、AttentionBlock
  • 四、use_scale_shift_norm

和nanoDiffusion-main相比,improved-diffusion-main代码是相似的,但有几个不是很好理解的地方记录一下。

一、 TimestepEmbedSequential

代码中class ResBlock继承自TimestepBlock,需要执行时间步嵌入操作,其他不需要。

class TimestepBlock(nn.Module):
    """
    Any module where forward() takes timestep embeddings as a second argument.
    """

    @abstractmethod
    def forward(self, x, emb):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """


class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            else:
                x = layer(x)
        return x
class ResBlock(TimestepBlock):

二、PyTorch之Checkpoint机制

def checkpoint(func, inputs, params, flag):
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.

    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    """
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)

checkpoint 是在 torch.no_grad() 模式下计算的目标操作的前向函数,这并不会修改原本的叶子结点的状态,有梯度的还会保持。只是关联这些叶子结点的临时生成的中间变量会被设置为不需要梯度,因此梯度链式关系会被断开。

三、AttentionBlock


class AttentionBlock(nn.Module):
    def __init__(self, channels, num_heads=1, use_checkpoint=False):
        super().__init__()
        self.channels = channels
        self.num_heads = num_heads
        self.use_checkpoint = use_checkpoint

        self.norm = normalization(channels)
        self.qkv = conv_nd(1, channels, channels * 3, 1)
        self.attention = QKVAttention()
        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))

    def forward(self, x):
        return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)

    def _forward(self, x):
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)
        qkv = self.qkv(self.norm(x))
        qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
        h = self.attention(qkv)
        h = h.reshape(b, -1, h.shape[-1])
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)


class QKVAttention(nn.Module):
    """
    A module which performs QKV attention.
    """

    def forward(self, qkv):
        """
        Apply QKV attention.

        :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x C x T] tensor after attention.
        """
        ch = qkv.shape[1] // 3
        q, k, v = th.split(qkv, ch, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = th.einsum(
            "bct,bcs->bts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
        return th.einsum("bts,bcs->bct", weight, v)

    @staticmethod
    def count_flops(model, _x, y):
    
        b, c, *spatial = y[0].shape
        num_spatial = int(np.prod(spatial))
        # We perform two matmuls with the same number of ops.
        # The first computes the weight matrix, the second computes
        # the combination of the value vectors.
        matmul_ops = 2 * b * (num_spatial ** 2) * c
        model.total_ops += th.DoubleTensor([matmul_ops])

下面这个函数是准备适合的qkv矩阵

    def _forward(self, x):
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)-》输入转换为(b,c,N)
        qkv = self.qkv(self.norm(x))-》通过卷积转换为(b,3*c,N)
        qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
        h = self.attention(qkv)
        h = h.reshape(b, -1, h.shape[-1])
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)

class QKVAttention的forward就是下面的公式:
在这里插入图片描述

四、use_scale_shift_norm

    def _forward(self, x, emb):
        h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = th.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift
            h = out_rest(h)
        else:
            h = h + emb_out
            h = self.out_layers(h)
        return self.skip_connection(x) + h

在深度学习中,特别是在处理如扩散模型(Diffusion Models)或任何需要精细控制输出特征的神经网络时,use_scale_shift_norm引入一种灵活的变换,这种变换通过缩放(scale)和平移(shift)来调整网络层的输出。
use_scale_shift_norm是一个布尔值(True或False),用于决定是否应用这种缩放和平移的归一化方法。如果use_scale_shift_norm为True,则执行以下步骤:

分割输出层:首先,代码将self.out_layers(一个包含网络层的列表)分割为两部分。out_norm是列表中的第一个层,负责进行某种形式的归一化或变换(尽管这里的名字是out_norm,但它可能不仅仅执行归一化,而是任何形式的变换层)。out_rest是列表中剩余的所有层,这些层将在缩放和平移之后应用。
提取缩放和平移参数:接下来,从emb_out(可能是嵌入层的输出或其他某种特征表示)中提取缩放(scale)和平移(shift)参数。这里假设emb_out的维度被设计为包含这两组参数,通过th.chunk(emb_out, 2, dim=1)沿着第二维(dim=1)将其分割成两部分,分别代表缩放和平移参数。
应用缩放和平移:然后,将h(可能是之前某个层的输出)通过out_norm层进行变换,之后使用从emb_out中提取的缩放和平移参数对结果进行调整。调整的方式是将out_norm(h)的输出乘以(1 + scale)并加上shift。这个步骤实质上是在对out_norm(h)的输出进行线性变换,以引入额外的灵活性和控制。
通过剩余层:最后,将调整后的输出h通过out_rest中剩余的层进行进一步的处理。
这种技术的一个关键优势是它能够以一种灵活且数据驱动的方式调整网络层的输出,而不需要在模型架构中硬编码特定的归一化或变换策略。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/776154.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

中国动物志(140卷)

中国动物志&#xff0c;共140卷&#xff0c;包括昆虫纲、鸟纲、兽纲、无脊椎动物、硬骨鱼纲等多类&#xff0c;是反映我国动物分类区系研究工作成果的系列专著&#xff0c;是研究物种多样性、探讨物种演化和系统发育的重要参考&#xff0c;是动物资源开发利用、有害物种控制、濒…

charles使用教程

安装与配置 下载链接&#xff1a;https://www.charlesproxy.com/download/ 进行移动端抓包&#xff1a; 电脑端配置&#xff1a; 关闭防火墙 Proxy–>勾选 macOS Proxy Proxy–>Proxy Setting–>填入代理端口8888–>勾选Enable transparent http proxying 安装c…

【pycharm】 Virtualenv创建venv报错

一、背景 在启动django项目时&#xff0c;需要创建venv环境&#xff0c;有时候能顺利创建成功&#xff0c;当python版本换成3.8时&#xff0c;会报错 ImportError: DLL load failed while importing _ssl: 找不到指定的模块。 二、原因和解决措施 之所以执行这个报错&#…

.NET下的开源OCR项目:解锁图片文字识别的新篇章

在数字化时代&#xff0c;从图片中高效准确地提取文字信息已成为众多应用场景的迫切需求。OCR&#xff08;Optical Character Recognition&#xff0c;光学字符识别&#xff09;技术正是满足这一需求的关键技术。对于.NET开发者而言&#xff0c;幸运的是&#xff0c;存在多个开…

SpringBoot的在线教育平台-计算机毕业设计源码68562

摘要 在数字化时代&#xff0c;随着信息技术的飞速发展&#xff0c;在线教育已成为教育领域的重要趋势。为了满足广大学习者对于灵活、高效学习方式的需求&#xff0c;基于Spring Boot的在线教育平台应运而生。Spring Boot以其快速开发、简便部署以及良好的可扩展性&#xff0c…

聚鼎科技:装饰画现在做晚不晚

在每一处光影交错的角落&#xff0c;墙上那一副副静默无言的装饰画&#xff0c;似乎总在诉说着不同的故事。如今&#xff0c;投身于装饰画的创作与收藏&#xff0c;仿佛是一场关于美和时间的赛跑&#xff0c;那么问题来了——现在开始&#xff0c;晚吗? 伴随着生活品质的提升和…

高薪程序员必修课-JVM的内存区域以及对象创建过程

JVM内存区域 在Java虚拟机&#xff08;JVM&#xff09;中&#xff0c;内存区域&#xff08;Memory Areas&#xff09;是对内存空间的逻辑划分&#xff0c;用于存储不同类型的数据和执行不同的操作。理解JVM的内存区域有助于优化程序性能、调优内存使用和排查内存相关的问题。下…

14-6 小型语言模型在商业应用中的使用指南

人工智能 (AI) 在商业领域的发展使众多工具和技术成为人们关注的焦点&#xff0c;其中之一就是语言模型。这些大小和复杂程度各异的模型为增强业务运营、客户互动和内容生成开辟了新途径。本指南重点介绍小型语言模型、它们的优势、实际用例以及企业如何有效利用它们。 基础知识…

RT-Thread和freeRTOS启动流程

一. freeRTOS启动流程 二. RT-Thread启动流程 因为RT-Thread中我们定义了补丁函数也叫做钩子函数--$Sub$$main()--作为一个新功能函数&#xff0c;可以将原有函数劫持下来&#xff0c;并在之后的程序运行中加上$Super $ $前缀来重新调用原始函数。 所以启动流程是$Sub$$main(…

谷粒商城笔记-04-项目微服务架构图简介

文章目录 一&#xff0c;网络二&#xff0c;网关1&#xff0c;网关选型2&#xff0c;认证鉴权3&#xff0c;动态路由4&#xff0c;限流5&#xff0c;负载均衡6&#xff0c;熔断降级 三&#xff0c;微服务四&#xff0c;存储层五&#xff0c;服务治理六&#xff0c;日志系统七&a…

【网络安全】Host碰撞漏洞原理+工具+脚本

文章目录 漏洞原理虚拟主机配置Host头部字段Host碰撞漏洞漏洞场景工具漏洞原理 Host 碰撞漏洞,也称为主机名冲突漏洞,是一种网络攻击手段。常见危害有:绕过访问控制,通过公网访问一些未经授权的资源等。 虚拟主机配置 在Web服务器(如Nginx或Apache)上,多个网站可以共…

软件测试面试题总结(超全的)

前面看到了一些面试题&#xff0c;总感觉会用得到&#xff0c;但是看一遍又记不住&#xff0c;所以我把面试题都整合在一起&#xff0c;都是来自各路大佬的分享&#xff0c;为了方便以后自己需要的时候刷一刷&#xff0c;不用再到处找题&#xff0c;今天把自己整理的这些面试题…

力扣热100 滑动窗口

这里写目录标题 3. 无重复字符的最长子串438. 找到字符串中所有字母异位词 3. 无重复字符的最长子串 左右指针left和right里面的字符串一直是没有重复的 class Solution:def lengthOfLongestSubstring(self, s: str) -> int:# 左右指针leftright0ans0#初始化结果tablecolle…

ctfshow-web入门-文件包含(web82-web86)条件竞争实现session会话文件包含

目录 1、web82 2、web83 3、web84 4、web85 5、web86 1、web82 新增过滤点 . &#xff0c;查看提示&#xff1a;利用 session 对话进行文件包含&#xff0c;通过条件竞争实现。 条件竞争这个知识点在文件上传、不死马利用与查杀这些里面也会涉及&#xff0c;如果大家不熟悉…

JavaScript高级程序设计(第四版)--学习记录之对象、类和面向对象编程(中)

创建对象方式 工厂模式&#xff1a;用于抽象创建特定对象的过程。可以解决创建多个类似对象的问题&#xff0c;但没有解决对象标识问题。&#xff08;即新创建的对象是什么类型&#xff09; function createPerson(name, age, job) { let o new Object(); o.name name; o.age…

广和通 OpenCPU 二次开发(二) ——通过linux编译

广和通 OpenCPU 二次开发&#xff08;二&#xff09; ——通过linux编译 一、编译命令总结 1.编译环境配置 . tools/core_launch.sh cout cmake ../.. -G Ninja 2.编译 ninja 二、命令解释 1. 执行 tools/core_launch.sh 这是一个脚本文件 core_launch.sh&#xff0c;通…

技术赋能政务服务:VR导视与AI客服在政务大厅的创新应用

在数字化转型的浪潮中&#xff0c;政务大厅作为服务民众的前沿阵地&#xff0c;其服务效率和质量直接影响着政府形象和民众满意度。然而&#xff0c;许多政务大厅仍面临着缺乏智能化导航系统的挑战&#xff0c;这不仅增加了群众的办事难度&#xff0c;也降低了服务效率。维小帮…

内核错误定位

内核打印出如下&#xff1a; 在代码目录输入&#xff1a; ./prebuilts/gcc/linux-x86/aarch64/gcc-linaro-6.3.1-2017.05-x86_64_aarch64-linux-gnu/bin/aarch64-linux-gnu-gdb kernel/vmlinux 进入gdb 命令模式 输入 l *(rk628_csi_probe0xf0) 能定位到出现问题地方。 最后就…

PLM系统:PLM系统如何重塑产品生命周期管理

PLM系统&#xff1a;重塑产品生命周期管理的未来 在当今快速变化的商业环境中&#xff0c;产品生命周期管理&#xff08;PLM&#xff09;系统正逐渐成为企业提升竞争力、加速创新并优化运营流程的关键工具。随着技术的不断进步和市场需求的日益复杂化&#xff0c;传统的手动或…

试用笔记之-汇通窗口颜色显示软件(颜色值可供Delphi编程用)

首先下载汇通窗口颜色显示软件 http://www.htsoft.com.cn/download/wdspy.rar 通过获得句柄颜色&#xff0c;显示Delphi颜色值和HTML颜色值