最新的强大的文生视频模型Pyramid Flow 论文阅读及复现

《PYRAMIDAL FLOW MATCHING FOR EFFICIENT VIDEO GENERATIVE MODELING》

论文地址:2410.05954icon-default.png?t=O83Ahttps://arxiv.org/pdf/2410.05954

项目地址:

jy0205/Pyramid-Flow: 用于高效视频生成建模的金字塔流匹配代码icon-default.png?t=O83Ahttps://github.com/jy0205/Pyramid-Flow

论文提出了一种新的视频生成模型,通过金字塔流匹配算法(Pyramidal Flow Matching),有效降低了视频生成的计算复杂度。该方法通过在不同分辨率的金字塔阶段之间进行流匹配,实现了从噪声到数据的生成过程,并通过单一的Diffusion Transformer(DiT)进行端到端优化。

摘要详述

论文提出了一种高效的视频生成建模方法,称为金字塔流匹配,旨在通过降低计算复杂度来优化视频生成过程。该方法避免了直接在全分辨率下进行训练,而是将视频生成过程分解为多个在不同分辨率下运行的金字塔阶段,仅在最终阶段达到全分辨率。这种方法的主要优势包括:

  1. 连续性:不同金字塔阶段的生成轨迹相互链接,后续阶段继续从前一阶段生成,避免了每个阶段从纯噪声重新生成的需要。

  2. 统一模型:与为每个图像金字塔使用独立模型不同,金字塔流匹配算法将它们集成到一个统一的模型中,通过端到端优化实现更优雅的实现,并大幅加快训练速度。

在全分辨率下,在非常嘈杂的潜在值上花费大量计算。(b) 我们的方法利用流动匹配的 f 灵活性在不同分辨率的潜在变量之间进行插值。这允许以更好的计算效率同时生成和解压缩视觉内容。请注意,黑色箭头表示降噪轨迹,蓝色箭头表示时间条件。

方法详述

空间金字塔图示。(a) 金字塔流分为多个阶段,每个阶段都从像素化和嘈杂的起点到无像素化和更清晰的结果。(b) 在推理过程中,我们在跨阶段的跳跃点添加校正噪声,以确保概率的连续性

1. 金字塔流匹配 (Pyramidal Flow Matching)

论文提出了一个新颖的视频生成框架,称为金字塔流匹配,它通过将视频生成轨迹重新解释为不同尺度的压缩表示的金字塔阶段来解决视频生成中的高时空复杂性问题。具体来说,该方法只在最终阶段以全分辨率运行,而在早期阶段则在更低分辨率下运行,从而减少冗余计算。

  • 流的构建:在金字塔流中,每个阶段都从带有噪声的像素化(压缩)潜在表示开始,到无像素化(解压缩)且更清晰的潜在表示结束。通过这种方式,只有最后一个阶段在全分辨率下执行,而大多数阶段在更低分辨率下执行,减少了计算量。

  • 统一训练:为了统一不同阶段的建模,论文通过在不同噪声水平和分辨率之间进行插值来构建概率路径。这允许从低分辨率的噪声潜在表示生成更清晰、细节更丰富的高分辨率结果。

  • 推理中的重噪声:在推理过程中,需要在不同分辨率的金字塔阶段之间的跳跃点仔细处理,以确保概率路径的连续性。为此,论文提出了一种添加校正高斯噪声的方法,以匹配不同阶段之间的分布。

2. 空间金字塔 (Spatial Pyramid)

  • 流的分段:空间金字塔流被分为多个阶段,每个阶段从像素化且带噪声的起点到无像素化且更清晰的结果。每个阶段的流遵循类似的公式,插值在像素化(压缩)和更带噪声的潜在表示与无像素化(解压缩)且更清晰的潜在表示之间。

  • 训练和推理:在训练阶段,通过插值不同分辨率的潜在表示来构建金字塔流。在推理阶段,每个阶段的输出通过添加校正高斯噪声重新噪声化,以维持连续性。

3. 时间金字塔 (Temporal Pyramid)

  • 视频的时间复杂性:视频因其时间长度而呈现显著的挑战。现有的全序列扩散方法同时生成所有视频帧,限制了固定长度的生成。与之相对,自回归视频生成范式支持在推理期间灵活长度的生成。

  • 压缩历史条件:考虑到全分辨率历史条件中的高冗余,论文提出使用压缩的、低分辨率的历史进行自回归视频生成。这显著减少了视频生成预训练的计算和内存开销。

