Diffusion扩散模型学习3——Stable Diffusion结构解析-以图像生成图像(图生图,img2img)为例

Diffusion扩散模型学习3——Stable Diffusion结构解析-以图像生成图像(图生图,img2img)为例

  • 学习前言
  • 源码下载地址
  • 网络构建
    • 一、什么是Stable Diffusion(SD)
    • 二、Stable Diffusion的组成
    • 三、img2img生成流程
      • 1、输入图片编码
      • 2、文本编码
      • 3、采样流程
        • a、生成初始噪声
        • b、对噪声进行N次采样
        • c、单次采样解析
          • I、预测噪声
          • II、施加噪声
        • d、预测噪声过程中的网络结构解析
          • I、apply_model方法解析
          • II、UNetModel模型解析
      • 4、隐空间解码生成图片
  • 图像到图像预测过程代码

学习前言

用了很久的Stable Diffusion,但从来没有好好解析过它内部的结构,写个博客记录一下,嘿嘿。
在这里插入图片描述

源码下载地址

https://github.com/bubbliiiing/stable-diffusion

喜欢的可以点个star噢。

网络构建

一、什么是Stable Diffusion(SD)

Stable Diffusion是比较新的一个扩散模型,翻译过来是稳定扩散,虽然名字叫稳定扩散,但实际上换个seed生成的结果就完全不一样,非常不稳定哈。

Stable Diffusion最开始的应用应该是文本生成图像,即文生图,随着技术的发展Stable Diffusion不仅支持image2image图生图的生成,还支持ControlNet等各种控制方法来定制生成的图像。

Stable Diffusion基于扩散模型,所以不免包含不断去噪的过程,如果是图生图的话,还有不断加噪的过程,此时离不开DDPM那张老图,如下:
在这里插入图片描述
Stable Diffusion相比于DDPM,使用了DDIM采样器,使用了隐空间的扩散,另外使用了非常大的LAION-5B数据集进行预训练。

直接Finetune Stable Diffusion大多数同学应该是无法cover住成本的,不过Stable Diffusion有很多轻量Finetune的方案,比如Lora、Textual Inversion等,但这是后话。

本文主要是解析一下整个SD模型的结构组成,一次扩散,多次扩散的流程。

大模型、AIGC是当前行业的趋势,不会的话容易被淘汰,hh。

txt2img的原理如博文
Diffusion扩散模型学习2——Stable Diffusion结构解析-以文本生成图像(txt2img)为例
所示。

二、Stable Diffusion的组成

Stable Diffusion由四大部分组成。
1、Sampler采样器。
2、Variational Autoencoder (VAE) 变分自编码器。
3、UNet 主网络,噪声预测器。
4、CLIPEmbedder文本编码器。

每一部分都很重要,我们以图像生成图像为例进行解析。既然是图像生成图像,那么我们的输入有两个,一个是文本,另外一个是图片。

三、img2img生成流程

在这里插入图片描述
生成流程分为四个部分:
1、对图片进行VAE编码,根据denoise数值进行加噪声。
2、Prompt文本编码。
3、根据denoise数值进行若干次采样。
4、使用VAE进行解码。

相比于文生图,图生图的输入发生了变化,不再以Gaussian noise作为初始化,而是以加噪后的图像特征为初始化这样便以图像的方式为模型注入了信息。

详细来讲,如上图所示:

  • 第一步为对输入的图像利用VAE编码,获得输入图像的Latent特征;然后使用该Latent特征基于DDIM Sampler进行加噪,此时获得输入图片加噪后的特征。假设我们设置denoise数值为0.8,总步数为20步,那么第一步中,我们会对输入图片进行0.8x20次的加噪声,剩下4步不加,可理解为打乱了80%的特征,保留20%的特征。
  • 第二步是对输入的文本进行编码,获得文本特征;
  • 第三步是根据denoise数值对 第一步中获得的 加噪后的特征 进行若干次采样。还是以第一步中denoise数值为0.8为例,我们只加了0.8x20次噪声那么我们也只需要进行0.8x20次采样就可以恢复出图片了
  • 第四步是将采样后的图片利用VAE的Decoder进行恢复。
