扩散模型实战(九):使用CLIP模型引导和控制扩散模型

推荐阅读列表:

 扩散模型实战(一):基本原理介绍

扩散模型实战(二):扩散模型的发展

扩散模型实战(三):扩散模型的应用

扩散模型实战(四):从零构建扩散模型

扩散模型实战(五):采样过程

扩散模型实战(六):Diffusers DDPM初探

扩散模型实战(七):Diffusers蝴蝶图像生成实战

扩散模型实战(八):微调扩散模型

上篇文章中介绍了如何微调扩散模型,有时候微调的效果仍然不能满足需求,比如图片编辑,3D模型输出等都需要对生成的内容进行控制,本文将初步探索一下如何控制扩散模型的输出。

我们将使用在LSUM bedrooms数据集上训练并在WikiArt数据集上微调的模型,首先加载模型来查看一下模型的生成效果:

!pip install -qq diffusers datasets accelerate wandb open-clip-torch

 

import numpy as npimport torchimport torch.nn.functional as Fimport torchvisionfrom datasets import load_datasetfrom diffusers import DDIMScheduler, DDPMPipelinefrom matplotlib import pyplot as pltfrom PIL import Imagefrom torchvision import transformsfrom tqdm.auto import tqdm

 

device = (    "mps"    if torch.backends.mps.is_available()    else "cuda"    if torch.cuda.is_available()    else "cpu")

 

# 载入一个预训练过的管线pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms"image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device) # 使用DDIM调度器,仅用40步生成一些图片scheduler = DDIMScheduler.from_pretrained(pipeline_name)scheduler.set_timesteps(num_inference_steps=40) # 将随机噪声作为出发点x = torch.randn(8, 3, 256, 256).to(device) # 使用一个最简单的采样循环for i, t in tqdm(enumerate(scheduler.timesteps)):    model_input = scheduler.scale_model_input(x, t)    with torch.no_grad():        noise_pred = image_pipe.unet(model_input, t)["sample"]    x = scheduler.step(noise_pred, t, x).prev_sample # 查看生成结果,如图5-10所示grid = torchvision.utils.make_grid(x, nrow=4)plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)

图片

       正如上图所示,模型可以生成一些图片,那么如何进行控制输出呢?下面我们以控制图片生成绿色风格为例介绍AIGC模型控制:

       思路是:定义一个均方误差损失函数,让生成的图片像素值尽量接近目标颜色;

 

def color_loss(images, target_color=(0.1, 0.9, 0.5)):    """给定一个RGB值,返回一个损失值,用于衡量图片的像素值与目标颜色相差多少; 这里的目标颜色是一种浅蓝绿色,对应的RGB值为(0.1, 0.9, 0.5)"""    target = (        torch.tensor(target_color).to(images.device) * 2 - 1    )  # 首先对target_color进行归一化,使它的取值区间为(-1, 1)     target = target[        None, :, None, None    ]  # 将所生成目标张量的形状改为(b, c, h, w),以适配输入图像images的 # 张量形状    error = torch.abs(        images - target    ).mean()  # 计算图片的像素值以及目标颜色的均方误差    return error

接下来,需要修改采样循环操作,具体操作步骤如下:

  1. 创建输入图像X,并设置requires_grad设置为True;

  2. 计算“去噪”后的图像X0;

  3. 将“去噪”后的图像X0传递给损失函数;

  4. 计算损失函数对输入图像X的梯度;

  5. 在使用调度器之前,先用计算出来的梯度修改输入图像X,使输入图像X朝着减少损失值的方向改进

实现上述步骤有两种方法:

方法一:从UNet中获取噪声预测,并将输入图像X的requires_grad属性设置为True,这样可以充分利用内存(因为不需要通过扩散模型追踪梯度),但是这会导致梯度的精度降低;

方法二:先将输入图像X的requires_grad属性设置为True,然后传递给UNet并计算“去噪”后的图像X0;

下面分别看一下这两种方法的效果:

 

# 第一种方法 # guidance_loss_scale用于决定引导的强度有多大guidance_loss_scale = 40  # 可设定为5~100的任意数字 x = torch.randn(8, 3, 256, 256).to(device) for i, t in tqdm(enumerate(scheduler.timesteps)):     # 准备模型输入    model_input = scheduler.scale_model_input(x, t)     # 预测噪声    with torch.no_grad():        noise_pred = image_pipe.unet(model_input, t)["sample"]     # 设置x.requires_grad为True    x = x.detach().requires_grad_()     # 得到“去噪”后的图像    x0 = scheduler.step(noise_pred, t, x).pred_original_sample     # 计算损失值    loss = color_loss(x0) * guidance_loss_scale    if i % 10 == 0:        print(i, "loss:", loss.item())     # 获取梯度    cond_grad = -torch.autograd.grad(loss, x)[0]     # 使用梯度更新x    x = x.detach() + cond_grad     # 使用调度器更新x    x = scheduler.step(noise_pred, t, x).prev_sample# 查看结果grid = torchvision.utils.make_grid(x, nrow=4)im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5Image.fromarray(np.array(im * 255).astype(np.uint8))

 

# 输出0 loss: 29.3701839447021510 loss: 12.11665058135986320 loss: 11.64170455932617230 loss: 11.78276252746582

图片

 

# 第二种方法:在模型预测前设置好x.requires_gradguidance_loss_scale = 40x = torch.randn(4, 3, 256, 256).to(device) for i, t in tqdm(enumerate(scheduler.timesteps)):     # 首先设置好requires_grad    x = x.detach().requires_grad_()    model_input = scheduler.scale_model_input(x, t)     # 预测    noise_pred = image_pipe.unet(model_input, t)["sample"]     # 得到“去噪”后的图像    x0 = scheduler.step(noise_pred, t, x).pred_original_sample     # 计算损失值    loss = color_loss(x0) * guidance_loss_scale    if i % 10 == 0:        print(i, "loss:", loss.item())     # 获取梯度    cond_grad = -torch.autograd.grad(loss, x)[0]     # 根据梯度修改x    x = x.detach() + cond_grad     # 使用调度器更新x    x = scheduler.step(noise_pred, t, x).prev_sample  grid = torchvision.utils.make_grid(x, nrow=4)im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5Image.fromarray(np.array(im * 255).astype(np.uint8))

 

# 输出0 loss: 27.6226882934570310 loss: 16.84250640869140620 loss: 15.5464210510253930 loss: 15.545379638671875

图片

       从上图看出,第二种方法效果略差,但是第二种方法的输出更接近训练模型所使用的数据,也可以通过修改guidance_loss_scale参数来增强颜色的迁移效果。

CLIP控制图像生成

       虽然上述方式可以引导和控制图像生成某种颜色,但现在LLM更主流的方式是通过Prompt(仅仅打几行字描述需求)来得到自己想要的图像,那么CLIP是一个不错的选择。CLIP是有OpenAI开发的图文匹配大模型,由于这个过程是可微分的,所以可以将其作为损失函数来引导扩散模型。

使用CLIP控制图像生成的基本流程如下

  1. 使用CLIP模型对Prompt表示为512embedding向量;

  2. 在扩散模型的生成过程中需要多次执行如下步骤:

    1)生成多个“去噪”图像;

    2)对生成的每个“去噪”图像用CLIP模型进行embedding,并对Prompt embedding和图像的embedding进行对比;

    3)计算Prompt和“去噪”后图像的梯度,使用这个梯度先更新输入图像X,然后再使用调度器更新X;

    加载CLIP模型

 

import open_clip clip_model, _, preprocess = open_clip.create_model_and_transforms(    "ViT-B-32", pretrained="openai")clip_model.to(device) # 图像变换:用于修改图像尺寸和增广数据,同时归一化数据,以使数据能够适配CLIP模型 tfms = torchvision.transforms.Compose(    [        torchvision.transforms.RandomResizedCrop(224),# 随机裁剪        torchvision.transforms.RandomAffine(5),       # 随机扭曲图片        torchvision.transforms.RandomHorizontalFlip(),# 随机左右镜像, # 你也可以使用其他增广方法        torchvision.transforms.Normalize(            mean=(0.48145466, 0.4578275, 0.40821073),            std=(0.26862954, 0.26130258, 0.27577711),        ),    ]) # 定义一个损失函数,用于获取图片的特征,然后与提示文字的特征进行对比def clip_loss(image, text_features):    image_features = clip_model.encode_image(        tfms(image)    )  # 注意施加上面定义好的变换    input_normed = torch.nn.functional.normalize(image_features.       unsqueeze(1), dim=2)    embed_normed = torch.nn.functional.normalize(text_features.       unsqueeze(0), dim=2)    dists = (        input_normed.sub(embed_normed).norm(dim=2).div(2).           arcsin().pow(2).mul(2)    )  # 使用Squared Great Circle Distance计算距离    return dists.mean()

      下面是引导模型生成图像的过程,步骤与上述类似,只需要把color_loss()替换成CLIP的损失函数

 

prompt = "Red Rose (still life), red flower painting" # 读者可以探索一下这些超参数的影响guidance_scale = 8n_cuts = 4 # 这里使用稍微多一些的步数scheduler.set_timesteps(50) # 使用CLIP从提示文字中提取特征text = open_clip.tokenize([prompt]).to(device)with torch.no_grad(), torch.cuda.amp.autocast():    text_features = clip_model.encode_text(text) x = torch.randn(4, 3, 256, 256).to(    device)  for i, t in tqdm(enumerate(scheduler.timesteps)):     model_input = scheduler.scale_model_input(x, t)     # 预测噪声    with torch.no_grad():        noise_pred = image_pipe.unet(model_input, t)["sample"]     cond_grad = 0     for cut in range(n_cuts):         # 设置输入图像的requires_grad属性为True        x = x.detach().requires_grad_()         # 获得“去噪”后的图像        x0 = scheduler.step(noise_pred, t, x).pred_original_sample         # 计算损失值        loss = clip_loss(x0, text_features) * guidance_scale         # 获取梯度并使用n_cuts进行平均        cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts     if i % 25 == 0:        print("Step:", i, ", Guidance loss:", loss.item())     # 根据这个梯度更新x    alpha_bar = scheduler.alphas_cumprod[i]    x = (        x.detach() + cond_grad * alpha_bar.sqrt()    )  # 注意这里的缩放因子     # 使用调度器更新x    x = scheduler.step(noise_pred, t, x).prev_sample  grid = torchvision.utils.make_grid(x.detach(), nrow=4)im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5Image.fromarray(np.array(im * 255).astype(np.uint8))

 

# 输出Step: 0 , Guidance loss: 7.418107986450195Step: 25 , Guidance loss: 7.085518836975098

图片

       上述生成的图像虽然不够完美,但可以调整一些超参数,比如梯度缩放因子alpha_bar.sqrt(),虽然理论上存在所谓的正确的缩放这些梯度方法,但在实践中仍需要实验来检验,下面介绍一些常用的方案:

 

plt.plot([1 for a in scheduler.alphas_cumprod], label="no scaling")plt.plot([a for a in scheduler.alphas_cumprod], label="alpha_bar")plt.plot([a.sqrt() for a in scheduler.alphas_cumprod],     label="alpha_bar.sqrt()")plt.plot(    [(1 - a).sqrt() for a in scheduler.alphas_cumprod], label="(1-     alpha_bar).sqrt()")plt.legend()

图片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/155331.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Dubbo协议详解

前言特点应用场景Dubbo协议示例Dubbo协议的不足拓展 前言 Dubbo协议是一种高性能、轻量级的开源RPC框架,主要设计目的是解决分布式系统中服务调用的一些常见问题,例如服务负载均衡、服务注册中心、服务的远程调用等。它支持多种语言,例如Jav…

LeetCode(26)判断子序列【双指针】【简单】

