【diffusers 极速入门(二)】如何得到扩散去噪的中间结果?Pipeline callbacks 管道回调函数

本文是对 Hugging Face Diffusers 文档中关于回调函数的翻译与总结,:


管道回调函数

在管道的去噪循环中,可以使用callback_on_step_end参数添加自定义回调函数。该回调函数在每一步结束时执行,并修改管道属性和变量,以供下一步使用。这在动态调整某些管道属性或修改张量变量时非常有用。利用回调函数,你可以实现新的功能而无需修改底层代码。

目前,Diffusers 仅支持callback_on_step_end,如果你有其他执行点的回调需求,可以在 github 上提出功能请求。

官方回调函数

官方提供了一些可用于修改去噪循环的回调函数列表:

  • SDCFGCutoffCallback:在一定步数后禁用 CFG。对于 SD 1.5 pipelines 适用, 包括 text-to-image, image-to-image, inpaint, controlnet。
  • SDXLCFGCutoffCallback:在一定步数后禁用 CFG。对于 SDXL pipelines 适用, 包括 text-to-image, image-to-image, inpaint, controlnet。
  • IPAdapterScaleCutoffCallback:在一定步数后禁用 IP Adapter。对所有支持 IP-Adapter 的 pipelines 适用。

要设置回调函数,可以指定cutoff_step_ratiocutoff_step_index参数。

  • cutoff_step_ratio:带有步长比的浮点数。
  • cutoff_step_index:一个整数,包含步数的确切编号。

示例代码

import torch
from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline
from diffusers.callbacks import SDXLCFGCutoffCallback

callback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4)
# 也可以用 cutoff_step_index
# callback = SDXLCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10)


pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
).to("cuda")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, use_karras_sigmas=True)

prompt = "a sports car at the road, best quality, high quality, high detail, 8k resolution"
generator = torch.Generator(device="cpu").manual_seed(2628670641)

out = pipeline(
    prompt=prompt,
    negative_prompt="",
    guidance_scale=6.5,
    num_inference_steps=25,
    generator=generator,
    callback_on_step_end=callback,
)

out.images[0].save("official_callback.png")

在这里插入图片描述

动态无分类器引导

动态无分类器引导(classifier-free guidance,CFG)允许在一定步数后禁用 CFG,从而节省计算成本。回调函数应包含以下参数:

  • pipeline:访问管道实例属性(如num_timesteps和guidance_scale)。
  • step_indextimestep:当前步骤索引和时间步。在达到num_timesteps的40%后,使用step_index关闭CFG。
  • callback_kwargs:包含在去噪循环中可以修改的张量变量。是一个dict,包含可以在去噪循环中修改的张量变量。
    • 它只包括callback_on_step_end_tensor_inputs参数中指定的变量,该参数被传递给管道的__call__方法。
    • 不同的管道可能使用不同的变量集,因此请检查管道的_callback_tensor_inputs属性以获取可以修改的变量列表。一些常见的变量包括latents和prompt_embeds。
    • 对于此函数,请在将guidance_scale设置为0.0后更改prompt_embeds的批处理大小,以使其正常工作。

示例回调函数:

def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
# adjust the batch_size of prompt_embeds according to guidance_scale
    if step_index == int(pipeline.num_timesteps * 0.4):
        prompt_embeds = callback_kwargs["prompt_embeds"]
        prompt_embeds = prompt_embeds.chunk(2)[-1]

# update guidance_scale and prompt_embeds
        pipeline._guidance_scale = 0.0
        callback_kwargs["prompt_embeds"] = prompt_embeds
    return callback_kwargs

每步生成后显示图像(中间结果)

通过访问并转换潜在空间,可以在每步生成后显示图像。以下函数将 SDXL 的潜在空间(4 通道)转换为 RGB 张量(3 通道)。

  1. 使用以下函数将SDXL潜伏时间(4个通道)转换为RGB张量(3个通道)
