DALL·E 2 文生图模型实践指南

前言:本篇博客记录使用dalle2模型进行推断时借鉴的相关资料和DEBUG流程。

相关博客:超详细!DALL · E 文生图模型实践指南


在这里插入图片描述

目录

  • 1. 环境搭建和预训练模型准备
    • 环境搭建
    • 预训练模型下载
  • 2. 代码
  • 3. BUG&DEBUG
    • URLError
    • CUDA error
    • RuntimeError
    • PydanticUserError


1. 环境搭建和预训练模型准备

本文使用的代码仓库为:https://github.com/lucidrains/DALLE2-pytorch

环境搭建

pip install dalle2-pytorch

预训练模型下载

地址:https://huggingface.co/laion/DALLE2-PyTorch

2. 代码

DALLE2 for inference 完整推断流程如下(from @cest_andre in Issues#282):

import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig


prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
prior = prior_config.create().cuda()

prior_model_state = torch.load("weights/prior_latest.pth")
prior.load_state_dict(prior_model_state, strict=True)

decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
decoder = decoder_config.create().cuda()

decoder_model_state = torch.load("weights/decoder_latest.pth")["model"]

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

decoder.load_state_dict(decoder_model_state, strict=True)

dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()

images = dalle2(
    ['your prompt here'],
    cond_scale = 2.
).cpu()

print(images.shape)

for img in images:
    img = ToPILImage()(img)
    img.show()

3. BUG&DEBUG

URLError

报错信息如下:

Traceback (most recent call last):
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1350, in do_open
    h.request(req.get_method(), req.selector, req.data, headers,
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1255, in request
    self._send_request(method, url, body, headers, encode_chunked)
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1301, in _send_request
    self.endheaders(body, encode_chunked=encode_chunked)
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1250, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1010, in _send_output
    self.send(msg)
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 950, in send
    self.connect()
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1424, in connect
    self.sock = self._context.wrap_socket(self.sock,
  File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 500, in wrap_socket
    return self.sslsocket_class._create(
  File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 1040, in _create
    self.do_handshake()
  File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 1309, in do_handshake
    self._sslobj.do_handshake()
ConnectionResetError: [Errno 104] Connection reset by peer

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/newdata/SD/extra/dalle2_cest.py", line 11, in <module>
    prior = prior_config.create().cuda()
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 185, in create
    clip = self.clip.create()
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 122, in create
    return OpenAIClipAdapter(self.model)
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/dalle2_pytorch.py", line 313, in __init__
    openai_clip, preprocess = clip.load(name)
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/clip/clip.py", line 122, in load
    model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/clip/clip.py", line 59, in _download
    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 222, in urlopen
    return opener.open(url, data, timeout)
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 525, in open
    response = self._open(req, data)
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 542, in _open
    result = self._call_chain(self.handle_open, protocol, protocol +
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 502, in _call_chain
    result = func(*args)
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1393, in https_open
    return self.do_open(http.client.HTTPSConnection, req,
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1353, in do_open
    raise URLError(err)
urllib.error.URLError: <urlopen error [Errno 104] Connection reset by peer>

我使用的是https://github.com/lucidrains/DALLE2-pytorch这个网址。

找到 /root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py 中对应的位置,我这里是第1349行,修改方式也在下面代码中一并给出。

try:
    h.request(req.get_method(), req.selector, req.data, headers,
              encode_chunked=req.has_header('Transfer-encoding'))
    time.sleep(0.5)  # 添加的一行
except OSError as err: # timeout error
    raise URLError(err)

CUDA error

RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

解决方案:版本不匹配,更换与系统cuda相匹配的pytorch版本。比如我的cuda版本是12.0,可以使用如下命令安装pytorch:

pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 torchaudio==2.0.1+cu118 -f https://download.pytorch.org/whl/torch_stable.html

RuntimeError

Traceback (most recent call last):
  File "/newdata/SD/extra/dalle2_cest.py", line 14, in <module>
    prior.load_state_dict(prior_model_state, strict=True)
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for DiffusionPrior:
        Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". 
        Unexpected key(s) in state_dict: "net.null_text_embed". 

解决方案1️⃣:load_state_dict()函数中的 strict=True 改为 strict=False,如下:

...
prior.load_state_dict(prior_model_state, strict=False)

decoder.load_state_dict(decoder_model_state, strict=False)
...

但这种方法可能会导致模型的性能下降,而且会生成马赛克图像,这显然不是我们想要的结果。

在这里插入图片描述

解决方案2️⃣:参考Issues中cest-andre的答案。

步骤(1)降低dalle2_pytorch版本至1.1.0

pip install dalle2-pytorch==1.1.0

步骤(2):版本降低后,要修复dalle2_pytorch.py文件中一个小bug:将第2940行改为如下代码:

images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)

PydanticUserError

降低dalle2_pytorch版本后,运行程序报错如下:

Traceback (most recent call last):
  File "/newdata/SD/extra/dalle2_cest.py", line 8, in <module>
    from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 34, in <module>
    class TrainSplitConfig(BaseModel):
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 40, in TrainSplitConfig
    def validate_all(cls, fields):
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/pydantic/deprecated/class_validators.py", line 222, in root_validator
    return root_validator()(*__args)  # type: ignore
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/pydantic/deprecated/class_validators.py", line 228, in root_validator
    raise PydanticUserError(
pydantic.errors.PydanticUserError: If you use `@root_validator` with pre=False (the default) you MUST specify `skip_on_failure=True`. Note that `@root_validator` is deprecated and should be replaced with `@model_validator`.

解决方案:参考Issues中JasbirCodeSpace的答案,降低Pydantic的版本:

pip install pydantic==1.10.6

到这里,模型就可以完成推断过程啦~嘻嘻!以下是prompt为a red car时生成的图像:

在这里插入图片描述

后记:感谢前人铺路!🌹


参考链接

  1. https://github.com/lucidrains/DALLE2-pytorch/issues/282
  2. python requests请求报错ConnectionError: (‘Connection aborted.‘, error(104, ‘Connection reset by peer‘))_铁朵斯提的博客-CSDN博客
  3. GPU版本pytorch(Cuda12.1)清华源快速安装一步一步教!小白教学~_清华源安装torch-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/143365.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

算法笔记-第五章-分数的四则运算

分数的四则运算 分数约分分数加法分数减法分数乘法分数除法分数的输出 分数约分 #include <cstdio> #include <algorithm> using namespace std; struct Fraction {//用结构体表示分子和分母int up, down; }; int gcd(int a, int b) {//求出最大公约数if (b 0) {r…

编程实例与解释,编程入门自学指南,中文编程工具下载

编程实例与解释&#xff0c;编程入门自学指南&#xff0c;中文编程工具下载。 编程实例可以点击最下方卡片了解&#xff0c;编程入门自学可以先从容易的入手&#xff0c;比如中文编程&#xff0c;先学习编程思路&#xff0c;然后再学习其他语言会比较轻松。中文编程工具可以下…

浅谈JavaScript闭包,小白的JS学习之路!

前言 在JavaScript中&#xff0c;闭包是一种强大而灵活的特性&#xff0c;它不仅允许变量私有化&#xff0c;而且提供了一种在函数执行完毕后仍然保持对外部作用域变量引用的机制。本文将深入讨论JavaScript闭包的概念、优点、缺点以及如何避免潜在的内存泄漏问题。 调用栈与…

Shiro快速入门之二

一、前言 Shiro快速入门之一 介绍了Shiro三大核心组件&#xff0c;四大核心功能&#xff0c;以及一个简单的Test Demo&#xff0c;接下来两篇我会用一个比较完整的例子来讲述Shiro的认证及授权是怎么做的&#xff0c;本篇侧重于介绍认证的过程 二、Shiro认证例子 1、例子概述…

H5三网魔幻手游【众神传奇】win服务端+GM授权后台+架设教程

搭建资源下载地址&#xff1a;H5三网魔幻手游【众神传奇】win服务端GM授权后台架设教程-海盗空间

锂价疲软,市场需求持续低迷,赣锋锂业在短期内将继续面临痛苦

来源&#xff1a;猛兽财经 作者&#xff1a;猛兽财经 公司介绍 在赣锋锂业(01772)(002460)网站上&#xff0c;赣锋锂业称自己是一家拥有“中国第一”和“世界最大锂金属生产商”的“锂化合物产能”公司。 根据其2022财年年度报告&#xff0c;该公司最近一个财年的收入有84%和1…

日历应用程序 BusyCal mac中文版软件特点

BusyCal mac是一款日历应用程序&#xff0c;它可以帮助用户轻松地管理日程安排、事件提醒、会议安排等。BusyCal 支持 macOS 和 iOS 平台&#xff0c;并且可以与 iCloud、Google 日历、Exchange 等多种日历服务进行同步。 BusyCal mac软件特点 强大的日历功能&#xff1a;Busy…

Rust编程中的共享状态并发执行

1.共享状态并发 虽然消息传递是一个很好的处理并发的方式&#xff0c;但并不是唯一一个。另一种方式是让多个线程拥有相同的共享数据。在学习Go语言编程过程中大家应该听到过一句口号:"不要通过共享内存来通讯"。 在某种程度上&#xff0c;任何编程语言中的信道都类…

活动通知邀请函H5页面制作源码系统+动感的背景音乐 自定义你想要的页面 源码完全开源可二开 带完整搭建教程

在现代社交活动中&#xff0c;一份精美、个性化的活动邀请函不仅能够展现主办方的品味和诚意&#xff0c;还可以吸引更多的参与者。然而&#xff0c;制作一份精美的活动邀请函需要专业的设计和技术支持&#xff0c;这对于很多非专业人士来说是一个难题。此外&#xff0c;传统的…

mysql之主从复制和读写分离

一、主从复制 1、定义 主mysql上的数据&#xff08;新增或修改库、表里的数据&#xff09;都会同步到从mysql上 2、mysql的主从复制模式&#xff08;面试题&#xff09; &#xff08;1&#xff09;异步复制&#xff08;常用&#xff09;&#xff1a;默认的复制模式。客户端…

网络运维Day15

文章目录 Prometheus简介环境准备配置模板机环境部署阿里镜像源实验环境准备 部署prometheus服务查看及测试 Promethues 被监控端Grafana简介部署 Grafana 服务器修改数据源 监控数据库安装部署Mariadb安装导出器修改 Prometheus服务端配置Grafana配置 总结 Prometheus简介 Pr…

需要买哪些网络设备才能过等保?求解!

随着等保2.0的落地执行&#xff0c;越来越多的企业需要过等保。但不少企业都是第一次过等保&#xff0c;对于等保政策不是很了解&#xff0c;有小伙伴问&#xff0c;需要买哪些网络设备才能过等保&#xff1f;这里我们小编就给大家来简单回答一下&#xff0c;仅供参考哈&#x…

Q learning

Q learning Q Learning是强化学习算法中的一个经典算法。在一个决策过程中&#xff0c;我们不知道完整的计算模型&#xff0c;所以需要我们去不停的尝试。 算法流程 整体流程如下&#xff1a; Q-table 初始化 第一步是创建 Q-table&#xff0c;作为跟踪每个状态下的每个动作…

电商平台api接口对接电商数据平台,获取商品详情页面实时信息须知

随着互联网的发展和普及&#xff0c;电商平台已成为人们日常生活中不可或缺的一部分。而为了保证电商平台的正常运行&#xff0c;平台与开发者之间需要进行数据交互&#xff0c;这便涉及到了电商平台API接口对接的问题。本文将详细介绍电商平台API接口对接的须知事项。 一、了解…

双算法SSL证书

国际算法的优势与挑战 1. RSA算法 RSA算法是一种基于大素数分解的非对称加密算法&#xff0c;长期以来一直是SSL证书的主流选择之一。然而&#xff0c;随着计算能力的提高&#xff0c;RSA算法的密钥长度需要不断增加&#xff0c;以维持足够的安全性。 2. ECC算法 椭圆曲线密…

迅软DSE答疑专业解析:内网遭受攻击的威胁到底有多大

当今数字化时代&#xff0c;企业数据安全已演变为企业生存和发展的至关重要因素。随着信息技术的迅猛发展&#xff0c;企业内网不仅是承载核心数据和信息的关键平台&#xff0c;也成为黑客和恶意软件攻击的主要目标。因此&#xff0c;确保企业数据安全和内网安全已成为企业管理…

ADC内部电源监控

文章目录 前言一、分析芯片手册1、43.5 ADC internal supply monitoring2、11.1.1.1 Chip Control register (CHIPCTL)1&#xff09;SIM_CHIPCTL[ADC_SUPPLY]2&#xff09;SIM_CHIPCTL[ADC_SUPPLYEN] 3、44.4.2 ADC Status and Control Register 1 (SC1A - aSC1P) 二、EB配置1、…

@CacheInvalidate(name = “xxx“, key = “#results.![a+b]“,multi = true)是什么意思

@CacheInvalidate 注解是 JetCache 框架提供的注解,它是由阿里巴巴开源的组织 Alibaba Group 开发和维护的。JetCache 是一款基于注解的缓存框架,提供了丰富的缓存功能和灵活的配置选项,可用于增强应用程序的性能和可扩展性。JetCache 支持多种缓存后端,包括内存缓存、Redi…

Ubuntu22.04源码安装ROS-noetic(ROS1非ROS2),编译运行VINS-MONO

1. Ubuntu22.04源码编译安装ROS-noetic 由于22.04默认安装ROS2&#xff0c;但很多仓库都是基于ROS1的&#xff0c;不想重装系统&#xff0c;参考这两个博客安装了ROS-noetic&#xff1a; 博客1. https://blog.csdn.net/Drknown/article/details/128701624博客2. https://zhua…

(七)Spring源码解析:Spring事务

对于事务来说&#xff0c;是我们平时在基于业务逻辑编码过程中不可或缺的一部分&#xff0c;它对于保证业务及数据逻辑原子性立下了汗马功劳。那么&#xff0c;我们基于Spring的声明式事务&#xff0c;可以方便我们对事务逻辑代码进行编写&#xff0c;那么在开篇的第一部分&…