with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)

    # ----------------------- #
    #   对输入图片进行编码并加噪
    # ----------------------- #
    if image_path is not None:
        img = HWC3(np.array(img, np.uint8))
        img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
        img = torch.stack([img for _ in range(num_samples)], dim=0)
        img = einops.rearrange(img, 'b h w c -> b c h w').clone()

        ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
        t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
        z = model.get_first_stage_encoding(model.encode_first_stage(img))
        z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))

    # ----------------------- #
    #   获得编码后的prompt
    # ----------------------- #
    cond    = {"c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
    un_cond = {"c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
    H, W    = input_shape
    shape   = (4, H // 8, W // 8)

    if image_path is not None:
        samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
    else:
        # ----------------------- #
        #   进行采样
        # ----------------------- #
        samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                        shape, cond, verbose=False, eta=eta,
                                                        unconditional_guidance_scale=scale,
                                                        unconditional_conditioning=un_cond)

    # ----------------------- #
    #   进行解码
    # ----------------------- #
    x_samples = model.decode_first_stage(samples)
    x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

1、输入图片编码

在这里插入图片描述

在图生图中,我们首先要指定一张参考的图像,然后在这个参考图像上开始工作:
1、利用VAE编码器对这张参考图像进行编码,使其进入隐空间,只有进入了隐空间,网络才知道这个图像是什么
2、然后使用该Latent特征基于DDIM Sampler进行加噪,此时获得输入图片加噪后的特征。加噪的逻辑如下:

  • denoise可认为是重建的比例,1代表全部重建,0代表不重建;
  • 假设我们设置denoise数值为0.8,总步数为20步;我们会对输入图片进行0.8x20次的加噪声,剩下4步不加,可理解为80%的特征,保留20%的特征;不过就算加完20步噪声原始输入图片的信息还是有一点保留的,不是完全不保留。

此时我们便获得在隐空间加噪后的图像,后续会在这个 隐空间加噪后的图像 的基础上进行采样。

with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)

    # ----------------------- #
    #   对输入图片进行编码并加噪
    # ----------------------- #
    if image_path is not None:
        img = HWC3(np.array(img, np.uint8))
        img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
        img = torch.stack([img for _ in range(num_samples)], dim=0)
        img = einops.rearrange(img, 'b h w c -> b c h w').clone()

        ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
        t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
        z = model.get_first_stage_encoding(model.encode_first_stage(img))
        z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))

2、文本编码

在这里插入图片描述
文本编码的思路比较简单,直接使用CLIP的文本编码器进行编码就可以了,在代码中定义了一个FrozenCLIPEmbedder类别,使用了transformers库的CLIPTokenizer和CLIPTextModel。

在前传过程中,我们对输入进来的文本首先利用CLIPTokenizer进行编码,然后使用CLIPTextModel进行特征提取,通过FrozenCLIPEmbedder,我们可以获得一个[batch_size, 77, 768]的特征向量。

class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from huggingface)"""
    LAYERS = [
        "last",
        "pooled",
        "hidden"
    ]
    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
                 freeze=True, layer="last", layer_idx=None):  # clip-vit-base-patch32
        super().__init__()
        assert layer in self.LAYERS
        # 定义文本的tokenizer和transformer
        self.tokenizer      = CLIPTokenizer.from_pretrained(version)
        self.transformer    = CLIPTextModel.from_pretrained(version)
        self.device         = device
        self.max_length     = max_length
        # 冻结模型参数
        if freeze:
            self.freeze()
        self.layer = layer
        self.layer_idx = layer_idx
        if layer == "hidden":
            assert layer_idx is not None
            assert 0 <= abs(layer_idx) <= 12

    def freeze(self):
        self.transformer = self.transformer.eval()
        # self.train = disabled_train
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        # 对输入的图片进行分词并编码,padding直接padding到77的长度。
        batch_encoding  = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        # 拿出input_ids然后传入transformer进行特征提取。
        tokens          = batch_encoding["input_ids"].to(self.device)
        outputs         = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
        # 取出所有的token
        if self.layer == "last":
            z = outputs.last_hidden_state
        elif self.layer == "pooled":
            z = outputs.pooler_output[:, None, :]
        else:
            z = outputs.hidden_states[self.layer_idx]
        return z

    def encode(self, text):
        return self(text)

3、采样流程

在这里插入图片描述

a、生成初始噪声

在图生图中,我们的初始噪声获取于参考图片,所以参考第一步就可以获得图生图的噪声

b、对噪声进行N次采样

既然Stable Diffusion是一个不断扩散的过程,那么少不了不断的去噪声,那么怎么去噪声便是一个问题。

在上一步中,我们已经获得了一个图生图的噪声,它是一个符合正态分布的向量,我们便从它开始去噪声。

我们会对ddim_timesteps的时间步取反,因为我们现在是去噪声而非加噪声,然后对其进行一个循环,由于我们此时不再是txt2img中的采样流程,我们使用sampler的另外一个方法decode,循环的代码如下:

@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
           use_original_steps=False):

    # 使用ddim的时间步
    # 这里内容看起来很多,但是其实很少,本质上就是取了self.ddim_timesteps,然后把它reversed一下
    timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
    timesteps = timesteps[:t_start]

    time_range = np.flip(timesteps)
    total_steps = timesteps.shape[0]
    print(f"Running DDIM Sampling with {total_steps} timesteps")

    iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
    x_dec = x_latent
    for i, step in enumerate(iterator):
        index = total_steps - i - 1
        ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
        # 进行单次采样
        x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning)
    return x_dec

c、单次采样解析

I、预测噪声

在进行单词采样前,需要首先判断是否有neg prompt,如果有,我们需要同时处理neg prompt,否则仅仅需要处理pos prompt。实际使用的时候一般都有neg prompt(效果会好一些),所以默认进入对应的处理过程。

在处理neg prompt时,我们对输入进来的隐向量和步数进行复制,一个属于pos prompt,一个属于neg prompt。torch.cat默认堆叠维度为0,所以是在batch_size维度进行堆叠,二者不会互相影响。然后我们将pos prompt和neg prompt堆叠到一个batch中,也是在batch_size维度堆叠。

# 首先判断是否由neg prompt,unconditional_conditioning是由neg prompt获得的
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
    e_t = self.model.apply_model(x, t, c)
else:
    # 一般都是有neg prompt的,所以进入到这里
    # 在这里我们对隐向量和步数进行复制,一个属于pos prompt,一个属于neg prompt
    # torch.cat默认堆叠维度为0,所以是在bs维度进行堆叠,二者不会互相影响
    x_in = torch.cat([x] * 2)
    t_in = torch.cat([t] * 2)
    # 然后我们将pos prompt和neg prompt堆叠到一个batch中
    if isinstance(c, dict):
        assert isinstance(unconditional_conditioning, dict)
        c_in = dict()
        for k in c:
            if isinstance(c[k], list):
                c_in[k] = [
                    torch.cat([unconditional_conditioning[k][i], c[k][i]])
                    for i in range(len(c[k]))
                ]
            else:
                c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
    else:
        c_in = torch.cat([unconditional_conditioning, c])

在这里插入图片描述
堆叠完后,我们将隐向量、步数和prompt条件一起传入网络中,将结果在bs维度进行使用chunk进行分割。

因为我们在堆叠时,neg prompt放在了前面。因此分割好后,前半部分e_t_uncond属于利用neg prompt得到的,后半部分e_t属于利用pos prompt得到的,我们本质上应该扩大pos prompt的影响,远离neg prompt的影响。因此,我们使用e_t-e_t_uncond计算二者的距离,使用scale扩大二者的距离。在e_t_uncond基础上,得到最后的隐向量。

# 堆叠完后,隐向量、步数和prompt条件一起传入网络中,将结果在bs维度进行使用chunk进行分割
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

在这里插入图片描述
此时获得的e_t就是通过隐向量和prompt共同获得的预测噪声啦。

II、施加噪声

获得噪声就OK了吗?显然不是的,我们还要将获得的新噪声,按照一定的比例添加到原来的原始噪声上。

这个地方我们最好结合ddim中的公式来看,我们需要获得 α ˉ t \bar{\alpha}_t αˉt α ˉ t − 1 \bar{\alpha}_{t-1} αˉt1 σ t \sigma_t σt 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} 1αˉt
在这里插入图片描述
在这里插入图片描述
代码中,我们其实已经预先计算好了这些参数。我们只需要直接取出即可,下方的a_t也就是公式中括号外的 α ˉ t \bar{\alpha}_t αˉt,a_prev 就是公式中的 α ˉ t − 1 \bar{\alpha}_{t-1} αˉt1,sigma_t就是公式中的 σ t \sigma_t σt,sqrt_one_minus_at就是公式中的 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} 1αˉt

# 根据采样器选择参数
alphas      = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas      = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas

# 根据步数选择参数,
# 这里的index就是上面循环中的total_steps - i - 1
a_t         = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev      = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t     = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

其实这一步我们只是把公式需要用到的系数全都拿了出来,方便后面的加减乘除。然后我们便在代码中实现上述的公式。

# current prediction for x_0
# 公式中的最左边
pred_x0             = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
    pred_x0, _, *_  = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
# 公式的中间
dir_xt              = (1. - a_prev - sigma_t**2).sqrt() * e_t
# 公式最右边
noise               = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
    noise           = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev              = a_prev.sqrt() * pred_x0 + dir_xt + noise
# 输出添加完公式的结果
return x_prev, pred_x0

在这里插入图片描述

d、预测噪声过程中的网络结构解析

I、apply_model方法解析

在3.a的预测噪声过程中,我们使用了model.apply_model方法进行噪声的预测,这个方法具体做了什么被隐掉了,我们看看具体做的工作。

apply_model方法在ldm.models.diffusion.ddpm.py文件中。在apply_model中,我们将x_noisy传入self.model中预测噪声。

x_recon = self.model(x_noisy, t, **cond)

在这里插入图片描述
self.model是一个预先构建好的类,定义在ldm.models.diffusion.ddpm.py文件的1416行,内部包含Stable Diffusion的Unet网络,self.model的功能有点类似于包装器,根据模型选择的特征融合方式,进行文本与上文生成的噪声的融合。

c_concat代表使用堆叠的方式进行融合,c_crossattn代表使用attention的方式融合。

class DiffusionWrapper(pl.LightningModule):
    def __init__(self, diff_model_config, conditioning_key):
        super().__init__()
        self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
        # stable diffusion的unet网络
        self.diffusion_model = instantiate_from_config(diff_model_config)
        self.conditioning_key = conditioning_key
        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']

    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
        if self.conditioning_key is None:
            out = self.diffusion_model(x, t)
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)
        elif self.conditioning_key == 'crossattn':
            if not self.sequential_cross_attn:
                cc = torch.cat(c_crossattn, 1)
            else:
                cc = c_crossattn
            out = self.diffusion_model(x, t, context=cc)
        elif self.conditioning_key == 'hybrid':
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc)
        elif self.conditioning_key == 'hybrid-adm':
            assert c_adm is not None
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'crossattn-adm':
            assert c_adm is not None
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'adm':
            cc = c_crossattn[0]
            out = self.diffusion_model(x, t, y=cc)
        else:
            raise NotImplementedError()

        return out

在这里插入图片描述
代码中的self.diffusion_model便是Stable Diffusion的Unet网络,网络结构位于ldm.modules.diffusionmodules.openaimodel.py文件中的UNetModel类。

II、UNetModel模型解析

UNetModel主要做的工作是结合时间步t和文本Embedding计算这一时刻的噪声。尽管UNet的思路非常简单,但是在StableDiffusion中,UNetModel由ResBlock和Transformer模块组成,整体来讲相比于普通的UNet复杂一些。

Prompt通过Frozen CLIP Text Encoder获得Text Embedding,Timesteps通过全连接(MLP)获得Timesteps Embedding;

ResBlock用于结合时间步Timesteps Embedding,Transformer模块用于结合文本Text Embedding。

我在这里放一张大图,同学们可以看到内部shape的变化。
请添加图片描述

Unet代码如下所示:

class UNetModel(nn.Module):
    """
    The full UNet model with attention and timestep embedding.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param num_res_blocks: number of residual blocks per downsample.
    :param attention_resolutions: a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    :param dropout: the dropout probability.
    :param channel_mult: channel multiplier for each level of the UNet.
    :param conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param num_classes: if specified (as an int), then this model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
    :param num_heads: the number of attention heads in each attention layer.
    :param num_heads_channels: if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
    :param resblock_updown: use residual blocks for up/downsampling.
    :param use_new_attention_order: use a different attention pattern for potentially
                                    increased efficiency.
    """

    def __init__(
        self,
        image_size,
        in_channels,
        model_channels,
        out_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        num_classes=None,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=-1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=False,
        use_new_attention_order=False,
        use_spatial_transformer=False,    # custom transformer support
        transformer_depth=1,              # custom transformer support
        context_dim=None,                 # custom transformer support
        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
        legacy=True,
    ):
        super().__init__()
        if use_spatial_transformer:
            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

        if context_dim is not None:
            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
            from omegaconf.listconfig import ListConfig
            if type(context_dim) == ListConfig:
                context_dim = list(context_dim)

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        if num_heads == -1:
            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'

        if num_head_channels == -1:
            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'

        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.dtype = th.float16 if use_fp16 else th.float32
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample
        self.predict_codebook_ids = n_embed is not None

        # 用于计算当前采样时间t的embedding
        time_embed_dim  = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_embed_dim)
        
        # 定义输入模块的第一个卷积
        # TimestepEmbedSequential也可以看作一个包装器,根据层的种类进行时间或者文本的融合。
        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
        )
        self._feature_size  = model_channels
        input_block_chans   = [model_channels]
        ch                  = model_channels
        ds                  = 1
        # 对channel_mult进行循环,channel_mult一共有四个值,代表unet四个部分通道的扩张比例
        # [1, 2, 4, 4]
        for level, mult in enumerate(channel_mult):
            # 每个部分循环两次
            # 添加一个ResBlock和一个AttentionBlock
            for _ in range(num_res_blocks):
                # 先添加一个ResBlock
                # 用于对输入的噪声进行通道数的调整,并且融合t的特征
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                # ch便是上述ResBlock的输出通道数
                ch = mult * model_channels
                if ds in attention_resolutions:
                    # num_heads=8
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    # 使用了SpatialTransformer自注意力,加强全局特征,融合文本的特征
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)
            # 如果不是四个部分中的最后一个部分,那么都要进行下采样。
            if level != len(channel_mult) - 1:
                out_ch = ch
                # 在此处进行下采样
                # 一般直接使用Downsample模块
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, dims=dims, out_channels=out_ch
                        )
                    )
                )
                # 为下一阶段定义参数。
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

        if num_head_channels == -1:
            dim_head = ch // num_heads
        else:
            num_heads = ch // num_head_channels
            dim_head = num_head_channels
        if legacy:
            #num_heads = 1
            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
        # 定义中间层
        # ResBlock + SpatialTransformer + ResBlock
        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(
                ch,
                use_checkpoint=use_checkpoint,
                num_heads=num_heads,
                num_head_channels=dim_head,
                use_new_attention_order=use_new_attention_order,
            ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )
        self._feature_size += ch

        # 定义Unet上采样过程
        self.output_blocks = nn.ModuleList([])
        # 循环把channel_mult反了过来
        for level, mult in list(enumerate(channel_mult))[::-1]:
            # 上采样时每个部分循环三次
            for i in range(num_res_blocks + 1):
                ich = input_block_chans.pop()
                # 首先添加ResBlock层
                layers = [
                    ResBlock(
                        ch + ich,
                        time_embed_dim,
                        dropout,
                        out_channels=model_channels * mult,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = model_channels * mult
                # 然后进行SpatialTransformer自注意力
                if ds in attention_resolutions:
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads_upsample,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                # 如果不是channel_mult循环的第一个
                # 且
                # 是num_res_blocks循环的最后一次,则进行上采样
                if level and i == num_res_blocks:
                    out_ch = ch
                    layers.append(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            up=True,
                        )
                        if resblock_updown
                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
                    )
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch

        # 最后在输出部分进行一次卷积
        self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
        )
        if self.predict_codebook_ids:
            self.id_predictor = nn.Sequential(
            normalization(ch),
            conv_nd(dims, model_channels, n_embed, 1),
            #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
        )

    def convert_to_fp16(self):
        """
        Convert the torso of the model to float16.
        """
        self.input_blocks.apply(convert_module_to_f16)
        self.middle_block.apply(convert_module_to_f16)
        self.output_blocks.apply(convert_module_to_f16)

    def convert_to_fp32(self):
        """
        Convert the torso of the model to float32.
        """
        self.input_blocks.apply(convert_module_to_f32)
        self.middle_block.apply(convert_module_to_f32)
        self.output_blocks.apply(convert_module_to_f32)

    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param context: conditioning plugged in via crossattn
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"
        hs      = []
        # 用于计算当前采样时间t的embedding
        t_emb   = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb     = self.time_embed(t_emb)

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        # 对输入模块进行循环,进行下采样并且融合时间特征与文本特征。
        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, context)
            hs.append(h)

        # 中间模块的特征提取
        h = self.middle_block(h, emb, context)

        # 上采样模块的特征提取
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)
        h = h.type(x.dtype)
        # 输出模块
        if self.predict_codebook_ids:
            return self.id_predictor(h)
        else:
            return self.out(h)

4、隐空间解码生成图片

在这里插入图片描述
通过上述步骤,已经可以多次采样获得结果,然后我们便可以通过隐空间解码生成图片。

隐空间解码生成图片的过程非常简单,将上文多次采样后的结果,使用decode_first_stage方法即可生成图片。

在decode_first_stage方法中,网络调用VAE对获取到的64x64x3的隐向量进行解码,获得512x512x3的图片。

@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
    if predict_cids:
        if z.dim() == 4:
            z = torch.argmax(z.exp(), dim=1).long()
        z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
        z = rearrange(z, 'b h w c -> b c h w').contiguous()

    z = 1. / self.scale_factor * z
	# 一般无需分割输入,所以直接将x_noisy传入self.model中,在下面else进行
    if hasattr(self, "split_input_params"):
    	......
    else:
        if isinstance(self.first_stage_model, VQModelInterface):
            return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
        else:
            return self.first_stage_model.decode(z)

图像到图像预测过程代码

整体预测代码如下:

import os
import random

import cv2
import einops
import numpy as np
import torch
from PIL import Image
from pytorch_lightning import seed_everything

from ldm_hacked import *

# ----------------------- #
#   使用的参数
# ----------------------- #
# config的地址
config_path = "model_data/sd_v15.yaml"
# 模型的地址
model_path  = "model_data/v1-5-pruned-emaonly.safetensors"
# fp16,可以加速与节省显存
sd_fp16     = True
vae_fp16    = True

# ----------------------- #
#   生成图片的参数
# ----------------------- #
# 生成的图像大小为input_shape,对于img2img会进行Centter Crop
input_shape = [512, 512]
# 一次生成几张图像
num_samples = 1
# 采样的步数
ddim_steps  = 20
# 采样的种子,为-1的话则随机。
seed        = 12345
# eta
eta         = 0
# denoise强度,for img2img
denoise_strength = 1.0

# ----------------------- #
#   提示词相关参数
# ----------------------- #
# 提示词
prompt      = "a cute cat, with yellow leaf, trees"
# 正面提示词
a_prompt    = "best quality, extremely detailed"
# 负面提示词
n_prompt    = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
# 正负扩大倍数
scale       = 9
# img2img使用,如果不想img2img这设置为None。
image_path  = None

# ----------------------- #
#   保存路径
# ----------------------- #
save_path   = "imgs/outputs_imgs"

# ----------------------- #
#   创建模型
# ----------------------- #
model   = create_model(config_path).cpu()
model.load_state_dict(load_state_dict(model_path, location='cuda'), strict=False)
model   = model.cuda()
ddim_sampler = DDIMSampler(model)
if sd_fp16:
    model = model.half()

if image_path is not None:
    img = Image.open(image_path)
    img = crop_and_resize(img, input_shape[0], input_shape[1])

with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)
    
    # ----------------------- #
    #   对输入图片进行编码并加噪
    # ----------------------- #
    if image_path is not None:
        img = HWC3(np.array(img, np.uint8))
        img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
        img = torch.stack([img for _ in range(num_samples)], dim=0)
        img = einops.rearrange(img, 'b h w c -> b c h w').clone()
        if vae_fp16:
            img = img.half()
            model.first_stage_model = model.first_stage_model.half()
        else:
            model.first_stage_model = model.first_stage_model.float()

        ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
        t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
        z = model.get_first_stage_encoding(model.encode_first_stage(img))
        z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
        z_enc = z_enc.half() if sd_fp16 else z_enc.float()

    # ----------------------- #
    #   获得编码后的prompt
    # ----------------------- #
    cond    = {"c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
    un_cond = {"c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
    H, W    = input_shape
    shape   = (4, H // 8, W // 8)

    if image_path is not None:
        samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
    else:
        # ----------------------- #
        #   进行采样
        # ----------------------- #
        samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                        shape, cond, verbose=False, eta=eta,
                                                        unconditional_guidance_scale=scale,
                                                        unconditional_conditioning=un_cond)

    # ----------------------- #
    #   进行解码
    # ----------------------- #
    x_samples = model.decode_first_stage(samples.half() if vae_fp16 else samples.float())

    x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

# ----------------------- #
#   保存图片
# ----------------------- #
if not os.path.exists(save_path):
    os.makedirs(save_path)
for index, image in enumerate(x_samples):
    cv2.imwrite(os.path.join(save_path, str(index) + ".jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/54629.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何获取微信公众号关注主页地址

1&#xff0c;首先。在公众号后台发布一篇文章&#xff0c;&#xff08;文章也可以关注公众号&#xff09; 2&#xff0c;浏览器打开文章地址 。在页面找到_biz码 3&#xff0c;https://mp.weixin.qq.com/mp/profile_ext?actionhome&__bizxxxxx&scene110#wechat_redi…

减轻 PWM 的滤波要求

经典脉宽调制器 (PWM) 发出 H 个连续逻辑高电平&#xff08;1&#xff09;&#xff0c;后跟 L 个连续逻辑低电平&#xff08;0&#xff09;的重复序列。每个高电平和低电平持续一个时钟周期 T 1/F (Hz)。结果的占空比可定义为 H/N&#xff0c;其中 N HL 时钟周期。N 通常是 2…

谷粒商城第六天-商品服务之分类管理下的获取三级分类树形列表

目录 一、总述 1.1 前端思路 1.2 后端思路 二、前端部分 2.1 在网页中建好目录及菜单 2.1.1 建好商品目录 2.1.2 建好分类管理菜单 ​编辑 2.2 编写组件 2.2.1 先完成组件文件的创建 2.2.2 编写组件 2.2.2.1 显示三级分类树形列表 三、后端部分 3.1 编写商品分类…

matlab编程实践16、17

捕食者与猎物模型 人口增长 在人口增长或衰减的最简单模型中&#xff0c;增长速度或衰减速度与人口本身的数目成正比。增加或减少人口规模会导致出生和死亡数量成比例地增加或减少。在数学上&#xff0c;可以由以下微分方程描述。 可以得出&#xff1a;&#xff0c;其中。 该简…

2023-08-01 LeetCode每日一题(英雄的力量)

2023-08-01每日一题 一、题目编号 2681. 英雄的力量二、题目链接 点击跳转到题目位置 三、题目描述 给你一个下标从 0 开始的整数数组 nums &#xff0c;它表示英雄的能力值。如果我们选出一部分英雄&#xff0c;这组英雄的 力量 定义为&#xff1a; i0 &#xff0c;i1 &…

Redis - 三大缓存问题(穿透、击穿、雪崩)

缓存穿透 概念&#xff1a; 查询一个数据库中也不存在的数据&#xff0c;数据库查询不到数据也就不会写入缓存&#xff0c;就会导致一直查询数据库 解决方法&#xff1a; 1. 缓存空数据 如果数据库也查询不到&#xff0c;就把空结果进行缓存 缺点是 - 消耗内存 2. 使用布…

ModuleNotFoundError: No module named ‘_sqlite3‘

前言 遇到报错信息如下&#xff1a; ModuleNotFoundError: No module named _sqlite3解决方式 参考解决方式&#xff1a; https://blog.csdn.net/jaket5219999/article/details/53512071 find / -name _sqlite*.socp /usr/lib64/python3.6/lib-dynload/_sqlite3.cpython-36…

Go语言性能优化建议与pprof性能调优详解——结合博客项目实战

文章目录 性能优化建议Benchmark的使用slice优化预分配内存大内存未释放 map优化字符串处理优化结构体优化atomic包小结 pprof性能调优采集性能数据服务型应用go tool pprof命令项目调优分析修改main.go安装go-wrk命令行交互界面图形化火焰图 性能优化建议 简介&#xff1a; …

从0到1开发go-tcp框架【1-搭建server、封装连接与业务绑定、实现基础Router、抽取全局配置文件】

从0到1开发go-tcp框架【1-搭建server、封装连接与业务绑定、实现基础Router】 本期主要完成对Server的搭建、封装连接与业务绑定、实现基础Router&#xff08;处理业务的部分&#xff09;、抽取框架的全局配置文件 从配置文件中读取数据&#xff08;服务器监听端口、监听IP等&a…

记一次phpmyadmin巧妙利用

声明&#xff1a;文中涉及到的技术和工具&#xff0c;仅供学习使用&#xff0c;禁止从事任何非法活动&#xff0c;如因此造成的直接或间接损失&#xff0c;均由使用者自行承担责任。 点点关注不迷路&#xff0c;每周不定时持续分享各种干货。 原文链接&#xff1a;众亦信安&a…

Spring中最简单的过滤器和监听器

1. 过滤器概念引入 Filter也称之为过滤器&#xff0c;它是Servlet技术中最实用的技术&#xff0c;Web开发人员通过Filter技术&#xff0c;对web服务器管理的所有web资源&#xff1a;例如Jsp, Servlet, 静态图片文件或静态 html 文件等进行拦截&#xff0c;从而实现一些特殊的功…

在Windows 10和11中恢复已删除的照片

可以在Windows 10或11上恢复已删除的照片吗&#xff1f; 随着技术的发展&#xff0c;越来越多的用户习惯在电子设备上存储照片。如果这些照片被删除&#xff0c;可能会给用户带来重大损失。当照片丢失时&#xff0c;您可能会想是否可以恢复已删除的照片&#xff1f; …

LabVIEW 开发在不确定路况下自动速度辅助系统

LabVIEW 开发在不确定路况下自动速度辅助系统 智能驾驶辅助系统是汽车行业最先进的升级和尖端技术&#xff0c;智能交通系统依靠智能驾驶辅助系统在公共交通部门工作。该智能驾驶辅助系统技术包括自适应巡航控制&#xff0c;防抱死制动系统&#xff0c;安全气囊展开&#xff0…

腾讯云从业者认证考试考——云服务器

文章目录 云服务器的产品概览腾讯云服务器的优势腾讯云服务器选型腾讯云服务器计费方案 云服务器的产品概览 腾讯云服务器的产品&#xff1f; CVM云服务器&#xff08;Cloud Virtual Machine&#xff0c;CVM&#xff09;提供安全可靠的弹性计算服务。 可以在云端获取和启用 CV…

根据前序和中序遍历序列构造二叉树 (递归+迭代两种方法实现)

给定两个整数数组 preorder 和 inorder &#xff0c;其中 preorder 是二叉树的先序遍历&#xff0c; inorder 是同一棵树的中序遍历&#xff0c;请构造二叉树并返回其根节点。 输入: preorder [3,9,20,15,7], inorder [9,3,15,20,7] 输出: [3,9,20,null,null,15,7]源代码如下…

《吐血整理》进阶系列教程-拿捏Fiddler抓包教程(13)-Fiddler请求和响应断点调试

1.简介 Fiddler有个强大的功能&#xff0c;可以修改发送到服务器的数据包&#xff0c;但是修改前需要拦截&#xff0c;即设置断点。设置断点后&#xff0c;开始拦截接下来所有网页&#xff0c;直到取消断点。这个功能可以在数据包发送之前&#xff0c;修改请求参数&#xff1b…

逻辑回归变量系数可为负数吗?应该如何解释?

之前很多学员来问逻辑回归变量系数是否都应该为正数&#xff0c;如果出现负的变量系数该怎么办&#xff1f;是否需要重新建模&#xff1f;这些学员都是在网上搜索时&#xff0c;被错误信息误导。网上信息可以随意转载&#xff0c;且无人审核对错。我见过最多情况时很多文章正确…

第4章 案例研究:JavaScript图片库

案例 html部分 <h1 id"title">图片1</h1> <ul><li><!-- onclick绑定点击事件&#xff0c;this为触发dom&#xff0c;return false阻止默认行为 --><a onclick"show_img(this); return false" title"图片1" h…

命令模式-请求发送者与接收者解耦

去小餐馆吃饭的时候&#xff0c;顾客直接跟厨师说想要吃什么菜&#xff0c;然后厨师再开始炒菜。去大点的餐馆吃饭时&#xff0c;我们是跟服务员说想吃什么菜&#xff0c;然后服务员把这信息传到厨房&#xff0c;厨师根据这些订单信息炒菜。为什么大餐馆不省去这个步骤&#xf…

装饰器模式(Decorator)

装饰器模式是一种结构型设计模式&#xff0c;用来动态地给一个对象增加一些额外的职责。就增加对象功能来说&#xff0c;装饰器模式比生成子类实现更为灵活。装饰器模式的别名为包装器(Wrapper)&#xff0c;与适配器模式的别名相同&#xff0c;但它们适用于不同的场合。 Decor…