本地部署推理TextDiffuser-2:释放语言模型用于文本渲染的力量

系列文章目录

文章目录

  • 系列文章目录
  • 一、模型下载和环境配置
  • 二、模型训练
    • (一)训练布局规划器
    • (二)训练扩散模型
  • 三、模型推理
    • (一)准备训练好的模型checkpoint
    • (二)全参数推理
    • (三)LoRA微调推理
  • 四、遇到的错误
    • (一)importerror,缺少某些库
    • (二)报错:libGL.so.1: cannot open shared object file: No such file or directory
    • (三)各种奇奇怪怪的错误(本质上是diffusers版本不对)
    • (四)各种库的版本不兼容
    • (五)RuntimeError: expected scalar type float Float bu found Half


一、模型下载和环境配置

  1. 将textdiffuser-2模型仓库克隆到本地
git clone https://github.com/microsoft/unilm/
cd unilm/textdiffuser-2
  1. 创建并激活虚拟环境,在textdiffuser-2目录下安装需要的软件包
conda create -n textdiffuser2 python=3.8
conda activate textdiffuser2
pip install -r requirements.txt
  1. 安装与系统版本和cuda版本相匹配的torch、torchvision、xformers (我的环境下cuda是12.2的,其他版本需要自己去官网查询)
    在这里插入图片描述
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple xformers
  1. 如果想用FastChat训练布局规划器,还需要安装flash-attention:

先将flash-attention模型仓库克隆下来

git clone https://github.com/Dao-AILab/flash-attention.git

然后安装对应的软件包

pip install packaging
pip uninstall -y ninja && pip install ninja
conda install -c nvidia cuda
pip install flash-attn --no-build-isolation
  1. 为了训练文本修复任务,还需要安装 differs 包
pip install https://github.com/JingyeChen/diffusers_td2.git

二、模型训练

(一)训练布局规划器

  1. 需要先下载lmsys/vicuna-7b-v1.5模型和FastChat模型。

模型下载方式: 采用git远程clone下来,具体方式可以参考之前的内容:huggingface学习 | 云服务器使用git-lfs下载huggingface上的模型文件;

  1. 进行训练
CUDA_VISIBLE_DEVICES=4,5 torchrun --nproc_per_node=2 --master_port=50008 FastChat-main/fastchat/train/train_mem.py \
    --model_name_or_path vicuna-7b-v1.5  \
    --data_path data/layout_planner_data_5k.json \
    --bf16 True \
    --output_dir experiment_result \
    --num_train_epochs 6 \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 16 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 500 \
    --save_total_limit 5 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --tf32 True \
    --model_max_length 2048 \
    --gradient_checkpointing True \
    --lazy_preprocess True

(二)训练扩散模型

  1. 需要先准备需要训练的扩散模型:stable-diffusion-v1-5模型
  2. 对于全参数训练:
accelerate launch train_textdiffuser2_t2i_full.py \
    --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
    --train_batch_size=18 \
    --gradient_accumulation_steps=4 \
    --gradient_checkpointing \
    --mixed_precision="fp16" \
    --num_train_epochs=6 \
    --learning_rate=1e-5 \
    --max_grad_norm=1 \
    --lr_scheduler="constant" \
    --lr_warmup_steps=0 \
    --output_dir="diffusion_experiment_result" \
    --enable_xformers_memory_efficient_attention \
    --dataloader_num_workers=8 \
    --index_file_path='/path/to/train_dataset_index.txt' \
    --dataset_path='/path/to/laion-ocr-select/' \
    --granularity=128 \
    --coord_mode="ltrb" \
    --max_length=77 \
    --resume_from_checkpoint="latest"
  1. 对于 LoRA 训练:
accelerate launch train_textdiffuser2_t2i_lora.py \
    --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
    --train_batch_size=18 \
    --gradient_accumulation_steps=4 \
    --gradient_checkpointing \
    --mixed_precision="fp16" \
    --num_train_epochs=6 \
    --learning_rate=1e-4 \
    --text_encoder_learning_rate=1e-5 \
    --lr_scheduler="constant" \
    --output_dir="diffusion_experiment_result" \
    --enable_xformers_memory_efficient_attention \
    --dataloader_num_workers=8 \
    --index_file_path='/path/to/train_dataset_index.txt' \
    --dataset_path='/path/to/laion-ocr-select/' \
    --granularity=128 \
    --coord_mode="ltrb" \
    --max_length=77 \
    --resume_from_checkpoint="latest"

三、模型推理

(一)准备训练好的模型checkpoint

  1. 下载官网提供的模型checkpoint:layout planner、diffusion model (full parameter fine-tuning) 和diffusion model (lora fine-tuning)

  2. 准备stable-diffusion-v1-5模型

(二)全参数推理

