RCG自条件是如何添加到 Pixel Generator上的?

在自条件的训练过程中,需要将图像经过Pretrained encoder的表征Rep输入进已有的Pixel Generator上,目前RCG是向四种Pixel Generator上加入了自条件,关于它是如何将rep加到Pixel Generator上的,我来总结一下:

一、Pixel Generator: MAGE

在MAGE中,是使用rep替换embedding的 fake class token做的:

  1. 得到CFG的混合表征
  2. 将混合表征替换embedding的 fake class token
  3. 输入进ViT block

        # replace fake class token with rep
        if self.use_rep:
            # cfg(class free guidance) by masking representation
            drop_rep_mask = torch.rand(bsz) < self.rep_drop_prob
            drop_rep_mask = drop_rep_mask.unsqueeze(-1).cuda().float()
            # 这里相当于cfg, O = αU + (1-a)C, 最终输出是由条件生成C(rep)和无条件生成U(fake_latent)的线性外推获得
            rep = drop_rep_mask * self.fake_latent + (1 - drop_rep_mask) * rep

            rep = self.latent_prior_proj(rep)
            # 将rep赋值给embedding的(将图像的rep替换seq的第0维度,相当于替换了seq的fake class token),其实并没有对原始的MAGE做什么改变,只是将原来可学习的fake token换为了rep,从而输入进encoder
            # input_embeddings_after_drop:(64,129,768) <-- rep:(64,768)
            input_embeddings_after_drop[:, 0] = rep
        # class-conditional MAGE
        if self.use_class_label:
            class_emb = self.class_emb(class_label)
            input_embeddings_after_drop[:, 0] = class_emb

        # apply Transformer blocks
        x = input_embeddings_after_drop
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        # print("Encoder representation shape:", x.shape)

        return x, gt_indices, token_drop_mask, token_all_mask

二、Pixel Generator: DiT

从forward函数中,可以看到,

  1. 先使用CFG得到rep的混合表征rep
  2. rep加到timestep中 (16,1024) + (16,1024) =(16,1024)维度得到c。
  3. 然后将这个c作为融合条件输入进去噪block(transformer block)中。
    def forward(self, x, t, y, rep=None):
        """
        Forward pass of DiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
        t = self.t_embedder(t)                   # (N, D)
        y = self.y_embedder(y, self.training)    # (N, D)
        # rep cond
        if rep is not None:
            # 1、get the CFG mixture rep
            if self.training:
                drop_rep_mask = torch.rand(x.size(0)) < self.rep_dropout_prob
                drop_rep_mask = drop_rep_mask.unsqueeze(-1).cuda().float()
                rep = drop_rep_mask * self.fake_latent + (1 - drop_rep_mask) * rep
            rep = self.rep_embedder(rep)
            # 2】直接将rep加到timestep t上从而作为下一步的输入 -->(16,1024)
            c = t + rep
        else:
            c = t + y  # (N, D)
        # 3、进一步处理
        for block in self.blocks:
            x = block(x, c)                      # (N, T, D)
        x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)
        x = self.unpatchify(x)                   # (N, out_channels, H, W)
        return x

三、Pixel Generator: ADM

这里和DiT的处理方式是一样的,直接将rep与timestep相加,然后输入进U-Net进行去噪

U-Net的forward(): 

model_output = model(x_t, self._scale_timesteps(t), rep=rep, **model_kwargs)
    def forward(self, x, timesteps, y=None, rep=None):
        """
        Apply the model to an input batch.

        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"

        assert (rep is not None) == self.rep_cond
        # 将timestep embedding
        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)
        # 将timestep的embedding和rep相加,然后输入进U-Net
        if self.rep_cond:
            emb = emb + self.rep_proj(rep)

        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)
        h = self.middle_block(h, emb)
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb)
        h = h.type(x.dtype)
        return self.out(h)

四、Pixel Generator: LDM

整体来说没有什么太多的问题,就是将LDM中的condition换成了包含了/未包含condition信息的 rep

  1. 得到encoder后的图像x(4,4,32,32),表征rep c(4,1,256)
  2. 将带有condition信息的rep替换DDPM的原始condition
  3. 将encoder后的图像x(4,4,32,32),表征rep c(4,1,256), timestep t(4)输入进DDPM的后向过程求loss
    def forward(self, x, c, batch=None, gen_img=False, *args, **kwargs):
        if gen_img:
            return self.gen_imgs()
        # 1、得到encoder后的图像x(4,4,32,32),表征rep c(4,1,256)
        if batch is not None:
            x, c = self.get_input(batch, self.first_stage_key)
            if self.rep_cond:
                rep = c['rep']
            c = {'class_label': c['class_label']}
        t = torch.randint(0, self.num_timesteps, (x.shape[0],)).cuda().long()
        if self.model.conditioning_key is not None:
            assert c is not None
            # 将图像的label变为可学习的
            if self.cond_stage_trainable:
                c = self.get_learned_conditioning(c)
            if self.shorten_cond_schedule:  # TODO: drop this option
                tc = self.cond_ids[t].cuda()
                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
        # 2、将带有condition信息的rep替换DDPM的原始condition
        if self.rep_cond:
            c = rep
            c = c.unsqueeze(1)
        # 3、将encoder后的图像x(4,4,32,32),表征rep c(4,1,256), timestep t(4)输入进DDPM的后向过程求loss
        loss, loss_dict = self.p_losses(x, c, t, *args, **kwargs)
        if self.use_ema and batch is not None:
            self.model_ema(self.model)
        return loss, loss_dict

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/505429.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[SpringCloud] Feign Client 的创建 (一) (四)

文章目录 1.FeignClientsRegistrar2.完成配置注册2.1 registerDefaultConfiguration方法2.2 迭代稳定性2.3 registerFeignClients方法 1.FeignClientsRegistrar FeignClientsRegistrar实现ImportBeanDefinitionRegistrar接口。 2.完成配置注册 public void registerBeanDefinit…

JQ 查看图片的好插件

效果图 插件官网 https://blog.51cto.com/transfer?https://github.com/fengyuanchen/viewer 使用 <!DOCTYPE html> <html lang"en"> <head><meta charset"utf-8"><link rel"stylesheet" href"css/viewer.c…

攻防世界——catfly

这道题我觉得很难&#xff0c;我当初刷题看见这道题&#xff0c;是唯一一道直接跳过的&#xff0c;现在掌握了一点知识才回来重新看 这道题在linux运行下是这样&#xff0c;我首先猜测是和下面这个time有关&#xff0c;判断达到一定次数就会给我flag 但是我找了好久都没找到那…

(九)信息融合方式简介

目录 前言 一、什么是信息融合&#xff1f; 二、集中式信息融合与分布式信息融合 &#xff08;一&#xff09;集中式融合 &#xff08;二&#xff09;分布式融合 1.简单信息融合 2.CI&#xff08;协方差交叉&#xff09;信息融合 3.无反馈的最优分布式融合 4.有反馈的…

反应式编程(一)什么是反应式编程

目录 一、背景二、反应式编程简介2.1 定义2.2 反应式编程的优势2.3 命令式编程 & 反应式编程 三、Reactor 入门3.1 Reactor 的核心类3.2 Reactor 中主要的方法1&#xff09;创建型方法2&#xff09;转化型方法3&#xff09;其他类型方法4&#xff09;举个例子 四、Reactor …

论文笔记:GPT-4 Is Too Smart To Be Safe: Stealthy Chat with LLMs via Cipher

ICLR 2024 reviewer评分 5688 1 论文思路 输入转换为密码&#xff0c;同时附上提示&#xff0c;将加密输入喂给LLMLLM输出加密的输出加密的输出通过解密器解密 ——>这样的步骤成功地绕过了GPT-4的安全对齐【可以回答一些反人类的问题&#xff0c;这些问题如果明文问的话&…

【C++】set和map

set和map就是我们上篇博客说的key模型和keyvalue模型。它们属于是关联式容器&#xff0c;我们之前说过普通容器和容器适配器&#xff0c;这里的关联式容器就是元素之间是有关联的&#xff0c;通过上篇博客的讲解我们也对它们直接的关系有了一定的了解&#xff0c;那么下面我们先…

蓝桥杯-python-常用库归纳

目录 日期和时间 datetime模块 date日期类&#xff0c;time时间类&#xff0c;datetime日期时间类 定义date&#xff08;年&#xff0c;月&#xff0c;日&#xff09; data之间的减法 定义时间&#xff08;时&#xff0c;分&#xff0c;秒&#xff09; 定义datetime&#xf…

42.HarmonyOS鸿蒙系统 App(ArkUI)实现横屏竖屏自适应

HarmonyOS鸿蒙系统 App(ArkUI)实现横屏竖屏自适应 媒体查询作为响应式设计的核心&#xff0c;在移动设备上应用十分广泛。媒体查询可根据不同设备类型或同设备不同状态修改应用的样式。媒体查询常用于下面两种场景&#xff1a; 针对设备和应用的属性信息&#xff08;比如显示…

