系列文章目录
- 【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch)
- 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
- 【可控图像生成系列论文(一)】 简要介绍了 MimicBrush 的整体流程和方法;
- 【可控图像生成系列论文(二)】 就MimicBrush 的具体模型结构、训练数据和纹理迁移进行了更详细的介绍。
- 【可控图像生成系列论文(三)】介绍了一篇相对早期(2018年)的可控字体艺术化工作。
- 【可控图像生成系列论文(四)】介绍了 IP-Adapter 具体是如何训练的?
- 【可控图像生成系列论文(五)】ControlNet 和 IP-Adapter 之间的区别有哪些?
- 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
- 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
文章目录
- 系列文章目录
- 前言
- 一、输入处理
- 二、过 Unet
- 三、Unet 中被替换的 CA
前言
这里以 /path/to/IP-Adapter/ip_adapter_demo.ipynb
中最基础的以图生图(Image Variations)为例:
SD1.5-IPA 的推理流程如下图所示,可被分为 3 个部分:
- 输入处理:对 img prompt 和 txt prompt 分别先得到 embedding 后再送入 SD 的 pipeline;
- 过 Unet:与一般输入 txt prompt 类似,通过 Unet 的各个模块;
- Unet 中的 CA:对于 img prompt 部分需要拆出来,单独过针对性的 k (to_k_ip)和 v(to_v_ip)。
其中的关键在第一部分,与一般将 txt prompt 直接送入 SD pipeline 不太一样,是先处理为 embedding 再送入 pipeline 的。
*图中的 bs 代表 batch size
一、输入处理
IP-Adapter 的推理代码核心是在 /path/to/IP-Adapter/ip_adapter/ip_adapter.py
文件的 IPAdapter 类的 generate() 函数中。
- 输入1: image prompt
- 通过冻结住的 image encoder(CLIPImageProcessor 先预处理,再通过 CLIPVisionModelWithProjection)
- 以及训练好的 image_proj_model(ImageProjModel)
- 输入1对应的输出1有:
- image_prompt_embeds
- uncond_image_prompt_embeds(纯 0 tensor 过一次 ImageProjModel)
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
self.image_proj_model.load_state_dict(state_dict["image_proj"])# 从训好的权重中读取
...
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
-
输入2: text prompt、negative_prompt(默认的
['monochrome, lowres, bad anatomy, worst quality, low quality']
)- text prompt 通过 StableDiffusionPipeline 中的 .encode_prompt()
- encode_prompt 中,对于直接文字的 prompt(str 字符串格式的),会先通过 tokenizer
- 检查是否超过 clip 的长度
- 通过 text_encoder (CLIPTextModel) 得到 prompt_embeds(文本特征)
- negative_prompt 同样通过 tokenizer 和 text_encoder 得到 negative_prompt_embeds
- text prompt 通过 StableDiffusionPipeline 中的 .encode_prompt()
-
输入2 对应的输出2有:
- prompt_embeds_
- negative_prompt_embeds_
-
输出1 的 image_prompt_embeds、uncond_image_prompt_embeds 分别和 输出2 prompt_embeds_、negative_prompt_embeds_ 在维度1上
torch.cat
后得到 self.pipe(第二次 encoder_prompt)的输入:prompt_embeds 和 negative_prompt_embeds。
with torch.inference_mode():
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
二、过 Unet
- 按照 prompt 和 negative_prompt 为 None、将 prompt_embeds 和 negative_prompt_embeds 作为输入,通过 encode_prompt(),
- 得到进一步的 prompt_embeds 和 negative_prompt_embeds
- prompt_embeds 和 negative_prompt_embeds 做
torch.cat
是在维度 0 上,这是针对 do_classifier_free_guidance 的操作,避免做两次前向传播。
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- 接下来的路径和 SD1.5 基本的推理流程基本一致,除了被替换的 Cross-Attn(CA)。
三、Unet 中被替换的 CA
该部分应该无需多说,与训练部分一致,即增加一个针对 image prompt 的 k 和 v。上篇 也有相应代码的介绍。