【深度学习】StableDiffusion的组件解析,运行一些基础组件效果

文章目录

  • 前言
  • vae
  • clip
  • UNet
  • unet训练
  • 帮助、问询

前言

看了篇文:

https://zhuanlan.zhihu.com/p/617134893

运行一些组件试试效果。

vae

代码:

import torch
from diffusers import AutoencoderKL
import numpy as np
from PIL import Image

# 加载模型: autoencoder可以通过SD权重指定subfolder来单独加载
autoencoder = AutoencoderKL.from_pretrained(
    "/ssd/xiedong/src_data/eff_train/Stable-diffusion/majicmixRealistic_v7_diffusers", subfolder="vae")
autoencoder.to("cuda", dtype=torch.float16)

# 读取图像并预处理
raw_image = Image.open("girl.png").convert("RGB").resize((512, 512))
image = np.array(raw_image).astype(np.float32) / 127.5 - 1.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)

# 压缩图像为latent并重建
with torch.inference_mode():
    # latentx形状是 (B, C, H, W) 的张量  (1,4,64,64)
    latentx = autoencoder.encode(image.to("cuda", dtype=torch.float16)).latent_dist.sample()  # 压缩
    # 保存 latent 为 PNG 图像
    latent = latentx.permute(0, 2, 3, 1)  # 将 latent 重新排列为 (B, H, W, C) 格式
    latent = latent.cpu().numpy()  # 将 latent 转换为 NumPy 数组
    latent = (latent * 127.5 + 127.5).astype('uint8')  # 将值缩放到 [0, 255] 范围内,并转换为 uint8 类型
    latent = latent.squeeze(0)  # 去掉批次维度

    latent_image = Image.fromarray(latent)  # 将 NumPy 数组转换为 PIL Image
    latent_image.save("latent.png")  # 保存为 PNG 图像
    # shape
    print(latentx.shape)
    rec_image = autoencoder.decode(latentx).sample
    rec_image = (rec_image / 2 + 0.5).clamp(0, 1)
    rec_image = rec_image.cpu().permute(0, 2, 3, 1).numpy()
    rec_image = (rec_image * 255).round().astype("uint8")
    rec_image = Image.fromarray(rec_image[0])

# save
rec_image.save("demo.png")

原图:

在这里插入图片描述

encoder之后:

在这里插入图片描述
decoder之后:

在这里插入图片描述

clip

在自然语言处理任务中,tokenizer和text_encoder是两个重要的组件,用于将文本转换为模型可以理解的数值表示形式。

  1. Tokenizer:

Tokenizer的作用是将一个文本序列(如句子或段落)分割成一系列的token(通常是单词或子词)。它将文本映射到一个词表(vocabulary),每个token对应词表中的一个索引(index)。在您的代码示例中,tokenizer将输入的文本"a photograph of an astronaut riding a horse"转换为对应的token id序列。

  1. Text Encoder:

Text Encoder(文本编码器)是一个预训练的语言模型,它可以将token id序列编码为对应的向量表示(embeddings)。这些向量表示捕获了单词及其上下文的语义信息。

在您的代码中:

  • tokenizer(prompt) 将文本"a photograph of an astronaut riding a horse"转换为对应的token id序列(如[1, 27, 38, 61, …]等)。
  • text_encoder(text_input_ids) 将这个token id序列输入到文本编码器模型中,得到对应的向量表示text_embeddings

最终得到的text_embeddings的形状是 torch.Size([1, 77, 768])。其中:

  • 1 表示批次大小(batch size),对于单个输入就是1。
  • 77 表示输入token的数量(由于padding,长度被填充到模型的最大长度)。
  • 768 是每个token的向量表示的维度。

通过这种编码方式,原始的文本被转换为了具有丰富语义信息的数值向量表示,方便被深度学习模型进一步处理。这种编码过程是自然语言处理中的常见做法。

from transformers import CLIPTextModel, CLIPTokenizer