def latents_to_rgb(latents):
    weights = (
        (60, -60, 25, -70),
        (60,  -5, 15, -50),
        (60,  10, -5, -35)
    )

    weights_tensor = torch.t(torch.tensor(weights, dtype=latents.dtype).to(latents.device))
    biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(latents.device)
    rgb_tensor = torch.einsum("...lxy,lr -> ...rxy", latents, weights_tensor) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
    image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
    image_array = image_array.transpose(1, 2, 0)

    return Image.fromarray(image_array)
  1. 使用该函数在每步生成后解码并保存潜在空间为图像。
def decode_tensors(pipe, step, timestep, callback_kwargs):
    latents = callback_kwargs["latents"]
    image = latents_to_rgb(latents)
    image.save(f"{step}.png")
    return callback_kwargs
  1. decode_tensors函数传递给callback_on_step_end参数,以在每一步之后对张量进行解码。还需要在callback_on_step_end_tensor_inputs参数中指定要修改的内容,在本例中为 latents。
from diffusers import AutoPipelineForText2Image
import torch
from PIL import Image

pipeline = AutoPipelineForText2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True
).to("cuda")

image = pipeline(
    prompt="A croissant shaped like a cute bear.",
    negative_prompt="Deformed, ugly, bad anatomy",
    callback_on_step_end=decode_tensors,
    callback_on_step_end_tensor_inputs=["latents"],
).images[0]

在这里插入图片描述

详细内容请参见Hugging Face Diffusers 官方文档。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/724586.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024青海三支一扶招1910人7月6日笔试

📢2024年青海省三支一扶计划招募1910人公告已发布! 小🀄️帮大家整理好了考试关键时间点: ★ 报名时间:6月20日至6月25日 ★ 报名网站:青海省人事考试信息网(www.qhpta.com) ★ 网上…

每日一练:攻防世界:miao~

给了一张jpg图片 没发现什么特别,放到winhex中查看也没思路。 放到kali里面foremost分离文件试试,结果分离出个wav音频文件 直接放到 audycity看看频谱图 发现字符串,但是没有其他信息。可能是密钥之类的。到这里我就卡住了,看…

vue3面试题八股集合——2024

vue3比vue2有什么优势? 性能更好,打包体积更小,更好的ts支持,更好的代码组织,更好的逻辑抽离,更多的新功能 描述Vu3生命周期 Options API的生命周期: beforeCreate: 在实例初始化之后、数据观…

喜讯!昂辉科技通过2024年度重点产业链企业(第一批)认定