图 3:时间金字塔图示。(a) 在每个金字塔阶段,生成都以压缩的、低分辨率的历史记录为条件,以提高自回归模型的训练效率,如行所示。(b) 设计了一种兼容的位置编码方案,该方案在空间金字塔中外推,但在时间金字塔中插值,以允许条件的空间对齐

实验

 

 复现

复现了两种模式,一种为web ui 一种为推理,

1、下载及环境安装

git clone https://github.com/jy0205/Pyramid-Flow
cd Pyramid-Flow

# create env using conda
conda create -n pyramid python==3.8.10
conda activate pyramid
pip install -r requirements.txt

其实环境不一定一模一样,我用的之前的环境,但是diffusion 和transformer最好和requirement一样,如果出现找不到pyramid模块之类的报错,检查版本。

2、下载权重

新建一个py文件

from huggingface_hub import snapshot_download

model_path = 'PATH'   # The local directory to save downloaded checkpoint
snapshot_download("rain1011/pyramid-flow-miniflux", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')

修改model

3、web UI

调整app 内 model_path 变量,为上一步的model_path。注意目前使用的pyramid_flux而不是pyramid_mmdit

python app.py

根据弹出的页面,简单移动,大家可以自行尝试调参。

4、本地推理

新建test.py ,复制下面代码,修改model_path,本地gpu 内存24g,如果内存小,跑再下面那个代码,卸载到cpu 的版本.两个都是384p版本,差不多本地要两分钟左右生成5s。A800需要50s左右。3090生成768p极慢,不推荐。

import torch
from PIL import Image
from pyramid_dit import PyramidDiTForVideoGeneration
from diffusers.utils import load_image, export_to_video

torch.cuda.set_device(0)


model_dtype, torch_dtype = 'bf16', torch.bfloat16   # Use bf16 (not support fp16 yet)

model_path = ***


model = PyramidDiTForVideoGeneration(
    model_path,                # Pass the base model path
    model_name="pyramid_flux"  ,     # set to pyramid_flux or pyramid_mmdit
    model_dtype=model_dtype,  # Use bf16
    model_variant='diffusion_transformer_384p',  # Pass the variant directory name
    cpu_offloading=True,  # Pass the CPU offloading flag
)

model.vae.enable_tiling()
model.vae.to("cuda")
model.dit.to("cuda")
model.text_encoder.to("cuda")
from tqdm import tqdm
# if you're not using sequential offloading bellow uncomment the lines above ^
# model.enable_sequential_cpu_offload()
import json
prompts = []
with open(r"D:\T2V\KandinskyVideo-main\Qwen-Audio-main\prompts_dict_new.json", 'r', encoding='utf-8') as f:
    datas = json.load(f)
    for timestamp, data in datas.items():
        prompts.append(data)

print(prompts)

for i, prompt in tqdm(enumerate(prompts)):
    width = 640
    height = 384
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
        frames = model.generate(
            prompt=prompt,
            num_inference_steps=[20, 20, 20],
            video_num_inference_steps=[10, 10, 10],
            height=height,     
            width=width,
            temp=16,                    # temp=16: 5s, temp=31: 10s
            guidance_scale=7.0,         # The guidance for the first frame, set it to 7 for 384p variant
            video_guidance_scale=5.0,   # The guidance for the other video latent
            output_type="pil",
            save_memory=True,           # If you have enough GPU memory, set it to `False` to improve vae decoding speed
        )

    export_to_video(frames, f"./demo/2/{i}.mp4", fps=24)

    
    # prompt = "A wide shot of the sunflower field at sunset. The sky is now a deep orange and pink, with the sun setting behind the horizon. The sunflower petals are still swaying in the breeze, but the children have disappeared. A single butterfly lands on a sunflower, its wings shimmering in the warm light. The air is filled with the sound of crickets chirping."
    # # used for 384p model variant
    # width = 640
    # height = 384
# # used for 768p model variant
# # width = 1280
# # height = 768

cpu版本
 

import torch
from PIL import Image
from pyramid_dit import PyramidDiTForVideoGeneration
from diffusers.utils import load_image, export_to_video

torch.cuda.set_device(0)
model_dtype, torch_dtype = 'bf16', torch.bfloat16   # Use bf16 (not support fp16 yet)

model = PyramidDiTForVideoGeneration(
    'PATH',                                         # The downloaded checkpoint dir
    model_name="pyramid_flux",
    model_dtype=model_dtype,
    model_variant='diffusion_transformer_384p',
)

model.vae.enable_tiling()
# model.vae.to("cuda")
# model.dit.to("cuda")
# model.text_encoder.to("cuda")

# if you're not using sequential offloading bellow uncomment the lines above ^
model.enable_sequential_cpu_offload()

prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors"

# used for 384p model variant
width = 640
height = 384

# used for 768p model variant
# width = 1280
# height = 768

with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
    frames = model.generate(
        prompt=prompt,
        num_inference_steps=[20, 20, 20],
        video_num_inference_steps=[10, 10, 10],
        height=height,     
        width=width,
        temp=16,                    # temp=16: 5s, temp=31: 10s
        guidance_scale=7.0,         # The guidance for the first frame, set it to 7 for 384p variant
        video_guidance_scale=5.0,   # The guidance for the other video latent
        output_type="pil",
        save_memory=True,           # If you have enough GPU memory, set it to `False` to improve vae decoding speed
    )

export_to_video(frames, "./text_to_video_sample.mp4", fps=24)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/942425.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

RSICV国产芯片之CHV208

1. 芯片选型分析的对比维度 分析或者对标应用的芯片替代思路 1.1 内核/主频/存储空间支持 内核能力/指令集支持(考虑工具链兼容性); 主频:对比计算能力是否满足基本要求 存储:内存--数据搬移空间决定数据运算的…

7. petalinux 根文件系统配置(package group)

根文件系统配置(Petalinux package group) 当使能某个软件包组的时候,依赖的包也会相应被使能,解决依赖问题,在配置页面的help选项可以查看需要安装的包 每个软件包组的功能: packagegroup-petalinux-audio包含与音…

基于Spring Boot的个人健康管理系统

一、系统背景与意义 随着现代生活节奏的加快和人们健康意识的日益增强,个人健康管理成为了人们关注的焦点。然而,传统的健康管理方式往往依赖于纸质记录、定期体检等手段,不仅效率低下,而且难以实现对健康数据的持续跟踪和深入分…

上手教程:使用Terraform打造弹性VPC架构

最近Akamai发布的虚拟专用云(VPC)功能提供了一种隔离的网络,让云资源可以用私密的方式进行通信。 关于Akamai VPC功能,最棒的地方在于它有着极高的灵活性。用户可以通过Cloud Manager、开发人员工具(如CLI&#xff09…

要查询 `user` 表中 `we_chat_subscribe` 和 `we_chat_union_id` 列不为空的用户数量

文章目录 1、we_chat_subscribe2、we_chat_union_id 1、we_chat_subscribe 要查询 user 表中 we_chat_subscribe 列不为空的用户数量,你可以使用以下 SQL 查询语句: SELECT COUNT(*) FROM user WHERE we_chat_subscribe IS NOT NULL;解释: …

【论文复现】进行不同视角图像的拼接

📝个人主页🌹:Eternity._ 🌹🌹期待您的关注 🌹🌹 ❀ 进行不同视角图像的拼接 背景描述算法简介SIFT算法原理代码原理代码部署核心代码拼接结果其他的图片如何进行拼接? 修改内容&…

xxl-job 简单的入门到实战

本文是参考官方文档自己实践一次,纯享版,大致也是作者边写博客边去跟着官方文档实现 一、前期准备 1、官网地址 GitHub地址: GitHub - xuxueli/xxl-job: A distributed task scheduling framework.(分布式任务调度平台XXL-JOB&…

数字后端培训项目Floorplan常见问题系列专题续集1

今天继续给大家分享下数字IC后端设计实现floorplan阶段常见问题系列专题。这些问题都是来自于咱们社区IC后端训练营学员提问的问题库。目前这部分问题库已经积累了4年了,后面会陆续分享这方面的问题。 希望对大家的数字后端学习和工作有所帮助。 数字后端项目Floor…

江苏捷科云:可视化平台助力制造企业智能化管理

公司简介 江苏捷科云信息科技有限公司(以下简称“捷科”)是一家专注于云平台、云储存、云管理等产品领域的创新型企业,集研发、生产和销售于一体,致力于在网络技术领域打造尖端品牌。在推动制造业企业数字化转型的进程中&#xf…

【视觉惯性SLAM:对极几何】

对极几何(Epipolar Geometry)介绍 对极几何是立体视觉中的核心内容之一,它描述了两个相机在观察同一个三维场景时,成像平面之间的几何关系。对极几何能够约束图像中对应点的位置关系,是双目立体匹配、三维重建、以及位…

从Condition开始,回顾AQS

Synchronized和Reentrantlock的挂起逻辑 synchronized中有两个核心的结构 EntryList cxq:等待拿锁的线程存储位置Waitset:被执行wait方法的线程存储位置 流转: 线程获取锁资源失败,扔到EntryList cxq线程持有锁资源&#x…

umi : 无法加载文件 D:\software\nodejs\node_global\umi.ps1,因为在此系统上禁止运行脚本。

问题详情 2、解决方法 1.使用命令 get-ExecutionPolicy查看 显示Restricted:限制 所以要给权限 2. 使用命令:Set-ExecutionPolicy -Scope CurrentUser 3. 会提示为参数提供值 4. 输入: RemoteSigned 具体如下图所示,成功解决。 报…

Redis篇--常见问题篇4--大Key(Big Key,什么是大Key,影响及使用建议)

1、概述 大Key:通常是指值(Value)的长度非常大,实际上键(Key)长度很大也算。通常来说,键本身不会很长,占用的内存较少,因此判断一个键是否为bigKey主要看它对应的值的大…

02、并发编程的三大特性

并发编程有三大特性分别是,原子性,可见性,有序性。会产生这些特性的根本原因是现在的服务器都是多CPU多核心数的,每个CPU都有自己单独的一套缓存和pc系统,而且程序在运行时按照JMM的规范,它们是需要先把数据…

基于Java+Jsp Servlet Mysql实现的Java Web在线商城项目系统设计与实现

一、前言介绍: 1.1 项目摘要 随着互联网技术的飞速发展,电子商务已成为现代商业活动的重要组成部分。在线商城作为电子商务的一种重要形式,以其便捷性、高效性和广泛覆盖性,受到了越来越多消费者的青睐。同时,随着消…

【安全测试相关知识】

安全测试介绍 背景 在当前信息技术快速发展的背景下,网络安全问题日益严峻,数据泄露、黑客攻击、病毒传播等安全事件层出不穷,给个人、企业乃至国家带来严重威胁。所以安全测试已成为企业和国家关注的重心 作用 安全测试是确保软件系统安…

WPS如何快速将数字金额批量转换成中文大写金额,其实非常简单

大家好,我是小鱼。 在日常的工作中经常会遇到需要使用金额大写的情况,比如说签订业务合同时一般都会标注大写金额,这样是为了安全和防止串改。但是很多人也许不太熟悉金额大写的方法和习惯,其它没有关系,我们在用WPS制…

Element-ui的使用教程 基于HBuilder X

文章目录 1.Element-ui简介2.使用HBuilderX 创建一个基于Vue3的项目 (由于是使用的基于Vue3的Element-ui)3.安装element-ui4.在项目里完全引用element-ui5.引用组件6.运行项目 1.Element-ui简介 Element,一套为开发者、设计师和产品经理准备…

MySQL的架构设计和设计模式

1. 数据库设计模式与范式 数据库设计模式是解决数据库设计中常见问题的一种思维方式,它提供了一套解决方案。以下是一些常见的数据库设计模式和范式: 实体-关系模型(Entity-Relationship Model):通过实体和实体之间的…

【MySQL】十三,关于MySQL的全文索引

MySQL的全文索引用于搜索文本中的关键字,类似于like查询。 演示 建表 CREATE TABLE demo (id INT(11) NOT NULL,name CHAR(30) NOT NULL,age INT(11) NOT NULL,info VARCHAR(255),primary key(id),fulltext index futxt_idx_info(info) );此表的默认存储引擎为In…