text_encoder = CLIPTextModel.from_pretrained(
    "/ssd/xiedong/src_data/eff_train/Stable-diffusion/majicmixRealistic_v7_diffusers", subfolder="text_encoder").to(
    "cuda")
# text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to("cuda")
tokenizer = CLIPTokenizer.from_pretrained(
    "/ssd/xiedong/src_data/eff_train/Stable-diffusion/majicmixRealistic_v7_diffusers", subfolder="tokenizer")
# tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

# 对输入的text进行tokenize,得到对应的token ids
prompt = "a photograph of an astronaut riding a horse"
text_input_ids = tokenizer(
    prompt,
    padding="max_length",
    max_length=tokenizer.model_max_length,
    truncation=True,
    return_tensors="pt"
).input_ids

print(text_input_ids.shape)  # torch.Size([1, 77]
# 将token ids送入text model得到77x768的特征
text_embeddings = text_encoder(text_input_ids.to("cuda"))[0]
print(text_embeddings.shape)  # torch.Size([1, 77, 768])

UNet

稳定扩散(Stable Diffusion,SD)是一种基于扩散模型的生成式人工智能模型,用于创建图像。它的核心是一个860万参数的UNet结构,负责将文本提示转化为图像。

UNet结构主要由编码器(Encoder)和解码器(Decoder)两部分组成,两部分通过跳跃连接(Skip Connection)相连。编码器用于捕获输入的潜在表征(Latent Representation),解码器则根据这些表征生成最终图像。

具体而言,编码器包含:

  • 3个CrossAttnDownBlock2D模块,每个模块会对输入进行2倍下采样
  • 1个DownBlock2D模块,不进行下采样

中间是一个UNetMidBlock2DCrossAttn模块,用于连接编码器和解码器。

解码器包含:

  • 1个UpBlock2D模块
  • 3个CrossAttnUpBlock2D模块

编码器和解码器的模块数量和结构是对应的,通过跳跃连接融合不同尺度的特征。

在这里插入图片描述

CrossAttnDownBlock2D 和 CrossAttnUpBlock2D 是 Stable Diffusion 中 UNet 架构的关键模块,用于将文本条件(text condition)融入到图像生成的过程中。它们利用了自注意力(Self-Attention)机制,将文本嵌入与图像特征进行交叉注意力(Cross-Attention)操作。

以 CrossAttnDownBlock2D 为例,其核心是一个 CrossAttention 模块,该模块的工作原理如下:

  1. 将文本条件(如"一只可爱的小狗")编码为文本嵌入(text embeddings),记为 E t e x t \mathbf{E}_{text} Etext

  2. 从 UNet 的中间层获取图像特征,记为 F i m a g e \mathbf{F}_{image} Fimage

  3. 在 CrossAttention 中,将 F i m a g e \mathbf{F}_{image} Fimage 作为 Query,而 E t e x t \mathbf{E}_{text} Etext 作为 Key 和 Value,计算注意力权重:

    Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

    其中 Q = F i m a g e Q=\mathbf{F}_{image} Q=Fimage, K = E t e x t K=\mathbf{E}_{text} K=Etext, V = E t e x t V=\mathbf{E}_{text} V=Etext, d k d_k dk 是缩放因子。

  4. 将注意力权重与 F i m a g e \mathbf{F}_{image} Fimage 相加,得到融合了文本条件的新特征表示。

这个过程可以用以下伪代码表示:

class CrossAttention(nn.Module):
    def forward(self, x, context):
        """
        x: 图像特征 [batch, channels, height, width]
        context: 文本嵌入 [batch, text_len, text_dim]
        """
        q = x.view(x.size(0), -1, x.size(-1))  # [batch, channels*height*width, dim]
        k = context.permute(0, 2, 1)  # [batch, text_dim, text_len]
        v = context.permute(0, 2, 1)  # [batch, text_dim, text_len]
        
        attn = torch.bmm(q, k)  # [batch, channels*height*width, text_len]
        attn = attn / sqrt(k.size(-1))
        attn = softmax(attn, dim=-1)
        
        x = torch.bmm(attn, v)  # [batch, channels*height*width, text_dim]
        x = x.view(x.size(0), -1, x.size(-1))  # [batch, channels, height, width]
        
        return x

