使用Dreambooth LoRA微调SDXL 0.9

本文将介绍如何通过LoRA对Stable Diffusion XL 0.9进行Dreambooth微调。DreamBooth是一种仅使用几张图像(大约3-5张)来个性化文本到图像模型的方法。

本教程基于通过LoRA进行Unet微调,而不是进行全部的训练。LoRA是在LoRA: Low-Rank Adaptation of Large Language Models中引入的一种参数高效的微调技术。

本文基于diffusers包,至少需要0.18.2或更高版本。

基于GeForce RTX 4090 GPU (24GB)的本地实验,VRAM消耗如下:

  • 512分辨率- 11GB用于训练,19GB保存检查点
  • 1024分辨率- 17GB的训练,19GB时保存检查点

环境设置

建议创建一个新的虚拟环境,下面是我们需要使用的python包

Pytorch

 pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118

Diffusers

 pip install git+https://github.com/huggingface/diffusers

LoRA进行SDXL 0.9 Dreambooth微调需要0.19.0.dev0及以上版本的diffusers

还有一些其他的依赖包:

 pip install invisible_watermark transformers accelerate safetensors

然后就是进行配置,在终端上执行如下命令配置accelerate:

 accelerate config

使用以下设置在单个GPU上进行混合精度的训练:

 ----------------------------------------------------------------------------------------------------------------------------
 In which compute environment are you running?
 This machine
 ----------------------------------------------------------------------------------------------------------------------------
 Which type of machine are you using?
 No distributed training
 
 Do you want to run your training on CPU only (even if a GPU is available)? [yes/NO]:
 no
 
 Do you wish to optimize your script with torch dynamo?[yes/NO]:
 no
 
 Do you want to use DeepSpeed? [yes/NO]: 
 no
 
 What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:
 all
 ----------------------------------------------------------------------------------------------------------------------------
 Do you wish to use FP16 or BF16 (mixed precision)?
 fp16

或者,使用以下命令使用默认值

 accelerate config default

数据集

我们这里将介绍Dreambooth微调所需的最简单配置。对于数据集的准备,只需收集一些相同主题或风格的图像,并将其放在一个目录中。

比如下面的文件夹结构:

 data/xxx.png
 data/xxy.png
 ...
 data/xxz.png
 data/yyz.png

这里要确保所有的训练图像都是相同的大小。如果大小不同,需要先调整大小。建议使用1024 * 1024作为图像分辨率。

我们这里使用dog示例数据集通过LoRA测试Dreambooth微调。这个数据集可以直接从网站下载,以下Python脚本可以将其下载到本地:

 from huggingface_hub import snapshot_download
 
 local_dir = "./data"
 snapshot_download(
     "diffusers/dog-example",
     local_dir=local_dir, repo_type="dataset",
     ignore_patterns=".gitattributes",
 )

微调训练

在官方库下载train_dreambooth_lora_sdxl.py训练脚本。将该文件放在工作目录中。

如果你使用的是旧版本的diffusers,它将由于版本不匹配而报告错误。但是你可以通过在脚本中找到check_min_version函数并注释它来轻松解决这个问题,如下所示:

 # check_min_version("0.19.0.dev0")

虽然可以用,但是还是建议使用官方的推荐版本。

如果全部设置正确,那么可以通过LoRA进行Dreambooth微调的训练命令:

 accelerate launch train_dreambooth_lora_sdxl.py \
   --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-0.9" \
   --instance_data_dir=data \
   --output_dir=output \
   --mixed_precision="fp16" \
   --instance_prompt="a photo of zwc dog" \
   --resolution=1024 \
   --train_batch_size=1 \
   --gradient_accumulation_steps=4 \
   --learning_rate=1e-4 \
   --lr_scheduler="constant" \
   --lr_warmup_steps=0 \
   --checkpointing_steps=500 \
   --max_train_steps=1000 \
   --seed="0" \
   --checkpoints_total_limit=5

对于Windows用户,将所有\符号替换为^符号。因为转义符不同

简单的介绍一些参数:

  • instance_prompt:带有指定实例标识符的提示符。
  • resolution:输入图像的分辨率,训练/验证数据集中的所有图像都将调整为此大小。默认值是512,将其设置为1024,因为它是用于SDXL训练的分辨率。
  • train_batch_size:训练数据加载器的批处理大小(每个设备)。减少批处理大小,防止训练过程中出现内存不足错误。
  • num_train_steps:训练步数。建议设置为N × 100,其中N表示训练图像的个数。
  • checkpointing_steps:每X次更新时保存训练状态的检查点。默认为500。将其设置为更高的值以减少保存的检查点数量,因为模型需要保存到磁盘,所以频繁的保存会降低训练速度。
  • checkpoints_total_limit:限制保存的检查点的数量。将删除/删除旧的检查点。

在第一次运行是,程序会下载Stable Diffusion模型并将其保存在本地缓存文件夹中,如果网不好的话这里会很慢。在随后的运行中,它将重用相同的缓存数据。

请注意SDXL 0.9权重需要登录HuggingFace并接受许可。然后,通过HuggingFace -cli命令登录,并使用从HuggingFace设置中获取的API令牌。

默认情况下,每个checkpointing_steps脚本只保存一次LoRA权重和一些检查点文件。

最后我们的结果如下:

 |- output
 |  |- checkpoint-500
 |  |- checkpoint-1000
 |  |- checkpoint-1500
 |  |- checkpoint-2000
 |- data
 |- train_dreambooth_lora_sdxl.py

上面的每个checkpoint文件夹包含以下文件:

  • optimizer.bin
  • pytorch_lora_weights.bin
  • random_states_0.pkl
  • scaler.pt
  • scheduler.bin

pytorch_lora_weights.bin文件可以直接用于推理。

推理

创建一个名为inference.py的新Python文件:

 from diffusers import DiffusionPipeline
 import torch
 #初始化,加载所需的LoRA权重
 pipe = DiffusionPipeline.from_pretrained(
     "stabilityai/stable-diffusion-xl-base-0.9",
     torch_dtype=torch.float16,
     variant="fp16",
     use_safetensors=True
 )
 # load LoRA weight
 pipe.unet.load_attn_procs("data/checkpoint-2000/pytorch_lora_weights.bin", use_safetensors=False)
 pipe.enable_model_cpu_offload()
 
 refiner = DiffusionPipeline.from_pretrained(
     "stabilityai/stable-diffusion-xl-refiner-0.9",
     torch_dtype=torch.float16,
     variant="fp16",
     use_safetensors=True
 )
 refiner.enable_model_cpu_offload()
 
 #推理和保存文件
 seed = 12345
 n_steps = 50
 prompt = "a photo of zwc dog in a bucket"
 
 generator = torch.Generator(device="cuda").manual_seed(seed)
 latent_image = pipe(prompt=prompt, num_inference_steps=n_steps, generator=generator, output_type="latent").images[0]
 image = refiner(prompt=prompt, num_inference_steps=n_steps, generator=generator, image=latent_image).images[0]
 image.save("image.jpg")

然后我们可以执行如下命令:

 python inference.py

结果展示

以下是我做的一个快速测试,使用16张具有各种情绪的chibi 人物图像作为训练数据集。

分辨率1024 × 1024 、duoduo 作为实例提示

大约花了4个小时的训练,下面的输出示例:

总结

使用我们上面的代码可以随意使用不同的数据集和训练配置进行实验,以获得所需的结果。

本文首先简要介绍了Dreambooth和LoRA背后的基本概念。然后介绍了通过pip install进行安装的过程。还探讨了数据集的准备。然后整理了训练命令,并对一些常用的训练参数进行了详细的说明。并使用代码加载新训练的LoRA权重,根据输入提示生成相应的图像。最后展示了一个在本地进行的训练的简单实验。

本文使用的主要库:

https://avoid.overfit.cn/post/0423804f782b4cb9a74f1ae6a6f99b34

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/37641.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2023届网络安全岗秋招面试题及面试经验分享

Hello,各位小伙伴,我作为一名网络安全工程师曾经在秋招中斩获🔟个offer🌼,并在国内知名互联网公司任职过的职场老油条,希望可以将我的面试的网络安全大厂面试题和好运分享给大家~ 转眼2023年秋招已经到了金…

Python应用实例(一)外星人入侵(八)

外星人入侵(八) 1.添加Play按钮1.1 创建Button类1.2 在屏幕上绘制按钮1.3 开始游戏1.4 重置游戏1.5 将play按钮切换到非活动状态1.6 隐藏鼠标光标 我们添加一个Play按钮,用于根据需要启动游戏以及在游戏结束后重启游戏,还会修改这…

剖析C语言字符串函数(超全)

目录 前言: 一、strlen函数 功能: 参数和返回值: 注意事项: 返回值是无符号的易错点: strlen函数的模拟实现 1、计数器算法 2、递归算法 3、指针减去指针 二、strcpy函数 功能: 参数和返回值 …

124、仿真-基于51单片机智能电表系统设计(Proteus仿真+程序+原理图+配套资料等)

方案选择 单片机的选择 方案一:STM32系列单片机控制,该型号单片机为LQFP44封装,内部资源足够用于本次设计。STM32F103系列芯片最高工作频率可达72MHZ,在存储器的01等等待周期仿真时可达到1.25Mip/MHZ(Dhrystone2.1)。内部128k字节…

盖子的c++小课堂——第十八讲:栈

目录 前言 栈的定义 栈,是什么? 例1-弹夹 问题 例2-停车场 问题 栈的概念 空栈 进栈、出栈 特点 例题 车厢调度 如何操作 数组模拟栈 入栈 出栈 栈的基本操作 判断空栈 求栈的元素数量 读栈顶元素 总结 前言 OK呀,说到做…

银河麒麟服务器v10 sp1 部署 redis 及redis gui 客户端工具

上一篇:银河麒麟服务器v10 sp1 redis开机自动启动_csdn_aspnet的博客-CSDN博客 本文介绍另一种redis安装方式及客户端工具安装。 Redis 是一种内存数据模型存储,可用作数据库、缓冲区和消息传递中继。它是开源的(BSD 许可)。字符…

大模型基础:理论与技术的演进概述

大模型基础:理论与技术的演进概述 人工智能发展历程 人工智能发展历程可以概括为以下几个主要阶段: 起源阶段(1956-1980年代),这一时期被称为人工智能的“黄金时代”, 达特茅斯会议首次提出人工智能概念, 开发出传统人工智能系统, 如ELIZA、深蓝等。知…

Java设计模式之行为型-命令模式(UML类图+案例分析)

目录 一、基础概念 二、UML类图 三、角色设计 四、案例分析 1、基本实现 2、点餐案例 五、总结 一、基础概念 1、将一个请求封装为一个对象,使您可以用不同的请求对客户进行参数化。 2、对请求排队或记录请求日志,以及支持可撤销的操作。 3、…

JAVA动态代理

动态代理是在运行时动态生成类字节码,并加载到 JVM 中 你通过Proxy 类的 newProxyInstance() 创建的代理对象在调用方法的时候,实际会调用到实现InvocationHandler 接口的类的 invoke()方法. 运行时的动作由invoke()方法决定控制。 其中运用了反射的相…

(vue)整个页面添加背景视频

(vue)整个页面添加背景视频 App.vue <template><div id"app" :class"[platform]"><video src"./assets/images/top/bg-video-711.mp4" autoplay muted loop class"bg"></video><router-view /></di…

关于Java的网络编程

