Keras深度学习框架第二十讲:使用KerasCV中的Stable Diffusion进行高性能图像生成

1、绪论

1.1 概念

为便于后文讨论,首先进行相关概念的陈述。

  • Stable Diffusion:Stable Diffusion 是一个在图像生成领域广泛使用的技术,尤其是用于文本到图像的转换。它基于扩散模型(Diffusion
    Models),这是一种深度生成模型,通过逐步去除图像中的噪声来生成新图像。Stable Diffusion
    是一种特定的实现,以其稳定性和高质量的图像生成能力而闻名。

  • KerasCV:Keras 是一个高级神经网络API,能够运行在 TensorFlow、Theano 或 CNTK 之上。KerasCV 是 Keras 的一个扩展库,专注于计算机视觉任务。虽然目前并没有一个官方的 KerasCV 库被广泛接受,但这样的命名通常意味着一个专注于计算机视觉的 Keras 扩展。

  • 高性能图像生成:这指的是使用高效且强大的计算方法来生成高质量的图像。在深度学习和计算机视觉中,高性能通常意味着模型能够在较短时间内处理大量数据,同时保持或提高生成的图像质量。

1.2 本文探讨的范围

本文将展示如何使用KerasCV实现的稳定性.ai的文本到图像模型Stable Diffusion,根据文本提示生成新颖的图像。

Stable Diffusion是一个功能强大、开源的文本到图像生成模型。尽管存在多个开源实现,可以让您轻松地从文本提示创建图像,但KerasCV的实现提供了一些独特优势。这些包括XLA编译和混合精度支持,这两者结合在一起实现了最先进的生成速度。

本文将探索KerasCV的Stable Diffusion实现,展示如何使用这些强大的性能提升,并探索它们提供的性能优势。

注意:如果要在torch后端运行,请在所有地方将jit_compile设置为False。Stable Diffusion的XLA编译目前不适用于torch。

1.3 Stable Diffusion的优缺点

Stable Diffusion 是一种强大的开源文本到图像生成模型,它使用扩散模型(Diffusion Models)来逐步去除图像中的噪声,从而生成新的图像。以下是 Stable Diffusion 的一些主要优点和缺点:

优点:

  • 高质量输出:Stable Diffusion 能够根据文本描述生成高分辨率、逼真的图像,满足各种设计需求。
  • 可控的生成过程:通过调整扩散参数和逆向过程,可以控制生成过程的速度和效果,使用户能够根据自己的需求和偏好生成图像。
  • 广泛的适用性:该模型适用于各种领域,如艺术创作、产品设计、虚拟现实等,为这些领域提供了更多的创意和可能性。
  • 可扩展性:随着深度学习技术的不断进步,Stable Diffusion 的性能和效果有望进一步提升。
  • 高稳定性:Stable Diffusion 通过引入新的稳定性系数来控制模型的稳定性,避免了其他扩散模型中出现的不稳定性问题。
  • 较快的训练速度:通过使用更小的批次大小和更少的步骤来训练模型,Stable Diffusion 提高了训练速度。

缺点:

  • 计算资源需求高:Stable Diffusion 的训练和生成过程需要大量的计算资源和时间,尤其是在生成高分辨率图像时。
  • 超参数调优:Stable Diffusion 的性能和效果可能受到许多超参数的影响,如扩散参数、模型架构等,因此需要进行细致的调优。
  • 牺牲多样性:由于引入了稳定性系数,Stable Diffusion 可能会在一定程度上牺牲生成样本的多样性。
  • 生成样本速度可能较慢:虽然训练速度有所提升,但在生成样本时,Stable Diffusion 的速度可能会相对较慢。

总的来说,Stable Diffusion 是一种功能强大的文本到图像生成模型,具有广泛的应用前景。然而,它也面临一些挑战,如计算资源需求高和需要细致的超参数调优。随着技术的不断发展,这些问题有望得到解决。

1.4 设置

安装

!pip install -q --upgrade keras-cv
!pip install -q --upgrade keras  # Upgrade to Keras 3.

导入

import time
import keras_cv
import keras
import matplotlib.pyplot as plt

2、Stable Diffusion的使用方法

2.1 构建模型

使用类似如下的代码构建模型

model = keras_cv.models.StableDiffusion(
    img_width=512, img_height=512, jit_compile=False
)

2.2 编辑文本提示

按照以下代码的方式给模型输入文本提示