CUDA_VISIBLE_DEVICES=4 accelerate launch inference_textdiffuser2_t2i_full.py \
  --pretrained_model_name_or_path="./stable-diffusion-v1-5" \
  --mixed_precision="fp16" \
  --output_dir="inference_results_1" \
  --enable_xformers_memory_efficient_attention \
  --resume_from_checkpoint="./textdiffuser2-full-ft" \
  --granularity=128 \
  --max_length=77 \
  --coord_mode="ltrb" \
  --cfg=7.5 \
  --sample_steps=20 \
  --seed=43555 \
  --m1_model_path="./textdiffuser2_layout_planner" \
  --input_format='prompt' \
  --input_prompt='a hotdog with mustard and other toppings on it'

推理结果:
在这里插入图片描述

(三)LoRA微调推理

CUDA_VISIBLE_DEVICES=4 accelerate launch inference_textdiffuser2_t2i_lora.py \
  --pretrained_model_name_or_path="./stable-diffusion-v1-5" \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --mixed_precision="fp16" \
  --output_dir="inference_results_2" \
  --enable_xformers_memory_efficient_attention \
  --resume_from_checkpoint="./textdiffuser2-lora-ft" \
  --granularity=128 \
  --coord_mode="ltrb" \
  --cfg=7.5 \
  --sample_steps=50 \
  --seed=43555 \
  --m1_model_path="./textdiffuser2_layout_planner" \
  --input_format='prompt' \
  --input_prompt='a stamp of u.s.a'

运行结果:
在这里插入图片描述

四、遇到的错误

(一)importerror,缺少某些库

在运行过程中出现了各种各样的importerror,于是就是缺少哪个库就下载那个库:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple opencv-python
pip install protobuf

(二)报错:libGL.so.1: cannot open shared object file: No such file or directory

pip uninstall opencv-python
pip install opencv-python-headless

