Multi-Concept Customization of Text-to-Image Diffusion——【代码复现】

本文是发表于CVPR 2023上的一篇论文:[2212.04488] Multi-Concept Customization of Text-to-Image Diffusion (arxiv.org)

一、引言

本文主要做的工作是对stable-diffusion的预训练模型进行微调,需要的显存相对较多,论文中测试时是在两块GPU上微调,需要30GB的显存,不过他调的batchsize=8,因为我自己的算力有限,我把复现的时候把batchsize调成了2,然后在两块3090上跑的,至于最低要求多少还没测试,不过个人认为最低也要有一张3090。

在复现前,请自行安装好Python的环境,本文就不叙述了哈哈。

二、下载相关文件及搭建环境

1.下载项目及环境搭建

adobe-research/custom-diffusion: Custom Diffusion: Multi-Concept Customization of Text-to-Image Diffusion (CVPR 2023) (github.com)

上述链接是本文代码的链接,这篇文章的代码实际上是基于Stable-diffusion构建的,所以我的建议是可以先去复现一下stable-diffusion的代码,再来学习这篇文章以及代码。stable-diffusion的复现可以看我另外一篇文章:stable-diffusion复现笔记,当然如果你想直接上手,可以按照项目中readme来构建,这里我默认已经有装过stable-diffusion了哈,因为很多文件都是相同的,如果你是直接上手,有些文件比如sd-v1-4.ckpt的下载等问题,都可以去看我这篇stable-diffusion复现笔记。

git clone https://github.com/adobe-research/custom-diffusion.git
cd custom-diffusion
git clone https://github.com/CompVis/stable-diffusion.git
cd stable-diffusion
conda env create -f environment.yaml
conda activate ldm
pip install clip-retrieval tqdm 

上述是论文给出的环境搭建代码,如果你跟我一样已经做过stable-diffusion的安装,可以直接执行最后一行 pip install clip-retrieval tqdm 。

2.下载数据集

复现的时候我用的是官方给的数据集,下载地址:https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip

三、运行

复现的过程我主要采用以生成的图像作为正则化来实现,方便起见,主要还是按照官方给的示例来复现。

1.单一概念微调——生成的图像作为正则化

第一步:这里我们可以直接执行命令文件,<pretrained-model-path>是预训练模型的路径,如:/data/disk1/sxtang/models/sd-v1-4.ckpt

bash scripts/finetune_gen.sh "cat" data/cat gen_reg/samples_cat  cat finetune_addtoken.yaml <pretrained-model-path>

这个sh文件会执行两个脚本文件:sample.py、train.py。

先执行sample.py生成用于正则化的图像,一共是200张,然后再执行train.py文件对预训练的模型进行微调,如果一切顺利,命令行最后的输入应该如下:

生成的正则化图像的目录:

 

微调所得模型目录:

 

复现过程中我所遇到的问题:

(1).我是在RTX3090上进行采样生成图片的,但是如果按照代码中默认的参数去执行,我的显存是不够的(论文毕竟是在两块A100做的),然后我的解决方法是把参数调了一下,改成:

--n_samples 5  --n_iter 40 

这里主要还是根据自己的情况去调整,如果还是爆显存的话,可以把数值都调小点,然后多执行几次sample脚本也是可以的。

(2).之前也说了,代码默认的batchsize=4,我跑不了哈哈,所以调整一下batchsize的大小。

具体的,在configs/custom-diffusion/finetune_addtoken.yaml文件中更改:

(3).TypeError: CUDACallback.on_train_epoch_end() missing 1 required positional argument: 'outputs'问题。

这里主要是pytorch-lighting的版本问题,需要把这个outputs参数删掉,具体的,在train.py文件下的on_train_epoch_end函数中:

 

(4).pytorch_lightning.utilities.exceptions.MisconfigurationException: No `test_dataloader()` method defined to run `Trainer.test`.


 这里说什么没定义这个方法,解决的方法就是在运行的时候直接加上参数--no-test即可。

第二步:更新权重