images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)


def plot_images(images):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.axis("off")


plot_images(images)

在这里插入图片描述
下面,我们看一个更复杂的示例

images = model.text_to_image(
    "cute magical flying dog, fantasy art, "
    "golden color, high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
plot_images(images)

在这里插入图片描述

3、工作原理

Stable Diffusion 并不是基于魔法的,它是一种“潜在扩散模型”。让我们深入了解一下这意味着什么。

程序员可能对超分辨率有所了解:可以通过训练一个深度学习模型来对输入图像进行去噪,从而将其转换为更高分辨率的版本。深度学习模型并不是通过神奇地恢复从嘈杂的低分辨率输入中缺失的信息来实现这一点的——相反,该模型使用其训练数据分布来“想象”出最可能基于输入的视觉细节。要了解更多关于超分辨率的信息,你可以查看以下 Keras.io 教程:

  • 使用高效的子像素卷积神经网络的图像超分辨率
  • 用于单图像超分辨率的增强深度残差网络

当程序员将这个想法推向极致时,可能会开始思考——如果我们只在纯噪声上运行这样的模型会怎样?模型将“对噪声进行去噪”并开始生成一个全新的图像。通过多次重复这个过程,程序员可以将一小块噪声转变为越来越清晰和高分辨率的人造图片。

这2020 年在《High-Resolution Image Synthesis with Latent Diffusion Models》中提出的潜在扩散的关键思想。要深入了解扩散模型,程序员可以查看 Keras.io 上的教程《Denoising Diffusion Implicit Models》。

要将潜在扩散模型转化为文本到图像系统,程序员还需要添加一个关键功能:通过提示关键字控制生成的视觉内容。这是通过“条件化”(conditioning)实现的,这是一种经典的深度学习技术,包括将表示文本的向量与噪声块连接起来,然后在包含 {图像: 描述} 对的数据集上训练模型。

这就产生了 Stable Diffusion 架构。Stable Diffusion 由三部分组成:

  • 文本编码器,将你的提示转换为潜在向量。
  • 扩散模型,反复“去噪”一个 64x64 的潜在图像块。
  • 解码器,将最终的 64x64 潜在块转换为更高分辨率的 512x512 图像。

首先,程序员的文本提示被文本编码器投影到潜在向量空间,这只是一个预训练并固定的语言模型。然后,该提示向量与随机生成的噪声块连接起来,通过一系列的“步骤”被扩散模型反复“去噪”(你运行的步骤越多,图像就越清晰、越漂亮——默认值是 50 步)。

最后,64x64 的潜在图像通过解码器以高分辨率正确渲染。

总的来说,这是一个相当简单的系统——Keras 实现包含在四个文件中,总共不到 500 行代码:

  • text_encoder.py: 87 行代码
  • diffusion_model.py: 181 行代码
  • decoder.py: 86 行代码
  • stable_diffusion.py: 106 行代码

但是,一旦程序员在数十亿张图片及其描述上进行训练,这个相对简单的系统就开始看起来像魔法一样。就像费曼(Feynman)对宇宙的描述:“它并不复杂,只是有很多!”

3.1 KerasCV的Stable Diffusion模型的优点

在多个公开的Stable Diffusion实现中,为什么你应该选择keras_cv.models.StableDiffusion呢?

除了易于使用的API之外,KerasCV的Stable Diffusion模型还带有一些强大的优势,包括:

  • 图模式执行(Graph mode execution):这种方式可以更高效地执行计算图,因为它允许优化和并行化计算。
  • 通过jit_compile=True启用XLA编译:XLA(Accelerated Linear Algebra)是TensorFlow的一个特性,它可以将多个计算步骤融合为一个优化的操作,从而提高执行速度。
  • 支持混合精度计算:混合精度计算允许模型在较低精度的数据类型上运行,如FP16(半精度浮点数),这通常可以提高速度并减少内存消耗,同时保持模型精度。

当这些特性组合在一起时,KerasCV的Stable Diffusion模型运行速度比原始实现快得多。为了进行比较,我们运行了基准测试,比较了HuggingFace的diffusers实现的Stable Diffusion与KerasCV实现的运行时。两个实现的任务都是为每张图像生成3张图像,每张图像的步数为50。在这个基准测试中,我们使用了Tesla T4 GPU。

我们所有的基准测试都是开源的,可以在GitHub上找到,并且可以在Colab上重新运行以重现结果。基准测试的结果如下表所示:

GPU ModelImplementationRuntime (s)
Tesla T4KerasCV (Warm Start)28.97
Tesla T4diffusers (Warm Start)41.33
Tesla V100KerasCV (Warm Start)12.45
Tesla V100diffusers (Warm Start)12.72

在Tesla T4上,执行时间提高了30%!尽管在V100上的提升幅度较小,但我们通常期望在所有的NVIDIA GPU上,基准测试的结果都将一致地偏向于KerasCV。

为了完整性,我们报告了冷启动和热启动的生成时间。冷启动执行时间包括模型创建和编译的一次性成本,因此在生产环境中(在那里你会多次重用相同的模型实例)它是可以忽略的。尽管如此,以下是冷启动的数据:

GPUModelRuntime (Cold Start)
Tesla T4KerasCV83.47s
Tesla T4diffusers46.27s
Tesla V100KerasCV76.43s
Tesla V100diffusers13.90s

注意:每个优化在硬件设置之间的性能提升可能会有很大的差异。

3.2 模型的基准测试方法

为了后续的讨论,我们先来对未优化模型进行基准测试:

benchmark_result = []
start = time.time()
images = model.text_to_image(
    "A cute otter in a rainbow whirlpool holding shells, watercolor",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["Standard", end - start])
plot_images(images)

print(f"Standard model: {(end - start):.2f} seconds")
keras.backend.clear_session()  # Clear session to preserve memory.

在这里插入图片描述

3.3混合精度

“混合精度”是指在计算时使用float16精度,同时将权重存储在float32格式中。这是为了利用现代NVIDIA GPU上float16操作比float32操作具有显著更快内核支持的事实。

在Keras中启用混合精度计算只需调用:

keras.mixed_precision.set_global_policy("mixed_float16")

之后,我们就能够非常简单的使用float16了:

model = keras_cv.models.StableDiffusion(jit_compile=False)

print("Compute dtype:", model.diffusion_model.compute_dtype)
print(
    "Variable dtype:",
    model.diffusion_model.variable_dtype,
)
By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
Compute dtype: float16
Variable dtype: float32

这样,上面构建的模型现在使用混合精度计算;在计算时使用float16操作的速度优势,同时以float32精度存储变量。

# Warm up model to run graph tracing before benchmarking.
model.text_to_image("warming up the model", batch_size=3)

start = time.time()
images = model.text_to_image(
    "a cute magical flying dog, fantasy art, "
    "golden color, high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["Mixed Precision", end - start])
plot_images(images)

print(f"Mixed precision model: {(end - start):.2f} seconds")
keras.backend.clear_session()

在这里插入图片描述

3.4XLA编译

TensorFlow和JAX内置了XLA(加速线性代数)编译器。keras_cv.models.StableDiffusion原生支持jit_compile参数。将此参数设置为True可以启用XLA编译,从而显著提高速度。

# Set back to the default for benchmarking purposes.
keras.mixed_precision.set_global_policy("float32")

model = keras_cv.models.StableDiffusion(jit_compile=True)
# Before we benchmark the model, we run inference once to make sure the TensorFlow
# graph has already been traced.
images = model.text_to_image("An avocado armchair", batch_size=3)
plot_images(images)
y using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE
 50/50 ━━━━━━━━━━━━━━━━━━━━ 48s 209ms/step

在这里插入图片描述
XLA对标结果

start = time.time()
images = model.text_to_image(
    "A cute otter in a rainbow whirlpool holding shells, watercolor",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["XLA", end - start])
plot_images(images)

print(f"With XLA: {(end - start):.2f} seconds")
keras.backend.clear_session()
50/50 ━━━━━━━━━━━━━━━━━━━━ 11s 210ms/step
With XLA: 10.63 seconds

在这里插入图片描述
在目标GPU上,XLA实现了2倍的速度提升。

4、总结

在构建高性能的稳定扩散推理流程时,关键在于利用现代计算硬件和软件优化技术来最大化模型的推理速度和效率。以下是构建此类流程时需要考虑的关键点:

  • 硬件选择:选择支持高效浮点运算和并行处理的现代GPU硬件,如NVIDIA的GPU系列。这些GPU通常具有强大的计算能力和优化的硬件加速功能,能够显著提升模型的推理速度。

  • 混合精度计算:利用混合精度计算技术,即在计算过程中使用float16精度,并将权重存储在float32格式中。这种策略能够减少内存占用和计算量,同时保持足够的数值稳定性。在NVIDIA GPU上,float16操作通常比float32操作更快,因此可以显著提高推理速度。

  • XLA编译:启用TensorFlow或JAX中的XLA编译器,将计算图编译成优化的机器代码。XLA编译器能够自动进行各种优化,如循环展开、内联函数和向量化等,从而进一步提高推理性能。

  • 模型优化:对稳定扩散模型进行必要的优化,如剪枝、量化或模型压缩等。这些技术可以减少模型的复杂性和参数量,从而加快推理速度并减少内存占用。

  • 批处理:在可能的情况下,使用批处理来一次性处理多个输入数据。这可以充分利用GPU的并行处理能力,提高推理吞吐量。

  • 数据预处理和后处理:优化数据预处理和后处理流程,确保它们与推理流程相匹配并尽可能高效。例如,使用适当的数据加载器和缓存机制来加速数据读取。

  • 软件优化:利用最新的深度学习框架和库(如TensorFlow、PyTorch或JAX),这些库通常包含了许多针对高性能计算的优化功能。同时,确保你的代码是高效且并发友好的,以便充分利用多核CPU和GPU资源。

  • 监控与调优:在推理流程运行过程中进行性能监控,并使用性能分析工具来识别瓶颈和潜在优化点。根据监控结果进行调整和优化,以实现最佳性能。

综上所述,构建高性能的稳定扩散推理流程需要综合考虑硬件选择、计算精度、编译优化、模型优化、批处理、数据预处理和后处理以及软件优化等多个方面。通过结合这些技术和策略,可以显著提高模型的推理速度和效率,从而在处理大规模图像和视频数据时获得更好的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/647150.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

leecode 637 二叉树的层平均值

leetcode 二叉树相关-层序遍历专题 二叉树的层序遍历一般来说,我们是利用队列来实现的,先把根节点入队,然后在出队后将其对应的子节点入队,然后往复此种操作。相比于二叉树的遍历递归,层序遍历比较简单,有…

AlexNet论文解析—ImageNet Classification with Deep Convolutional Neural Networks

AlexNet论文解析—ImageNet Classification with Deep Convolutional Neural Networks 2012 研究背景 认识数据集:ImageNet的大规模图像识别挑战赛 LSVRC-2012:ImageNet Large Scale Visual Recoanition Challenge 类别训练数据测试数据图片格式Mnist1…

word 全文中 英文字体 和 样式的字体 莫名奇妙地 被改成 “等线”

word全文中英文字体和样式的字体莫名奇妙地被改成“等线” sm word又抽风了,改完论文保存后打开突然发现全文字体都不对劲,吓得冷汗直冒:虽然我用git管理了论文版本,但是只有比较大的修改我才上传了,刚刚修了几个小时…

Excel必知必会

文章目录 基础概念数据格式选择区域内指定格式数据多行筛选数据转换数据格式固定首行和首列在滚动时一直显示指定列数据符合预期批量填充公式 函数VLOOKUP函数 基础概念 数据格式 文本,数值(默认值0),(逻辑值&#x…

OceanBase的存储架构与传统LSM-Tree架构的异同|OceanBase数据转储合并技术解读(二)

前篇博文将OceanBase的存储架构巧妙地与自然界中的“水生态”进行了类比,今日我们转变视角,聚焦在与拥有相同LSM-Tree架构的其他产品的比较,深入探讨OceanBase相较于它们所展现出的独特性能。 众所周知,OceanBase数据库的存储引擎…

HQL面试题练习 —— 合并数据

题目来源:京东 目录 1 题目2 建表语句3 题解 1 题目 已知有数据 A 如下,请分别根据 A 生成 B 和 C。 数据A ------------ | id | name | ------------ | 1 | aa | | 2 | aa | | 3 | aa | | 4 | d | | 5 | c | | 6 | aa…

mac操作系统下,docker登录nexus私库,提示不支持https协议的错误

一、背景 我们使用nexus搭建了一个Docker Registry私有仓库,在Mac操作系统,在推送本地镜像到私库前,要求我们登录私库,报错如下: docker login 192.168.5.6:8086 -u username -p passwordWARNING! Using --password …

Python轻松玩转excel操作指导

目录 一、一图概览 二、表格操作 三、内容操作 四、单元格操作 五、Pandas实现表格操作 六、常见场景示例 一、一图概览 ​ ​本文主要对openpyxl库的常用表格操作进行了梳理,熟练的运用后可极大地提升工作效率。 二、表格操作 #创建一个表格sheet.xlsx #…

[论文阅读笔记31]Mamba (Selective Structured State Space Model) 及其应用

最近想学一下Mamba模型,奈何看了很多视频还是感觉一知半解,因此做一篇笔记,顺便介绍一下Mamba结构作为CV backbone和时间序列预测领域的应用。 论文1. Mamba: Linear-Time Sequence Modeling with Selective State Spaces 0. Abstract 现有…

linux查看是否被入侵(一)

1、查看当前系统状态 [rootbastion-IDC ~]#top #一般挖矿等病毒点用CPU比较大 2、查看当前登录用户(w\who) 3、检查系统日志 检查系统错误登陆日志,统计IP重试次数 [rootbastion-IDC ~]# lastb 4、查看近期用户登录情况 [rootkvm01 ~]# last -n 5 #-n 5 表示…

【同构字符串】python

思路: 先记录同一个值出现的次数,再将字典中的值取出,比较2个列表即可 代码: class Solution:def isIsomorphic(self, s: str, t: str) -> bool:dit1dict()dit2dict()for i in range(len(s)):if s[i] not in dit1:dit1[s[i…

入门五(项目介绍及登录和发布需求)

软件缺陷判定标准 项目中缺陷的管理流程 使用Excel对于缺陷进行管理 使用工具管理缺陷 一、项目背景 传智作为一个IT教育机构,拥有自己开发且实际运营的产品; 将开发和运营的技术作为授课的内容,对于学员而言学到的都是一手的真实案例和…

vue data中的return

vue 的data return 是干啥的呢,vue中页面中绑定的变量都要放在data的return中,可以赋值,值可在script中改,修改引用就用this了 如果不使用return包裹的数据会在项目的全局中可见,会造成变量污染; 使用retu…

Mixed-precision计算原理(FP32+FP16)

原文: https://lightning.ai/pages/community/tutorial/accelerating-large-language-models-with-mixed-precision-techniques/ This approach allows for efficient training while maintaining the accuracy and stability of the neural network. In more det…

Android BACK键和HOME键应用差异详解

文章目录 1、应用层分析1.1 BACK键功能实现 1.2 HOME键功能实现 1.3 BACK键与HOME键的区别 2、系统层分析2.1 BACK键的处理2.2 HOME键的处理2.3 代码分析BACK键HOME键BACK键的系统代码分析HOME键的系统代码分析BACK键HOME键 3、优缺点分析3.1 BACK键3.2 HOME键 4、项目中的使用…

3小时-入门短视频创作:短视频创作入门必修(15节视频课)

课程目录 1、先导课.mp4 2、建立视听思维.mp4 3、口语化.mp4 4、具象化.mp4 5、建立选题思维.mp4 6、2个小白好上手的选题技巧.mp4 7、建立开场思维.mp4 8、3个口播视频方能开场套路.mp4 9、建立脚本结构思维.mp4 10、爆款口指的3大结构.mp4 11、建立标题思维.mp4 …

【顶刊新文】nature plants|植物高度作为高山碳固存和生态系统对变暖响应的指标

文章简介 论文名称:Plant height as an indicator for alpine carbon sequestration and ecosystem response to warming(植物高度作为高山碳固存和生态系统对变暖响应的指标) 第一作者及单位:Quan Quan(中国科学院地…

服务器被黑?快速检测和识别系统中的恶意进程

在管理和维护服务器时,检测和识别系统中的恶意进程是非常重要的。本文将详细介绍几种常用方法和工具,帮助您有效地检测和处理恶意进程,确保系统的安全性。 方法一:使用系统监控工具 1.1. 使用 ps 命令 ps 命令可以列出系统中所有正在运行的进程。使用以下命令查看特定用户…

推荐丨快速申请免费域名证书

背景: 域名是一个IP地址上的“面具” 。一个域名的目的是便于记忆和沟通的一组服务器的地址(网站,电子邮件,FTP等)。 通俗的说,域名就相当于一个家庭的门牌号码,别人通过这个号码可以很容易的找到你。 域名不仅便于记…

安装mamba时报错bare_metal_version

原因:缺少cuda118的环境版本,直接安装 nvidia/label/cuda-11.8.0 可解决,代码如下: conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc