【扩散模型(三)】IP-Adapter 源码详解1-输入篇

系列文章目录

  • 【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch)
  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【可控图像生成系列论文(一)】 简要介绍了 MimicBrush 的整体流程和方法;
  • 【可控图像生成系列论文(二)】 就MimicBrush 的具体模型结构训练数据纹理迁移进行了更详细的介绍。
  • 【可控图像生成系列论文(三)】介绍了一篇相对早期(2018年)的可控字体艺术化工作。
  • 【可控图像生成系列论文(四)】介绍了 IP-Adapter 具体是如何训练的?
  • 【可控图像生成系列论文(五)】ControlNet 和 IP-Adapter 之间的区别有哪些?
  • 本文《【扩散模型(三)】IP-Adapter 源码详解1-输入篇》作为两个系列的交汇点,将通过对经典的 IP-Adapter 源码详细阅读,进一步加深对其原理的解释。

文章目录

  • 系列文章目录
  • 整体结构图+代码中的变量名
  • 一、IP-Adapter 做了什么?
  • 二、对应的代码实现
    • 1.模型输入
    • 2.Linear 和 LN(LayerNorm)
  • 总结


整体结构图+代码中的变量名

IP-Adapter 源码:https://github.com/tencent-ailab/IP-Adapter

本文就基于 SD1.5 的 IP-Adapter 训练代码 tutorial_train.py 为例,进行代码和结构图的解释。


在这里插入图片描述

一、IP-Adapter 做了什么?

如上图所示,插入了图中的最上面一条分支(图像输入条件分支):

  1. 蓝色的(无需训练的) Image Encoder
  2. 红色的(需训练的)Linear + LN(LayerNorm)
  3. 红色的(需训练的)、针对图像(Image Prompt)的 Cross Attention。

在论文中也提到,具体分别是:

  1. Image Encoder 是 pretrained CLIP image encoder
  2. 线性层和层归一化 Linear + LN(LayerNorm1):
    • 为了有效地分解全局图像嵌入,作者使用一个小的可训练投影网络(projection network)将图像嵌入投影到长度为N的特征序列中(在本研究中使用N=4),图像特征的维数与预训练的扩散模型中文本特征的维数相同。使用的投影网络由线性层和层归一化组成。
  3. Decoupled Cross-Attention 中,做法是在原来的 UNet 的 Cross-Attention 中加了一层 Cross-Attention。
    • 如原文提到 “we add a new cross-attention layer for each cross-attention layer in the original UNet model to insert image features.”

二、对应的代码实现

在这里插入图片描述

1.模型输入

先简单看下模型的训练时的输入,即 /path/IP-Adapter/tutorial_train.py 中 main() 函数内的 dataloader 部分,下面代码通过调用 MyDataset 类来实现了 train_dataloader 的构建。

    # dataloader
    train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=args.train_batch_size,
        num_workers=args.dataloader_num_workers,
    )

对于实际训练使用的数据则为从 train_dataloader 中取的:

  1. batch[“images”]
    • 用来得到形状后,生成随机噪声。
    • 具体如下代码所示,通过 vae.encoder 得到 latents后
    • 通过 torch.randn_like(latents) 按照 latents 张量的形状生成一个随机的噪声张量 noise
  2. batch[“clip_images”]
    • 通过 image_encoder 得到 image_embeds 图像特征
  3. batch[“drop_image_embeds”]
    • 文中有提到会随机通过随机丢弃条件信息(如文本或图像嵌入),使得模型会学会在有条件和无条件的情况下进行预测(生成图像)
  4. batch[“text_input_ids”] 是文本输入,通过一个 text_encoder 后得到文本特征 encoder_hidden_states
  for step, batch in enumerate(train_dataloader):
      load_data_time = time.perf_counter() - begin

        with torch.no_grad():
            latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
            latents = latents * vae.config.scaling_factor

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
    
        with torch.no_grad():
            image_embeds = image_encoder(batch["clip_images"].to(accelerator.device, dtype=weight_dtype)).image_embeds
        image_embeds_ = []
        for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
            if drop_image_embed == 1:
                image_embeds_.append(torch.zeros_like(image_embed))
            else:
                image_embeds_.append(image_embed)
        image_embeds = torch.stack(image_embeds_)
    
        with torch.no_grad():
            encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0] # pooled_prompt_embeds?