日前,合肥市推进战略性新兴产业发展工作委员会办公室公布了 2024年度重点产业链企业(第一批)新入库名单(集成电路、新型显示、网络与信息安全、城市安全、空天信息、新能源汽车和智能网联汽车、生物医药、新材料、高端装备、节能环…

【PyQt5】一文向您详细介绍 QHBoxLayout() 的作用

【PyQt5】一文向您详细介绍 QHBoxLayout() 的作用 下滑即可查看博客内容 🌈 欢迎莅临我的个人主页 👈这里是我静心耕耘深度学习领域、真诚分享知识与智慧的小天地!🎇 🎓 博主简介:985高校的普通本硕&a…

【IPython的使用技巧】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出…

GLSB是什么?带你深入了解GLSB核心功能

伴随互联网的快速发展,大型企业等组织单位通过建设多数据中心,以提升用户体验。然而想要在多个数据中心实现流量的智能管理,提高网站的可靠性和可用性,则需要全局服务器负载均衡技术——GLSB的助力。GLSB是什么?它又有…

算法金 | 再见!!!梯度下降(多图)

大侠幸会,在下全网同名「算法金」 0 基础转 AI 上岸,多个算法赛 Top 「日更万日,让更多人享受智能乐趣」 接前天 李沐:用随机梯度下降来优化人生! 今天把达叔 6 脉神剑给佩奇了,上 吴恩达:机器…

解决MacOS docker 拉取镜像慢的问题

docker官网:https://docker.p2hp.com/get-started/index.html 下载完成之后,拉取镜像速度慢,问题如下: 解决方法 配置阿里源:https://cr.console.aliyun.com/cn-hangzhou/instances/mirrors在docker desktop里面设置…

【C++入门(4)】引用、内联函数、auto

一、引用与类型转换 我们看下面这个例子。 用 int & 给 double 类型的变量起别名,编译器报错: int main() {double b 3.14;int a b;int& x b;return 0; } 用 const int & 给 double 类型的变量起别名,成功: in…

Spark日志有哪些?

spark.log:记录作业运行日志,包括Spark框架内部日志和用户通过日志接口输出的日志。 executor 启动结束日志: job,stage,task提交结束日志: pmap.log:周期性地截取Driver或Executor的pmap和…

element--el-table添加合计后固定列x轴滚动条无法滚动问题

效果图 改变固定列滚轮高度问题 解决文章 解决方案 使用到的参数 pointer-events 属性用来控制一个元素能否响应鼠标操作,常用的关键字有 auto 和 none pointer-events: none; 让一个元素忽略鼠标操作 pointer-events: auto; 还原浏览器设定的默认行为 代码演示 添…

C++11(1)

这一节介绍一些C11个人认为比较常用的部分 文章目录 1.{}列表初始化2.initializer_list3.auto、decltype、nullptr关键字4.范围for5.左值引用、右值引用、万能引用(完美转发)6.lambda表达式7.新的类功能8.可变参数模板9.包装器 1.{}列表初始化 C98中,标准允许使用花…

Wireshark v4 修改版安装教程(免费开源的网络嗅探抓包工具)

前言 Wireshark(前称Ethereal)是一款免费开源的网络嗅探抓包工具,世界上最流行的网络协议分析器!网络封包分析软件的功能是撷取网络封包,并尽可能显示出最为详细的网络封包资料。Wireshark网络抓包工具使用WinPCAP作为…

【ARM Cache 及 MMU 系列文章 6.5 -- 如何进行 Cache miss 统计?】

请阅读【ARM Cache 及 MMU/MPU 系列文章专栏导读】 及【嵌入式开发学习必备专栏】 文章目录 ARM Cache Miss 统计Cache 多层架构简介Cache 未命中的类型Cache 未命中统计Cache miss 统计代码实现Cache Miss 统计意义ARM Cache Miss 统计 在ARMv8/v9架构中,缓存未命中(Cache …

使用MAT定位线上OOM问题

目录 1.什么是OOM? 2.发生的可能原因 3.常见类型的OOM 4.如何定位问题? 4.1 获取dump文件 4.2 MAT分析 「Leak Suspects」泄露嫌疑 「Histogram」直方图 「dominator tree」支配树 「thread overview」线程视图 目录 1.什么是OOM? 2.发生的可能原因 …

完整迁移方案+工具:Citrix替换,无感迁移!

随着用户的替换进程进入到演进的阶段,用户面临的重大挑战包括: (1)大量数据的迁移需要精确规划,以避免数据丢失或损坏; (2)迁移效率低下,不仅会增加迁移成本,…

每日复盘-202406019

今日关注: 20240619 六日涨幅最大: ------1--------300868--------- 杰美特 五日涨幅最大: ------1--------300462--------- 华铭智能 四日涨幅最大: ------1--------300462--------- 华铭智能 三日涨幅最大: ------1--------300462--------- 华铭智能 二日涨幅最大…

可信计算和数字水印技术

可信计算 可信计算可信计算基础概述可信计算关键技术要素可信性认证可信计算优劣 数字水印技术数字版权保护技术 可信计算 可信计算基础概述 可信计算(Trusted Computing,TC):在计算和网络通信系统中广泛使用的、基于硬件安全模块…