执行下面的命令即可实现,这里<folder-name> 就是你微调后的那个模型的文件夹,比如:2024-01-13T14-11-49_cat-sdv4,这一步我在执行过程中没有遇到什么问题。

## save updated model weights
python src/get_deltas.py --path logs/<folder-name> --newtoken 1

第三步:运行

## sample
python sample.py --prompt "<new1> cat playing with a ball" --delta_ckpt logs/<folder-name>/checkpoints/delta_epoch\=000004.ckpt --ckpt <pretrained-model-path>

这个new1就是个占位符,无需更改;<folder-name>和上述的含义一样,最后这个“000004.ckpt”是你想要用的权重文件名称。 最后--ckpt <pretrained-model-path> 就是预训练的模型路径。

如果一切顺利的话,就会出图啦!

图片存放的位置以及我生成的图片如下:

 

2.多概念微调——生成的图像作为正则化

官方的readme中只给出了基于真实图像的代码,所以自己实现了一下生成图像正则化。

第一步:生成正则化图像。

上面我们已经生成的cat的正则化图像,这里还需要wooden_pot的正则化图像,所以我们需要先采样生成图像,我这里用的命令如下:

python -u sample.py \
        --n_samples 5 \
        --n_iter 40 \
        --scale 6 \
        --ddim_steps 50  \
        --ckpt  /data/disk1/sxtang/models/sd-v1-4.ckpt  \  #预训练模型的路径
        --ddim_eta 1. \
        --outdir "gen_reg/samples_wooden_pot" \   # 输出图像的路径
        --prompt "photo of a wooden_pot" 

 第二步:微调,这里我稍微改了一下那个项目中给出的基于真实图像实现的.sh文件

#!/usr/bin/env bash
#### command to run with retrieved images as regularization
# 1st arg: target caption1
# 2nd arg: path to target images1
# 3rd arg: path where retrieved images1 are saved
# 4rth arg: target caption2
# 5th arg: path to target images2
# 6th arg: path where retrieved images2 are saved
# 7th arg: name of the experiment
# 8th arg: config name
# 9th arg: pretrained model path

ARRAY=()

for i in "$@"
do
    echo $i
    ARRAY+=("${i}")
done


python -u  train.py \
        --base configs/custom-diffusion/${ARRAY[7]}  \
        -t --gpus 6,7 \
        --resume-from-checkpoint-custom  ${ARRAY[8]} \
        --caption "<new1> ${ARRAY[0]}" \
        --datapath ${ARRAY[1]} \
        --reg_datapath "${ARRAY[2]}/samples" \
        --reg_caption "${ARRAY[0]}" \
        --caption2 "<new2> ${ARRAY[3]}" \
        --datapath2 ${ARRAY[4]} \
        --reg_datapath2 "${ARRAY[5]}/samples" \
        --reg_caption2 "${ARRAY[3]}" \
        --modifier_token "<new1>+<new2>" \
        --name "${ARRAY[6]}-sdv4"

 执行命令:

bash scripts/finetune_joint_gen.sh "wooden pot" data/wooden_pot gen_reg/samples_wooden_pot \
                                    "cat" data/cat gen_reg/samples_cat  \
                                    wooden_pot+cat finetune_joint.yaml /data/disk1/sxtang/models/sd-v1-4.ckpt

注:如果需要调整如batchsize等参数,这里是在finetune_joint.yaml文件中更改。

如果一切顺利,出现如下界面,就代表着微调成功啦:

后面两步和单个概念那边一样,这里不过多叙述。

第二步:更新权重

## save updated model weights
python src/get_deltas.py --path logs/<folder-name> --newtoken 2

 第三步:运行

## sample
python sample.py --prompt "the <new2> cat sculpture in the style of a <new1> wooden pot" --delta_ckpt logs/<folder-name>/checkpoints/delta_epoch\=000004.ckpt --ckpt <pretrained-model-path>

下面是我测试所生成的图像:

四、最后

这篇文章和Dreambooth等有着异曲同工之妙,都是为了实现个性化的图像生成,当然论文中还有比如通过diffusers实现等功能,如果感兴趣可以自己去试试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/317248.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