2.Linear 和 LN(LayerNorm)

以 SD1.5 + IP-Adapter 的训练代码为例:

下方代码为 /path/IP-Adapter/tutorial_train.py 中 main() 函数内,调用了定义好的 ImageProjModel 类

#ip-adapter
    image_proj_model = ImageProjModel(
        cross_attention_dim=unet.config.cross_attention_dim,
        clip_embeddings_dim=image_encoder.config.projection_dim,
        clip_extra_context_tokens=4,
    )

下方代码为 /path/IP-Adapter/ip_adapter/ip_adapter.py 被调用的 ImageProjModel 类,在构造函数 __init__ 中可以看到有前文提到的 Linear 和 LayerNorm。

class ImageProjModel(torch.nn.Module):
    """Projection Model"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens

总结

本文详解了IP-Adapter 训练源码中的输入部分,下篇则详解核心部分,针对图像输入的 Cross-Attention。


  1. Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/775803.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

生产力工具|viso常用常见科学素材包

一、科学插图素材网站 一图胜千言,想要使自己的论文或重要汇报更加引人入胜?不妨考虑利用各类示意图和科学插图来辅助研究工作。特别是对于新手或者繁忙的科研人员而言,利用免费的在线科学插图素材库,能够极大地节省时间和精力。 …

20.5.【C语言】求长度的两种方式

1.sizeof 用于测数据类型的长度的函数(详细见第3篇) 2.strlen 其计算长度时只有遇到\0才会停止,并且\0不会计算在内 如char arr[]{a,1,b}; printf("%d\n",strlen(arr)); 结果是个随机数!strlen读内存中的数据&…

3D生成模型TripoSR完美搭建流程,包含所有问题解决方案!

最近需要使用3D生成模型,无意中看到了TripoSR,觉得效果还行,于是打算在Linux系统上部署一下,结果遇到很多坑,在这里写一下详细的部署流程和部署过程中遇到的问题。 下面是TripoSR的源码地址。 GitHub - VAST-AI-Research/TripoSRContribute to VAST-AI-Research/TripoSR…

已经安装deveco-studio-4.1.3.500的基础上安装deveco-studio-3.1.0.501

目录标题 1、执行exe文件后安装即可2、双击devecostudio64_3.1.0.501.exe2.1、安装Note (注意和4.1的Note放不同目录)2.2、安装ohpm (注意和4.1版本的ohpm放不同目录)2.3、安装SDK (注意和4.1版本的SDK放不同目录) 1、执行exe文件后安装即可 2、双击devecostudio64_3.1.0.501.e…

linux主机(A)通过私钥登录linux主机(B)

1.登录B主机,先在B主机执行 ssh-keygen 2.设置id_rsa的权限 chmod 600 id_rsa 3.将生成的id_rsa.pub导入到authorized_keys ssh-copy-id -i ./id_rsa.pub root127.0.0.1 4.将id_rsa复制到A主机 scp id_rsa_123 root1.1.1.A:/home/ 5.登录到A主机使用私钥登录 因…

C++左值/右值/左值引用/右值引用

1)C入门级小知识,分享给将要学习或者正在学习C开发的同学。 2)内容属于原创,若转载,请说明出处。 3)提供相关问题有偿答疑和支持。 左值和右值的概念: 早期的c语言中关于左值和右值的定义&a…

使用中国大陆镜像源安装最新版的 docker Deamon

在一个智算项目交付过程中,出现了新建集群中的全部 docker server V19 进程消失、仅剩 docker server 的 unix-socket 存活的现象。 为了验证是否是BD产品研发提供的产品deploy语句缺陷,需要在本地环境上部署一个简单的 docker Deamon 环境。尴尬的是&a…

强化学习后的数学原理:随机近似与梯度下降

概述 这节课的作用: 本节课大纲如下: Motivating examples 先回顾一下 mean estimation : 为什么总数反复提到这个 mean estimation,就是因为 RL 当中有非常多的 expectation,后面就会知道除了 state value 这些定义…

传统视觉Transformer的替代者:交叉注意力Transformer(CAT)

传统视觉Transformer的替代者:交叉注意力Transformer(CAT) 在深度学习的世界里,Transformer架构以其在自然语言处理(NLP)领域的卓越表现而闻名。然而,当它进入计算机视觉(CV&#x…

Hilbert编码 思路和scala 代码

需求: 使用Hilbert 曲线对遥感影像瓦片数据进行编码,获取某个区域的编码值即可 Hilbert 曲线编码方式 思路 大致可以对四个方向的数据进行归类 左下左上右上右下 这个也对应着编码的顺序 思考在不同Hilbert深度(阶)情况下的…

AutoX.js某音自动评论(一个函数,5秒完成)

背景 某音自动化评论,步骤简单,对版本兼容性要高(不用节点id定位) 思路 通过Intent直接跳转到视频*利用文字,描述(正则)等匹配输入框,发表评论 效果 某音自动化评论 已经封装成一…

专题二:Spring源码编译

目录 下载源码 配置Gradle 配置环境变量 配置setting文件 配置Spring源码 配置文件调整 问题解决 完整配置 gradel.properties build.gradle settiings.gradel 在专题一: Spring生态初探中我们从整体模块对Spring有个整体的印象,现在正式从最…

C盘扩容/扩大C盘的12个有效操作方法

对于许多计算机用户来说,C盘空间可能会成为一个问题,尤其是那些将计算机广泛用于工作、游戏和多媒体目的的用户。如果您发现C驱动器上的空间不足,则需要对其进行扩展以提高系统的整体性能。在这篇文章中,我们将探讨C盘扩展的12种操…

【计算智能】遗传算法(二):基本遗传算法在优化问题中的应用【实验】

前言 本系列文章架构概览: 本文将介绍基本遗传算法在解决优化问题中的应用,通过实验展示其基本原理和实现过程:选取一个简单的二次函数作为优化目标,并利用基本遗传算法寻找其在指定范围内的最大值。 2. 基本遗传算法(SGA&#x…

Laravel 谨慎使用Storage::append()

在 driver 为 local 时,Storage::append()在高并发下,会存在丢失数据问题,文件被覆写,而非尾部添加,如果明确是本地文件操作,像日志写入,建议使用 Illuminate\Filesystem\Filesystem或者php原生…

跨境干货|最新注册Google账号方法分享

谷歌账号对做跨境外贸业务的人来说是刚需,目前来说大部分的海外社媒平台、工具都可以用谷歌账号来注册。但是仍然有很多朋友并不知道如何注册这个谷歌账号,今天就来给大家分享2个注册谷歌账号的方法,一个是手机号注册,一个是如何跳…

品牌营销新趋势:独立站如何与TikTok达人互动共赢

在当今数字化营销时代,独立站与TikTok达人的互动营销已成为一种高效的推广方式。通过精心设计的互动环节,独立站不仅能够增强用户参与感,还能显著提升带货效果。本文Nox聚星将和大家探讨独立站如何与TikTok达人进行高效互动,吸引用…

如何有效管理你的Facebook时间线?

Facebook作为全球最大的社交平台之一,每天都有大量的信息和内容在用户的时间线上展示。有效管理你的Facebook时间线,不仅可以提升用户体验,还能够帮助你更好地控制信息流和社交互动。本文将探讨多种方法和技巧,帮助你有效管理个人…

开发者评测|操作系统智能助手OS Copilot

操作系统智能助手OS Copilot 文章目录 操作系统智能助手OS CopilotOS Copilot 是什么优势功能 操作步骤创建实验重置密码创建Access Key配置安全组安装 os-copilot环境变量配置功能评测命令行模式多轮交互模式 OS Copilot 产品体验评测反馈OS Copilot 产品功能评测反馈 参考文档…

C++基础(七):类和对象(中-2)

上一篇博客学的默认成员函数是类和对象的最重要的内容,相信大家已经掌握了吧,这一篇博客接着继续剩下的内容,加油! 目录 一、const成员(理解) 1.0 引入 1.1 概念 1.2 总结 1.2.1 对象调用成员函数 …