CrossAttnUpBlock2D 的工作方式与 CrossAttnDownBlock2D 类似,只是它位于 UNet 的解码器部分。通过这种交叉注意力机制,Stable Diffusion 能够将文本条件与图像特征进行融合,从而根据提示生成所需的图像。

CrossAttention 模块是 Stable Diffusion 中实现文本到图像生成的关键部分,它利用注意力机制将文本信息注入到图像特征中,使生成的图像能够匹配给定的文本描述。

unet训练

SD训练过程原理:

SD模型的训练过程可以概括为以下几个步骤:

  1. 将图像编码到潜在空间,获得潜在向量表示。
  2. 使用CLIP文本编码器获取文本的嵌入向量表示。
  3. 在潜在空间中采样噪声,并将其添加到潜在向量中,进行扩散过程。
  4. U-Net模型接收扩散后的噪声潜在向量和文本嵌入,预测原始的噪声向量。
  5. 计算预测噪声和实际噪声之间的均方误差作为损失函数。
  6. 通过反向传播优化U-Net模型参数,使其能够更好地预测噪声。

整个过程可以用以下公式表示:

L = E x 0 , ϵ , t [ ∥ ϵ − ϵ θ ( x t , t , y ) ∥ 2 ] \mathcal{L} = \mathbb{E}_{x_0, \epsilon, t} \left[\left\| \epsilon - \epsilon_\theta(x_t, t, y) \right\|^2\right] L=Ex0,ϵ,t[ϵϵθ(xt,t,y)2]

其中:

  • x 0 x_0 x0是原始的潜在向量
  • ϵ \epsilon ϵ是添加到潜在向量上的噪声
  • t t t是随机采样的时间步长
  • x t x_t xt是扩散后的噪声潜在向量
  • y y y是文本嵌入向量
  • ϵ θ \epsilon_\theta ϵθ是U-Net模型预测的噪声

通过最小化这个损失函数,U-Net模型可以学习到从噪声潜在向量和文本嵌入中预测原始噪声的能力。在生成图像时,可以通过从随机噪声开始,逐步去噪,并根据文本嵌入对每一步的去噪过程进行条件控制,最终得到符合文本描述的图像。

Classifier-Free Guidance (CFG)是一种在训练和生成过程中提高文本条件控制的技术。它的核心思想是在训练时,以一定概率随机丢弃文本嵌入,这样模型就可以同时学习条件预测和无条件预测。在生成时,将条件预测和无条件预测的结果进行加权融合,从而增强文本条件的影响。CFG可以用以下公式表示:

ϵ θ ( C F G ) = ϵ θ ( x t , t , y ) + s ⋅ ( ϵ θ ( x t , t , ∅ ) − ϵ θ ( x t , t , y ) ) \epsilon_\theta^{(CFG)} = \epsilon_\theta(x_t, t, y) + s \cdot (\epsilon_\theta(x_t, t, \emptyset) - \epsilon_\theta(x_t, t, y)) ϵθ(CFG)=ϵθ(xt,t,y)+s(ϵθ(xt,t,)ϵθ(xt,t,y))

其中:

  • ϵ θ ( x t , t , y ) \epsilon_\theta(x_t, t, y) ϵθ(xt,t,y)是条件预测的噪声
  • ϵ θ ( x t , t , ∅ ) \epsilon_\theta(x_t, t, \emptyset) ϵθ(xt,t,)是无条件预测的噪声
  • s s s是一个scale参数,用于控制条件和无条件预测的权重

通过CFG,可以有效地提高生成图像与输入文本的一致性。

帮助、问询

https://docs.qq.com/sheet/DUEdqZ2lmbmR6UVdU?tab=BB08J2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/524868.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Redis 知识储备】读写分离/主从分离架构 -- 分布系统的演进(4)