目录 1.题目2.答案3.提交结果截图 链接: 判断子序列 1.题目 给定字符串 s 和 t ,判断 s 是否为 t 的子序列。 字符串的一个子序列是原始字符串删除一些(也可以不删除)字符而不改变剩余字符相对位置形成的新字符串。(…

centos搭建docker镜像Harbor仓库的简明方法

在kubernetes集群中如果要部署springcloud这样的应用,就必须有一个自建的docker镜像中心仓库。 它的目的有两点: 1. 镜像拉取速度快 2. 开发好维护 而Harbor是一个非常好用的docker本地仓库 所以本篇文章来讲讲如何在部署Harbor仓库 首先系统版本最…

此芯科技加入绿色计算产业联盟,参编绿色计算产业发展白皮书

近日,此芯科技正式加入绿色计算产业联盟(Green Computing Consortium,简称GCC),以Arm架构通用智能CPU芯片及高能效的Arm PC计算解决方案加速构建软硬协同的绿色计算生态体系,推动绿色计算产业加速发展。 继…

Linux_系统信息_uname查看内核版本、内核建立时间、处理器类型、顺便得到操作系统位数等

1、uname --help 使用uname --help查看uname命令的帮助信息 2、uname -a 通过上面的help就知道-a选项显示全部内容时的含义了。 内核名是Linux主机名是lubancat,如果想看主机名可以使用命令hostname;内核版本是Linux 4.19.232,建立时间为2…

全栈工程师必须要掌握的前端Html技能

作为一名全栈工程师,在日常的工作中,可能更侧重于后端开发,如:C#,Java,SQL ,Python等,对前端的知识则不太精通。在一些比较完善的公司或者项目中,一般会搭配前端工程师&a…

【二分法】

二分法可以在有序排列中&#xff0c;通过不断对半切割数据&#xff0c;提高数据查找效率。 lst [1,4,6,7,45,66,345,767,788,999] n 66 left 0 right len(lst)-1 while left < right: #边界&#xff0c;当右边比左边还小的时候退出循环 mid (left right)//2 …

JVM虚拟机-虚拟机执行子系统-第6章 类文件结构

各种不同平台的Java虚拟机&#xff0c;以及所有平台都统一支持的程序存储格式——字节码&#xff08;Byte Code&#xff09;是构成平台无关性的基石 Class类文件的结构 字节码指令&#xff1a;操作码 操作数 任何一个Class文件都对应着唯一的一个类或接口的定义信息 Class文件…

C/C++输出整数部分 2021年12月电子学会青少年软件编程(C/C++)等级考试一级真题答案解析

目录 C/C输出整数部分 一、题目要求 1、编程实现 2、输入输出 二、算法分析 三、程序编写 四、程序说明 五、运行结果 六、考点分析 C/C输出整数部分 2021年12月 C/C编程等级考试一级编程题 一、题目要求 1、编程实现 输入一个双精度浮点数f&#xff0c; 输出其整…

C++打怪升级(十一)- STL之list

~~~~ 前言1. list是什么2. list接口函数的使用1. 构造相关默认构造n个val构造迭代器范围构造拷贝构造 2 赋值运算符重载函数2 析构函数3 迭代器相关begin 和 endrbegin 和rend 4 容量相关emptysize 5 元素访问相关frontback 6 修改相关push_backpop_backpush_frontpop_frontins…

【OJ比赛日历】快周末了,不来一场比赛吗? #11.18-11.24 #15场

CompHub[1] 实时聚合多平台的数据类(Kaggle、天池…)和OJ类(Leetcode、牛客…&#xff09;比赛。本账号会推送最新的比赛消息&#xff0c;欢迎关注&#xff01; 以下信息仅供参考&#xff0c;以比赛官网为准 目录 2023-11-18&#xff08;周六&#xff09; #5场比赛2023-11-19…

受电诱骗快充取电芯片XSP08:PD+QC+华为+三星多种协议9V12V15V20V

目前市面上很多家的快充充电器&#xff0c;都有自己的私有快充协议&#xff0c;如PD协议、QC协议、华为快充协议、三星快充协议、OPPO快充协议等待&#xff0c;为了让它们都能输出快充电压&#xff0c;就需要在受电端也增加快充协议取电芯片XSP08&#xff0c;它可以和充电器通讯…

我记不住的getopt_long的那些参数和返回值

前言&#xff1a;最近在学习面向Linux系统进行C语言的编程&#xff0c;通过查询man手册和查看网络上的各种文章来获取一点点的知识&#xff0c;重点是看完手册还是一脸懵逼&#xff0c;搞不懂手册里面再说啥&#xff0c;而本篇文章将记录一下学习getopt_long的那些参数和返回值…

IDEA无法查看源码是.class,而不是.java解决方案?

问题&#xff1a;在idea中&#xff0c;ctrl鼠标左键进入源码&#xff0c;但是有时候会出现无法查看反编译的源码&#xff0c;如图&#xff01; 而我们需要的是方法1: mvn dependency:resolve -Dclassifiersources 注意&#xff1a;需要该模块的目录下&#xff0c;不是该文件目…

计算机毕业设计选题推荐-幼儿园管理微信小程序/安卓APP-项目实战

✨作者主页&#xff1a;IT毕设梦工厂✨ 个人简介&#xff1a;曾从事计算机专业培训教学&#xff0c;擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

LeetCode(24)文本左右对齐【数组/字符串】【困难】

目录 1.题目2.答案3.提交结果截图 链接&#xff1a; 文本左右对齐 1.题目 给定一个单词数组 words 和一个长度 maxWidth &#xff0c;重新排版单词&#xff0c;使其成为每行恰好有 maxWidth 个字符&#xff0c;且左右两端对齐的文本。 你应该使用 “贪心算法” 来放置给定的单…

spring中的DI

【知识要点】 控制反转&#xff08;IOC&#xff09;将对象的创建权限交给第三方模块完成&#xff0c;第三方模块需要将创建好的对象&#xff0c;以某种合适的方式交给引用对象去使用&#xff0c;这个过程称为依赖注入&#xff08;DI&#xff09;。如&#xff1a;A对象如果需要…

flutter web 中嵌入一个html

介绍 flutter web 支持使用 HtmlElementView嵌入html import dart:html; import dart:ui as ui; import package:flutter/cupertino.dart;class WebWidget extends StatelessWidget {const WebWidget({super.key});overrideWidget build(BuildContext context) {DivElement fr…

flutter绘制弧形进度条

绘制&#xff1a; import package:jade/utils/JadeColors.dart; import package:flutter/material.dart; import dart:math as math; import package:flutter_screenutil/flutter_screenutil.dart;class ArcProgressBar extends StatefulWidget{const ArcProgressBar({Key key…

Linux系统进程——进程的退出、子进程退出的收集、孤儿进程

进程退出 进程退出主要分为两种&#xff1a;正常退出、异常退出 正常退出 正常退出分为以下几种&#xff1a; 1.main函数调用return 2.进程调用exit(),标准c库 3.进程调用 _exit() 或者 _Exit() &#xff0c;属于系统调用 4.进程最后一个线程返回 5.最后一个线程调用pthrea…