【Linux】进程实践项目 —— 自主shell编写

送给大家一句话&#xff1a; 不管前方的路有多苦&#xff0c;只要走的方向正确&#xff0c;不管多么崎岖不平&#xff0c;都比站在原地更接近幸福。 —— 宫崎骏《千与千寻》 自主shell命令编写 1 前言2 项目实现2.1 创建命令行2.2 获取命令2.3 分割命令2.4 运行命令 3 源代码…

计算机服务器中了rmallox勒索病毒怎么办?rmallox勒索病毒解密数据恢复

网络技术的不断发展与应用&#xff0c;大大提高了企业的生产运营效率&#xff0c;越来越多的企业开始网络开展各项工作业务&#xff0c;网络在为人们提供便利的同时&#xff0c;也会存在潜在威胁。近日&#xff0c;云天数据恢复中心接到多家企业的求助&#xff0c;企业的计算机…

Python内置函数enumerate()

Python的内置函数enumerate()。在学习过程中遇到了一点小问题。记录一下。 enumerate() 是 Python 中常用的内置函数之一&#xff0c;它可以用来同时遍历序列的索引和对应的值。具体来说&#xff0c;enumerate() 接受一个可迭代对象作为参数&#xff0c;返回一个包含索引和值的…

vuees6新语法

vue的学习网站&#xff1a; https://www.runoob.com/vue2/vue-tutorial.html1.Vue的介绍 学习目标 说出什么是Vue能够说出Vue的好处能够说出Vue的特点 内容讲解 【1】Vue介绍 1.vue属于一个前端框架&#xff0c;底层使用原生js编写的。主要用来进行前端和后台服务器之间的…

Holiday Notice

Holiday Notice 放假通知 要是每个公司都能放假放的多&#xff0c;把加班折算放假落实到位&#xff0c;还怕我们不努力干活&#xff0c;巴不得把全年都干完了&#xff0c;然后休息。

HCIP【GRE VPN配置】

目录 实验要求&#xff1a; 实验配置思路&#xff1a; 实验配置过程&#xff1a; 一、按照图式配置所有设备的IP地址 &#xff08;1&#xff09;首先配置每个接口的IP地址 &#xff08;2&#xff09;配置静态路由使公网可通 二、在公网的基础上创建GRE VPN隧道&#xff0…

HarmonyOS实战开发-如何实现一个简单的健康生活应用(上)

介绍 本篇Codelab介绍了如何实现一个简单的健康生活应用&#xff0c;主要功能包括&#xff1a; 用户可以创建最多6个健康生活任务&#xff08;早起&#xff0c;喝水&#xff0c;吃苹果&#xff0c;每日微笑&#xff0c;刷牙&#xff0c;早睡&#xff09;&#xff0c;并设置任…

C++list的模拟实现

为了实现list&#xff0c;我们需要实现三个类 一、List的节点类 template<class T> struct ListNode {ListNode(const T& val T()):_pPre(nullptr),_pNext(nullptr),_val(val){}ListNode<T>* _pPre;ListNode<T>* _pNext;T _val; }; 二、List的迭代器…

2024年腾讯云服务器99元一年_老用户优惠续费不涨价

腾讯云99元一年服务器配置为轻量2核2G4M、50GB SSD盘、300GB月流量、4M带宽&#xff0c;新用户和老用户都可以购买&#xff0c;续费不涨价&#xff0c;续费价格也是99元一年。以往腾讯云优惠服务器都是新用户专享的&#xff0c;这款99元服务器老用户也可以购买&#xff0c;这是…

Spring Task 知识点详解、案例、源代码解析

简介&#xff1a;Spring Task 定时任务   所谓定时任务。就是依据我们设定的时间定时运行任务&#xff0c;就像定时发邮件一样&#xff0c;设定时间到了。邮件就会自己主动发送。 在Spring大行其道的今天&#xff0c;Spring也提供了其定时任务功能&#xff0c;Spring Task。同…

安装dalton过程中出现的pcre问题

在前面文章中&#xff0c;基于多种流量检测引擎识别pcap数据包中的威胁&#xff0c;并没有详细的说明dalton的安装。由于dalton提供了脚本./start-dalton.sh &#xff0c;执行之后会自动的安装各种依赖以及suricata&#xff0c;zeek&#xff0c;snort的容器环境。但是在实际执行…