读写分离/主从分离架构 简介出现原因架构工作原理技术案例架构优缺点 简介 将数据库读写操作分散到不同的节点上, 数据库服务器搭建主从集群, 一主一从, 一主多从都可以, 数据库主机负责写操作, 从机只负责读操作 出现原因 数据库成为瓶颈, 而互联网应用一般读多写少, 数据库…

zdpdjango_argonadmin Django后台管理系统中的常见功能开发

效果预览 首先&#xff0c;看一下这个项目最开始的样子&#xff1a; 左侧优化 将左侧优化为下面的样子&#xff1a; 代码位置&#xff1a; 代码如下&#xff1a; {% load i18n static admin_argon %}<aside class"sidenav bg-white navbar navbar-vertical na…

SpringCloud Alibaba Sentinel 创建流控规则

一、前言 接下来是开展一系列的 SpringCloud 的学习之旅&#xff0c;从传统的模块之间调用&#xff0c;一步步的升级为 SpringCloud 模块之间的调用&#xff0c;此篇文章为第十四篇&#xff0c;即介绍 SpringCloud Alibaba Sentinel 创建流控规则。 二、基本介绍 我们在 senti…

Golang | Leetcode Golang题解之第16题最接近的三数之和

题目&#xff1a; 题解&#xff1a; func threeSumClosest(nums []int, target int) int {sort.Ints(nums)var (n len(nums)best math.MaxInt32)// 根据差值的绝对值来更新答案update : func(cur int) {if abs(cur - target) < abs(best - target) {best cur}}// 枚举 a…

2024/4/1—力扣—最小高度树

代码实现&#xff1a; /*** Definition for a binary tree node.* struct TreeNode {* int val;* struct TreeNode *left;* struct TreeNode *right;* };*/ struct TreeNode* buildTree(int *nums, int l, int r) {if (l > r) {return NULL; // 递归出口}struct…

加州大学欧文分校英语基础语法专项课程01:Word Forms and Simple Present Tense 学习笔记