(三)各种奇奇怪怪的错误(本质上是diffusers版本不对)

  • RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
  • The deprecation tuple (‘LoRAXFormersAttnProcessor’, ‘0.26.0’, 'Make sure use XFormersAttnProcessor instead by settingLoRA layers to `self.
pip install diffusers==0.24.0 -i https://pypi.mirrors.ustc.edu.cn/simple/

(四)各种库的版本不兼容

由于作者在官网上提供了实验中使用的软件包列表可供参考,所以我直接将textdiffuser-2的assets文件夹下的refere_requirements.txt文件中的库一次性安装下来:

cd assets
pip install -r reference_requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple/

在这里插入图片描述

(五)RuntimeError: expected scalar type float Float bu found Half

这个错误是因为安装的diffusers包里有个文件需要用官网提供的新文件进行替换
可以先根据错误提示找到diffusers库包中attention_processor.py所在的位置,然后用assets文件夹下attention_processor.py进行替换即可解决问题。

在这里插入图片描述

参考:libGL.so.1: cannot open shared object file: No such file or directory

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/447309.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

PaddlePaddle框架安装

提示:可在python环境中进行安装,避免环境污染,创建命令conda create -n xxx_name python3.9,激活conda activate xxx_name 第一步:查看计算机平台版本 在窗口输入查看命令,查看CUDA的版本 nvidia-smi 二、根据以下条件…

pycharm连接远程服务器解决远程服务器文件不同步问题

一、pycharm连接远程服务器 1.前提条件 linux已经安装好虚拟环境,并且虚拟环境已经有python的相关环境 conda create -n name python3.10 2.连接服务器 1.进入pycharm的Settings 2.在Project里找Python Interpreter 3.点击⚙,再点击add 4.选择SSH…

【C++从练气到飞升】02---初识类与对象

🎈个人主页:库库的里昂 ✨收录专栏:C从练气到飞升 🎉鸟欲高飞先振翅,人求上进先读书。 目录 ⛳️推荐 一、面向过程和面向对象初步认识 二、类的引用 1. C语言版 2. C版 三、类的定义 类的两种定义方式&#xff…

机器学习-04-分类算法-01决策树

总结 本系列是机器学习课程的系列课程,主要介绍机器学习中分类算法,本篇为分类算法开篇与决策树部分。 本门课程的目标 完成一个特定行业的算法应用全过程: 懂业务会选择合适的算法数据处理算法训练算法调优算法融合 算法评估持续调优工程…

如何在Win系统本地部署Jupyter Notbook交互笔记并结合内网穿透实现公网远程使用

文章目录 1.前言2.Jupyter Notebook的安装2.1 Jupyter Notebook下载安装2.2 Jupyter Notebook的配置2.3 Cpolar下载安装 3.Cpolar端口设置3.1 Cpolar云端设置3.2.Cpolar本地设置 4.公网访问测试5.结语 1.前言 在数据分析工作中,使用最多的无疑就是各种函数、图表、…

【前端系列】CSS 常见的选择器

CSS 常见的选择器 CSS(层叠样式表)是一种用于描述网页样式的标记语言,它定义了网页中各个元素的外观和布局。在 CSS 中,选择器是一种用于选择要应用样式的 HTML 元素的模式。选择器允许开发人员根据元素的类型、属性、关系等来选…

【计算机视觉】图像处理算法(其他篇)

来源:《OpenCV3编程入门》,怀念毛星云大佬🕯️ 说明:本系列重点关注各种图像处理算法的原理、作用和对比 漫水填充 漫水填充法是一种用特定的颜色填充连通区域,通过设置可连通像素的上下限以及连通方式来达到不同的填…

【零基础学习02】嵌入式linux驱动中原子操作基本实现

大家好,为了进一步提升大家对实验的认识程度,每个控制实验将加入详细控制思路与流程,欢迎交流学习。 今天给大家分享一下,linux系统里面并发与竞争具体实现,操作硬件为I.MX6ULL开发板。 第一:Linux系统并发与竞争简介 linux是一个多任务操作系统,存在多个任务操作同一个…

【SpringBoot框架篇】36.整合Tess4J搭建提供图片文字识别的Web服务

文章目录 简介文件下载引入依赖main函数中使用基于Springboot搭建OCR Web服务配置traineddata路径枚举用到的语种类型定义接口响应的json数据格式封装OCR服务引擎编写web提供服务的接口启动服务并且测试html demo扩展 项目配套代码 简介 Tess4J是一个基于Tesseract OCR引擎的J…

mysql的索引、事务、分库分表问题

1.了解MySQL的索引吗?它为什么使用Btree作为底层,而不是其他呢? 这里我们要谈的是其他数据结构的缺点,然后说说Btree的优点,也就看你对MySQL的Btree与其他数据结构熟不熟悉。 Hash (1)Hash 索引…

第五十五回 吴用使时迁偷甲 汤隆赚徐宁上山-以102flowers数据集为例训练ResNet50模型

汤隆说我家世代打造兵器,破连环马需要用钩镰枪,我会打造。还需要我的姑表哥,金枪手徐宁,因为他会钩镰枪法。汤隆和吴用商议了请徐宁上山的法子,派石迁去偷徐宁的宝贝雁翎锁子甲。 石迁到了东京,打听到徐宁住…

R语言:多值提取到点

ArcGIS中有相关工具实现多值提取到点的功能&#xff0c;在这里&#xff0c;我将使用R语言进行操作&#xff1a; library(dplyr) library(readxl) library(sf) library(raster)setwd("D:/Datasets") Bio <- stack(paste0("D:/Datasets/Data/worldclim2_1km/…

HTML静态网页成品作业(HTML+CSS)——家乡漳州介绍设计制作(1个页面)

&#x1f389;不定期分享源码&#xff0c;关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 &#x1f3f7;️本套采用HTMLCSS&#xff0c;未使用Javacsript代码&#xff0c;共有1个页面。 二、作品演示 三、代…

爬虫入门到精通_框架篇16(Scrapy框架基本使用_名人名言的抓取

1 目标站点分析 抓取网站&#xff1a;http://quotes.toscrape.com/ 主要显示了一些名人名言&#xff0c;以及作者、标签等等信息&#xff1a; 点击next&#xff0c;page变为2&#xff1a; 2 流程框架 抓取第一页&#xff1a;请求第一页的URL并得到源代码&#xff0c;进行下…

学校Java的第七天

目录 一、什么是数组 二、作用 三、如何使用数组 1、声明数组变量 2、创建数组 示例&#xff1a; 3、数组的使用 示例&#xff1a; 4、数组的遍历 for循环示例&#xff08;不知道for循环的可以查看我之前发的文章&#xff09; for-each循环&#xff08;也就是增强for…

使用JDBC操作数据库

意志、工作和等待是成功的金字塔的基石。 Will, work and wait are the pyramidal cornerstones for success. 文章目录 JDBC简介&#xff1a;JDBC访问数据库步骤StatementPreparedStatement JDBC简介&#xff1a; 在Java应用程序中&#xff0c;JDBC&#xff08;Java Database…

【实战项目】网络编程:在Linux环境下基于opencv和socket的人脸识别系统--C++实现

&#x1f31e;前言 这里我们会实现一个项目&#xff1a;在linux操作系统下基于OpenCV和Socket的人脸识别系统。 目录 &#x1f31e;前言 &#x1f31e;一、项目介绍 &#x1f31e;二、项目分工 &#x1f31e;三、项目难题 &#x1f31e;四、实现细节 &#x1f33c;4.1 关…

腾讯云服务器99元一年厉害了,老用户可以买,续费也是99元

良心腾讯云推出99元一年服务器&#xff0c;新用户和老用户均可以购买&#xff0c;续费不涨价&#xff0c;续费也是99元&#xff0c;配置为轻量2核2G4M、50GB SSD盘、300GB月流量、4M带宽&#xff1a;优惠价格99元一年&#xff0c;续费99元&#xff0c;官方活动页面 txybk.com/g…

java继承,接口,抽象类

目录 目录 1 继承的含义 2 继承的好处 3使类与类之间产生了关系。 看这里继承-------我的理解 代码部分 接口 代码 抽象类 代码 各位友友们大家好呀&#x1f60a;&#xff01; 今天让我们继续回顾java&#xff0c;看看java中的抽象类以及接口继承是什么&#x1f914…

LeetCode Python - 46.全排列

目录 题目答案运行结果 题目 给定一个不含重复数字的数组 nums &#xff0c;返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1&#xff1a; 输入&#xff1a;nums [1,2,3] 输出&#xff1a;[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]] 示例 2&#x…