npmvue详解

1、npm是node.js的一个软件依赖包管理工具 2、当前目录下面一般会有一个package.json文件 3、npm install 会去按照package.json文件中的依赖去下载依赖包 默认会下载到当前目录中的node_modules文件夹下&#xff0c;-g会进行全局安装 4、package.json文件中有两种依赖关系 …

Mybatis实现映射,一次查询和嵌套查询

1.实现映射 Mybatis的最大魅力就在于它的语句映射。实现映射一般有一下三种方法&#xff1a; 当我们在数据库的列名和java中的属性名完全相同时&#xff0c;mybatis会自动映射并将查询结果封装。 对于由多个单词组成的名字时&#xff08;例如studentgender&#xff09;&…

MATLAB | 龙年大吉,使用MATLAB绘制会动的中国风神龙

hey各位好久不见&#xff0c;龙年到了&#xff0c;这期画一期配色非常中国风的龙&#xff0c;这个造型的龙参考了某些html绘制龙的视频&#xff0c;但是由于html版全网都是也不咋给代码和代码出处&#xff0c;因此自己写了个MATLAB版本&#xff1a; 可以看到还是非常酷炫的&…

Linux 内核学习 2 - 用户程序如何被塞进内核进行调度?

Shell是系统的用户界面&#xff0c;提供了用户与内核进行交互操作的一种接口。它接收用户输入的命令并把它送入内核去执行。 fork里copy了父进程的信息&#xff0c;并激活task放到运行队列&#xff0c;当系统发生调度并获得执行机会时开始执行&#xff0c;但这时还不是hello程序…

【RHEL】Vivado调用VCS+Verdi联合仿真报错解决

问题描述 在使用VCS Verdi仿真Vivado工程时&#xff0c;点击行为仿真按钮进度条窗口消失后&#xff0c;Verdi窗口并未出现&#xff0c;查看消息报错如下&#xff1a; vcs: line 34205: 119837 Segmentation fault (core dumped) ${TOOL_HOME}/bin/cfs_ident_exec -f ${X…

vulnhub靶场之DC-7

一.环境搭建 1.靶场描述 DC-7 is another purposely built vulnerable lab with the intent of gaining experience in the world of penetration testing. While this isnt an overly technical challenge, it isnt exactly easy. While its kind of a logical progression …

7个向量数据库对比:Milvus、Pinecone、Vespa、Weaviate、Vald、GSI 和 Qdrant

本文简要总结了当今市场上正在积极开发的7个向量数据库&#xff0c;Milvus、Pinecone、Vespa、Weaviate、Vald、GSI 和 Qdrant 的详细比较。 我们已经接近在搜索引擎体验的基础层面上涉及机器学习&#xff1a;在多维多模态空间中编码对象。这与传统的关键字查找不同&#xff08…

Android Studio个性化修改

Android Studio原始界面看着也太无趣了叭&#xff0c;话不多说跟步骤走就可以。 1.更改Android Studio主题及背景 1.背景修改 File->Settings->Plugins&#xff0c;搜索Sexy Editor 重启后&#xff0c;左侧边栏出现Other Settings选项&#xff0c;点击SexyEditor进行背…

K8S后渗透横向节点与持久化隐蔽方式探索

前言 通常在红蓝对抗中&#xff0c;我们可能会通过各种方法如弱口令、sql注入、web应用漏洞导致的RCE等方法获得服务器的权限&#xff1b;在当前云原生迅猛发展的时代&#xff0c;这台服务器很可能是一个容器&#xff0c;在后续的后渗透由传统的提权变为容器逃逸&#xff0c;内…

在程序中链接静态库 和 动态库

9. 链接库 在编写程序的过程中&#xff0c;可能会用到一些系统提供的动态库或者自己制作出的动态库 或者静态库文件&#xff0c;cmake中也为我们提供了相关的加载动态库的命令hehedalinux:~/Linux/loveDBTeacher-v3$ tree . ├── CMakeLists.txt ├── include │ └── …

Java合并两个有序链表

思路&#xff1a; 创建一个临时的节点&#xff0c;命名傀儡节点&#xff0c;可以理解成临时的头节点&#xff0c;newHead&#xff0c;list1和list2的两两元素比较&#xff0c;小的连接newHead&#xff08;升序&#xff09;newHead的路径&#xff08;蓝色&#xff09;就是连接后…

MySQL 基于 GTID 主从复制

GTID 定义 GTID 是 MySQL 事务标识&#xff0c;为每一个提交的事务都生成一个标识&#xff0c;并且是全局唯一的&#xff0c;这个特性是从 MySQL5.6 引进的。 组成 GTID 是由 UUID TID&#xff0c;UUID 是MySQL的唯一标识&#xff0c;每个MySQL实例之间都是不同的。TID是代表…

Servlet-执行流程生命周期

一、思考 在上一篇文章Servlet基本概念中&#xff0c;我们抛出了一个问题&#xff1a;我们定义一个类实现了Servlet接口后&#xff0c;是谁创建了这个类的对象呢&#xff0c;又是谁调用了类中的service方法呢&#xff1f;本篇我们将介绍Servlet的执行流程。 二、执行流程 根…

Dreamweaver CS 操作

服务器 在Windows 10中添加IIS 可以将自己的电脑设置为服务器&#xff0c;在Windows 10中添加IIS的步骤如下&#xff1a; 在开始按钮上点击右键&#xff0c;选择“控制面板”。从控制面板选择“程序”。然后选择“启用或关闭Windows功能”。在弹出的对话框中&#xff0c;找到…

【华为OD机试真题2023CD卷 JAVAJS】查找一个有向网络的头节点和尾节点

华为OD2023(C&D卷)机试题库全覆盖,刷题指南点这里 查找一个有向网络的头节点和尾节点 知识点图DFS搜索 时间限制:1s 空间限制:256MB 限定语言:不限 题目描述: 给定一个有向图,图中可能包含有环,图使用二维矩阵表示,每一行的第一列表示起始节点,第二列表示终止节…

Error: Failed to download template from registry: fetch failed

第一次构建Nuxt项目时&#xff0c;出现在这样的错误&#xff01;&#xff01;&#xff01; 如果你也是这样得错误&#xff0c;修改hosts也没用。我试了 是因为你的npm安装了其他镜像源&#xff0c; 这个时候你就需要手动下载了&#xff1a; web端访问&#xff1a; https://ra…

x-cmd pkg | csview - 美观且高性能的 csv 数据查看工具

目录 介绍首次用户功能特点类似工具与竞品进一步阅读 介绍 csview 是一个用于在命令行中查看 CSV 文件的工具&#xff0c;采用 Rust 语言编写的&#xff0c;支持中日韩/表情符号。它允许用户在终端中以表格形式查看 CSV 数据&#xff0c;可以对数据进行排序、过滤、搜索等操作…

牛仔行头目标检测竞赛 李沐老师Kaggle(有代码有注释)

一、比赛地址 CowBoy Outfits Detection | Kaggle 二、数据介绍 这次比赛的数据集一共有4879个image&#xff0c;需要自己划分训练集和验证集。annotation的格式是json&#xff0c;因此如果用yolo系列的算法会涉及到json2txt这一步。当然如果不用yolo系列的就不要这段代码了。…

PyTorch Tutorial

本文作为博客“Transformer - Attention is all you need 论文阅读”的补充内容&#xff0c;阅读的内容来自于 https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#recommended-preparation 建议的准备流程。 Deep Learning with PyTorch: …

java实现局域网内视频投屏播放(五)视频搜索和投屏自动切换下一个

代码链接 这次对ui做了一些调整&#xff0c;整体分成了5个类别分别为 搜索设备播放任务已下载视频列表视频搜索下载任务列表 视频搜索 搜索 点击搜索后&#xff0c;会从执行所有VideoResolver实现类的search方法&#xff0c;将搜索到的结果汇总到一起&#xff0c;根据视频的…