Word Forms and Simple Present Tense Course Certificate 本文是学习Coursera上 Word Forms and Simple Present Tense 这门课程的学习笔记。 文章目录 Word Forms and Simple Present TenseWeek 01: Introduction & BE VerbLearning Objectives Word FormsWord Forms (P…

云原生安全当前的挑战与解决办法

云原生安全作为一种新兴的安全理念&#xff0c;不仅解决云计算普及带来的安全问题&#xff0c;更强调以原生的思维构建云上安全建设、部署与应用&#xff0c;推动安全与云计算深度融合。所以现在云原生安全在云安全领域越来受到重视&#xff0c;云安全厂商在这块的投入也是越来…

工业网络自动化控制赛项分析

时间过去很久了,我突然想起来这篇文章还没写… 设备 它实际上是一个药盒装盖然后再进行一个归类码垛 左侧是供料,主要将盒子推出然后传送带送至中间工作站 中间工作站进行对料盒进行钢珠装填 再通过图像处理,判断大小,然后将数据传送到云服务器,最后通过伺服电机进行分类 …

飞书文档如何在不同账号间迁移

今天由于个人需要新建了一个飞书账号&#xff0c;遇到个需求就是需要把老帐号里面的文档迁移到新的账号里面。在网上搜了一通&#xff0c;发现关于此的内容似乎不多&#xff0c;只好自己动手解决&#xff0c;记录一下过程以便分享&#xff0c;主要有以下几个步骤。 1. 添加新账…

蓝桥杯 历届真题 双向排序【第十二届】【省赛】【C组】

资源限制 内存限制&#xff1a;256.0MB C/C时间限制&#xff1a;1.0s Java时间限制&#xff1a;3.0s Python时间限制&#xff1a;5.0s 改了半天只有60分&#xff0c;还是超时&#xff0c;还不知道怎么写&#xff0c;后面再看吧┭┮﹏┭┮ #include<bits/stdc.h> …

Java | Leetcode Java题解之第16题最接近的三数之和

题目&#xff1a; 题解&#xff1a; class Solution {public int threeSumClosest(int[] nums, int target) {Arrays.sort(nums);int n nums.length;int best 10000000;// 枚举 afor (int i 0; i < n; i) {// 保证和上一次枚举的元素不相等if (i > 0 && nums…

YoloV8改进策略:Neck改进改进|ELA

摘要 本文使用最新的ELA注意力机制改进YoloV8&#xff0c;实现涨点&#xff01;改进方式简单易用&#xff0c;涨点明显&#xff01;欢迎大家使用。 大家在订阅专栏后&#xff0c;记着加QQ群啊&#xff01;有些改进方法确实有难度&#xff0c;大家在改进的过程中遇到问题&#…

【QT+QGIS跨平台编译】063:【qca-softstore+Qt跨平台编译】(一套代码、一套框架,跨平台编译)

点击查看专栏目录 文章目录 一、qca-softstore介绍二、QCA下载三、文件分析四、pro文件五、编译实践5.1 windows下编译5.2 linux下编译5.3 macos下编译一、qca-softstore介绍 QCA-Softstore 是一个软件证书存储插件,它是为 QCA 框架设计的。这个插件提供了一个简单的持久化证书…

SAP 资产管理中如何调整折旧(摊销)金额

在资产管理的日常中可能涉及资产折旧金额的调整&#xff08;或者需要调增折旧&#xff0c;或者调减折旧额&#xff09;。这是需要使用到事务代码ABAA或者ABMA。在SAP中&#xff0c;ABAA和ABMA是两个不同的事务代码&#xff0c;它们都与固定资产折旧相关&#xff0c;但用途和处理…

分表?分库?分库分表?实践详谈 ShardingSphere-JDBC

如果有不是很了解ShardingSphere的可以先看一下这个文章&#xff1a; 《ShardingSphere JDBC?Sharding JDBC&#xff1f;》基本小白脱坑问题 阿丹&#xff1a; 在很多开发场景下面&#xff0c;很多的技术难题都是出自于&#xff0c;大数据量级或者并发的场景下面的。这里就出…

【二分查找】Leetcode 搜索插入位置

题目解析 35. 搜索插入位置 这道题就是寻找target的目标位置&#xff0c;如果nums中包含target直接返回索引&#xff1b;如果不包含&#xff0c;需要返回target存放的合适位置 注意这道题有一个细节地方需要注意&#xff1a;如果现在target没有在nums中出现&#xff0c;并且目…

lua学习笔记9(字典的学习)

print("********************字典的学习***********************") a{["凌少"]"傻逼",["我"]"天才",["age"]24,["daihao"]114514,["8848"]20000} --访问单个变量 print(a["凌少"])…

华为海思2024春招数字芯片岗机试题(共9套)

huawei海思2024春招数字芯片岗机试题(共9套&#xff09;&#xff08;WX:didadidadidida313&#xff0c;加我备注&#xff1a;CSDN huawei数字题目&#xff0c;谢绝白嫖哈&#xff09; 题目包含数字集成电路、System Verilog、Verilog2001、半导体制造技术、高级ASIC芯片综合、…

finebi6.0中我的分析中...中加自己的菜单

js的两个扩展点是&#xff1a; BI.config("bi.result_wrapper", function (e) {return e.showMerge !0, e}),BI.config("bi.analysis.admin_list", function (e) {return e.showMergeUser !0, e}) 对应的组件在conf.min.js中的 bi.search_sort 点击事件…

pdf、docx、markdown、txt提取文档内容,可以应用于rag文档解析

返回的是文档解析分段内容组成的列表&#xff0c;分段内容默认chunk_size: int 250, chunk_overlap: int 50&#xff0c;250字分段&#xff0c;50分段处保留后面一段的前50字拼接即窗口包含下下一段前面50个字划分 from typing import Union, Listimport jieba import recla…