网络的一些了解 网络通信协议 链路层&#xff1a;链路层是用于定义物理传输通道&#xff0c;通常是对某些网络连接设备的驱动协议&#xff0c;例如针对光纤、网线提供的驱动。网络层&#xff1a;网络层是整个TCP/IP协议的核心&#xff0c;它主要用于将传输的数据进行分组&…

你的隐私被泄漏了吗

近日&#xff0c;某高校毕业生在校期间窃取学校内网数据&#xff0c;收集全校学生个人隐私信息的新闻引发了人们对互联网生活中个人信息安全问题的再度关注。在大数据时代&#xff0c;算法分发带来了隐私侵犯&#xff0c;在享受消费生活等便捷权利的同时&#xff0c;似乎又有不…

按关键词全网采集

简数采集器支持按关键词全网采集&#xff0c;只需输入对应关键词&#xff0c;即可在全网采集相关数据&#xff0c;类似搜索引擎&#xff0c;无需用户配置采集规则。 简数采集器按关键词泛采集可用于舆情监控、市场研究分析等。 使用方法如下&#xff1a; 目录 1. 创建关键词…

MySQL为什么采用B+树作为索引底层数据结构?

索引就像一本书的目录&#xff0c;通过索引可以快速找到我们想要找的内容。那么什么样的数据结构可以用来实现索引呢&#xff1f;我们可能会想到&#xff1a;二叉查找树&#xff0c;平衡搜索树&#xff0c;或者是B树等等一系列的数据结构&#xff0c;那么为什么MySQL最终选择了…

【框架篇】对象注入的三种实现方式

对象注入的实现 一&#xff0c;实现方式的使用 对象注入也可被称为对象装配&#xff0c;是把Bean对象获取出来放到某个类中。 对象注入的实现方式有3种&#xff0c;分别为属性注入&#xff0c;Setter注入和构造方法注入。 为了更好地理解对象注入的实现方式&#xff0c;搞个…

Spring管理事务知识

目录 1.什么是事务 2.事务的特性ACID 3.Spring 管理事务的方式 4.Spring管理事务的体现&#xff1a;JDBCTemplate 5.声明式事务的属性有哪些 6.声明式事务属性---只读 7.声明式事务属性---超时 8.声明式事务属性---回滚策略 9.声明式事务属性---事务隔离级别 10.声明…

1、Kubernetes 概述和架构

目录 一、基本介绍 二、kubernetes功能和架构 2.1、 概述 2.2 、功能 &#xff08;1&#xff09;自动装箱 &#xff08;2&#xff09;自我修复(自愈能力) &#xff08;3&#xff09;水平扩展 &#xff08;4&#xff09;服务发现 &#xff08;5&#xff09;滚动更新 &a…

【Vue】给 elementUI 中的 this.$confirm、this.$alert、 this.$prompt添加按钮的加载效果

文章目录 主要使用 beforeClose 方法实现 loading 的效果beforeClose MessageBox 关闭前的回调&#xff0c;会暂停实例的关闭 function(action, instance, done)1. action 的值为confirm, cancel或close。 2. instance 为 MessageBox 实例&#xff0c;可以通过它访问实例上的属…

C语言中定义和声明的区别

声明(declaration)与定义(definition) 为了使不同的文件都可以访问同一个变量&#xff0c;C会区 分变量的声明和定义。 变量的定义会为这个变量分配存储空间&#xff0c;并且 可能 会为其指定一个初始化的值&#xff0c; 一个变量的定义有且 仅有一处。 定义实际上是一种特殊…

【网络】HTTPS协议原理

目录 “加密”相关概念 为什么要加密 常见加密方式 对称加密 非对称加密 HTTPS工作过程探究 方案1-只使用对称加密 方案2-只使用非对称加密 方案3-客户端和服务端双方都使用非对称加密 方案4-非对称加密 对称加密 上述方案问题分析 方案5-证书认证 非对称加密对…