AIGC系列博文:
【AIGC系列】1:自编码器(AutoEncoder, AE)
【AIGC系列】2:DALL·E 2模型介绍(内含扩散模型介绍)
【AIGC系列】3:Stable Diffusion模型原理介绍
【AIGC系列】4:Stable Diffusion应用实践和代码分析
【AIGC系列】5:视频生成模型数据处理和预训练流程介绍(Sora、MovieGen、HunyuanVideo)
【AIGC系列】6:HunyuanVideo视频生成模型部署和代码分析
目录
- 1 前言
- 2 部署
- 2.1 环境配置
- 2.1.1 方法一:使用Open-R1的环境
- 2.1.2 方法二:使用官方Docker
- 2.2 下载预训练模型
- 2.2.1 混元Diffusion模型和VAE模型
- 2.2.2 text-encoder-tokenizer
- 2.2.3 CLIP模型
- 2.3 视频生成命令
- 3 源码分析
- 3.1 推理流程
- 3.2 模型初始化
- 3.3 模型推理
- 3.4 模型结构
- 3.4.1 双流块 (MMDoubleStreamBlock)
- 3.4.2 单流块 (MMSingleStreamBlock)
- 3.4.3 混元主干网络(HYVideoDiffusionTransformer)
1 前言
先展示一下结果。
生成540p视频,推理占用的显存超过40G:(720p x 1280p 分辨率大约需要76G显存,544p x 960p分辨率大约需要43G显存)
迭代50步,需要超过一个半小时,还是挺久的。
Prompt:A cat walks on the grass, realistic style.
生成的视频结果如下,效果还是很逼真的,包括光影、猫毛的细节、景深等等。
视频链接:哔哩哔哩:https://b23.tv/Boazciy
当然,用8卡推理可以更快一点,生成5s的视频耗时约20分钟,整体效果也还可以。
哔哩哔哩:https://b23.tv/vBiXMKT
2 部署
2.1 环境配置
2.1.1 方法一:使用Open-R1的环境
HuyuanVideo也使用到flash attention了,为了方便起见,我们使用Open-R1的环境来跑HunyuanVideo。
Open-R1的环境配置详细步骤请参考博文:【复现DeepSeek-R1之Open R1实战】系列1:跑通SFT(一步步操作,手把手教学)。
将源码clone下来之后,接下来我们安装HunyuanVideo的依赖库:
python -m pip install -r requirements.txt
2.1.2 方法二:使用官方Docker
当然,我们也可以直接使用官方提供的docker:
# 1. Use the following link to download the docker image tar file (For CUDA 12).
wget https://aivideo.hunyuan.tencent.com/download/HunyuanVideo/hunyuan_video_cu12.tar
# 2. Import the docker tar file and show the image meta information (For CUDA 12).
docker load -i hunyuan_video_cu12.tar
docker image ls
# 3. Run the container based on the image
docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged docker_image_tag
推荐使用方法一,方法二我尝试了一下,报了个torch的错误,后来我没接着往下解决,GitHub上也有小伙伴反馈cuda12的docker跑起来会有些问题,感兴趣的小伙伴也可使用cuda11.8的docker。
# For CUDA 12.4 (updated to avoid float point exception)
docker pull hunyuanvideo/hunyuanvideo:cuda_12
docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_12
# For CUDA 11.8
docker pull hunyuanvideo/hunyuanvideo:cuda_11
docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_11
2.2 下载预训练模型
我们需要下载的模型包括:混元Diffusion模型、VAE模型、text-encoder-tokenizer模型以及CLIP模型。
2.2.1 混元Diffusion模型和VAE模型
HuggingFace:https://huggingface.co/tencent/HunyuanVideo/tree/main。
建议下载FP8量化模型,推理时总共占用显存40多G。
Diffusion模型放入ckpts/hunyuan-video-t2v-720p/transformers目录,VAE模型放入ckpts/hunyuan-video-t2v-720p/vae目录。
2.2.2 text-encoder-tokenizer
我们可以直接下载text-encoder-tokenizer模型:https://huggingface.co/Kijai/llava-llama-3-8b-text-encoder-tokenizer,保存在ckpts/text_encoder文件夹中。
官方的操作是先下载完整版LLaVA 模型,然后再把文本编码器(language_model)和分词器(tokenizer)提取出来,多了一步提取的操作。
2.2.3 CLIP模型
下载CLIP-ViT-L模型:https://huggingface.co/openai/clip-vit-large-patch14,将模型存放到ckpts/text_encoder_2文件夹中。
最终保存的预训练模型路径如下所示:
HunyuanVideo
├──ckpts
│ ├──README.md
│ ├──hunyuan-video-t2v-720p
│ │ ├──transformers
│ │ │ ├──mp_rank_00_model_states.pt
│ │ │ ├──mp_rank_00_model_states_fp8.pt
│ │ │ ├──mp_rank_00_model_states_fp8_map.pt
├ │ ├──vae
│ ├──text_encoder
│ ├──text_encoder_2
├──…
2.3 视频生成命令
使用FP8模型推理:
cd HunyuanVideo
python3 sample_video.py \
--dit-weight ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt \
--video-size 1280 720 \
--video-length 129 \
--infer-steps 50 \
--prompt "A cat walks on the grass, realistic style." \
--seed 42 \
--embedded-cfg-scale 6.0 \
--flow-shift 7.0 \
--flow-reverse \
--use-cpu-offload \
--use-fp8 \
--save-path ./results
–dit-weight指定FP8模型的路径;–use-fp8是指模型的格式是FP8。
参数说明:
Argument | Default | Description |
---|---|---|
--prompt | None | The text prompt for video generation |
--video-size | 720 1280 | The size of the generated video |
--video-length | 129 | The length of the generated video |
--infer-steps | 50 | The number of steps for sampling |
--embedded-cfg-scale | 6.0 | Embedded Classifier free guidance scale |
--flow-shift | 7.0 | Shift factor for flow matching schedulers |
--flow-reverse | False | If reverse, learning/sampling from t=1 -> t=0 |
--seed | None | The random seed for generating video, if None, we init a random seed |
--use-cpu-offload | False | Use CPU offload for the model load to save more memory, necessary for high-res video generation |
--save-path | ./results | Path to save the generated video |
当然,我们也可以用多卡推理:
cd HunyuanVideo
torchrun --nproc_per_node=8 sample_video.py \
--video-size 1280 720 \
--video-length 129 \
--infer-steps 50 \
--prompt "A cat walks on the grass, realistic style." \
--flow-reverse \
--seed 42 \
--ulysses-degree 8 \
--ring-degree 1 \
--save-path ./results
3 源码分析
3.1 推理流程
上面我们展示的demo的主文件是sample_video.py,定义了一个main函数,用于加载模型并生成视频样本,包含整体推理流程。
def main():
# 调用 parse_args() 函数来解析命令行参数,并将结果存储在 args 中。
args = parse_args()
print(args)
#使用 args.model_base 指定模型的基础路径。检查路径是否存在,如果路径不存在,抛出异常
models_root_path = Path(args.model_base)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
# 创建保存目录
save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}'
if not os.path.exists(args.save_path):
os.makedirs(save_path, exist_ok=True)
# 加载指定路径下的预训练模型,并传入 args 参数
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
# 更新 args 为模型内部参数,确保之后使用的一致性
args = hunyuan_video_sampler.args
# 生成视频样本
outputs = hunyuan_video_sampler.predict(
prompt=args.prompt, # 文本提示词,用于引导生成内容。
height=args.video_size[0],# 视频帧的分辨率(高度和宽度)
width=args.video_size[1],
video_length=args.video_length, # 视频时长(帧数)。
seed=args.seed, # 随机种子,用于结果的可重复性
negative_prompt=args.neg_prompt, # 负向提示词,指定生成时需避免的特性
infer_steps=args.infer_steps, # 推理步数
guidance_scale=args.cfg_scale,# 引导系数,控制生成质量。
num_videos_per_prompt=args.num_videos, # 每个提示生成的视频数量。
flow_shift=args.flow_shift, # 时间帧间的流动控制参数
batch_size=args.batch_size, # 批处理大小
embedded_guidance_scale=args.embedded_cfg_scale # 内嵌引导系数,用于调节特定特征
)
samples = outputs['samples']
# 保存视频样本
# 检查环境变量 LOCAL_RANK 是否存在,用于分布式训练的本地进程控制:如果不存在,或者值为 0(即主进程),则继续保存样本。
if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
# 遍历生成的samples
for i, sample in enumerate(samples):
sample = samples[i].unsqueeze(0)
# 添加时间戳
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"
save_videos_grid(sample, save_path, fps=24) # 保存视频
logger.info(f'Sample save to: {save_path}') # 日志记录
3.2 模型初始化
加载模型的时候,调用了HunYuanVideoSampler的from_pretrained()初始化实例,改方法是在其Inference实现的(hyvideo/inference.py文件),功能包括初始化vae、text_encoder、扩散模型等核心部件,然后通过cls初始化并返回一个实例,而HunYuanVideoSampler类继承了父类的from_pretrained()方法,因此这里cls返回的是HunYuanVideoSampler的实例。
class Inference(object):
...
@classmethod
def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs):
...
in_channels = args.latent_channels # 16
out_channels = args.latent_channels # 16
model = load_model( # HYVideoDiffusionTransformer
args,
in_channels=in_channels,
out_channels=out_channels,
factor_kwargs=factor_kwargs,
)
...
# VAE
vae, _, s_ratio, t_ratio = load_vae( # AutoencoderKLCausal3D
args.vae,
args.vae_precision,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
# Text encoder
if args.prompt_template_video is not None:
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get( # 95
"crop_start", 0
)
...
max_length = args.text_len + crop_start # 256+95
# prompt_template
prompt_template = (
PROMPT_TEMPLATE[args.prompt_template]
if args.prompt_template is not None
else None
)
# prompt_template_video
prompt_template_video = (
PROMPT_TEMPLATE[args.prompt_template_video]
if args.prompt_template_video is not None
else None
)
text_encoder = TextEncoder( # text-encoder-tokenizer
text_encoder_type=args.text_encoder, # llava
max_length=max_length,
text_encoder_precision=args.text_encoder_precision,
tokenizer_type=args.tokenizer,
prompt_template=prompt_template,
prompt_template_video=prompt_template_video,
hidden_state_skip_layer=args.hidden_state_skip_layer,
apply_final_norm=args.apply_final_norm,
reproduce=args.reproduce,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
text_encoder_2 = None
if args.text_encoder_2 is not None:
text_encoder_2 = TextEncoder(
text_encoder_type=args.text_encoder_2, # clipL
max_length=args.text_len_2,
text_encoder_precision=args.text_encoder_precision_2,
tokenizer_type=args.tokenizer_2,
reproduce=args.reproduce,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
return cls( # 初始化本类的一个实例
args=args,
vae=vae, # AutoencoderKLCausal3D
vae_kwargs=vae_kwargs,
text_encoder=text_encoder, # llm
text_encoder_2=text_encoder_2,
model=model,
use_cpu_offload=args.use_cpu_offload,
device=device,
logger=logger,
)
最后使用了cls(),会调用HunyuanVideoSampler的初始化方法__init__(),指定所有组件并将他们组合到一个pipeline,包括模型、调度器(scheduler)、设备配置等必要的组件,然后指定负面提示词。
class HunyuanVideoSampler(Inference):
def __init__(...):
super().__init__(...)
self.pipeline = self.load_diffusion_pipeline( # 组合所有原件
args=args,
vae=self.vae,
text_encoder=self.text_encoder,
text_encoder_2=self.text_encoder_2,
model=self.model,
device=self.device,
)
self.default_negative_prompt = NEGATIVE_PROMPT # 负面提示词
def load_diffusion_pipeline(
self,
args,
vae,
text_encoder,
text_encoder_2,
model,
scheduler=None,
device=None,
progress_bar_config=None,
data_type="video",
):
"""Load the denoising scheduler for inference."""
# 去噪调度器的初始化
if scheduler is None:
if args.denoise_type == "flow":
# 流动匹配的去噪策略,离散去噪调度器,可能用于视频生成任务中时间帧之间的一致性建模。
# 负责指导扩散模型逐步还原噪声,生成清晰的视频帧。
scheduler = FlowMatchDiscreteScheduler(
shift=args.flow_shift, # 流动偏移值。
reverse=args.flow_reverse, # 是否反向计算。
solver=args.flow_solver, # 去噪求解器的类型
)
else:
raise ValueError(f"Invalid denoise type {args.denoise_type}")
# 构建推理pipeline
pipeline = HunyuanVideoPipeline(
vae=vae, # 负责特征编码和解码的模块。
text_encoder=text_encoder, # 用于处理文本提示,生成与视频生成相关的特征。
text_encoder_2=text_encoder_2,
transformer=model, # 主扩散模型,生成视频的核心模块。
scheduler=scheduler, # 去噪调度器,控制扩散生成的时间步长和过程
progress_bar_config=progress_bar_config, # 可选的进度条配置,用于显示推理进度。
args=args, # 配置参数的集合
)
# 配置计算资源
if self.use_cpu_offload:
# 将部分计算任务卸载到 CPU。这是显存不足时的优化策略,可以大幅降低 GPU 的显存占用。
pipeline.enable_sequential_cpu_offload()
else:
# 如果为 False,直接将管道加载到指定的 device(如 GPU)上运行
pipeline = pipeline.to(device)
return pipeline
提示词如下:
PROMPT_TEMPLATE_ENCODE = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
)
PROMPT_TEMPLATE_ENCODE_VIDEO = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
)
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
PROMPT_TEMPLATE = {
"dit-llm-encode": {
"template": PROMPT_TEMPLATE_ENCODE,
"crop_start": 36,
},
"dit-llm-encode-video": {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
"crop_start": 95,
},
}
3.3 模型推理
位置编码使用的是RoPE,主要根据输入的视频维度、网络配置以及位置嵌入参数,生成对应的正弦和余弦频率嵌入。
def get_rotary_pos_embed(self, video_length, height, width):
# video_length: 视频的帧长度。
# height, width: 视频的帧高和帧宽。 目标是根据这些维度计算位置嵌入。
# 表示生成的 RoPE 的目标维度(3D: 时间维度 + 空间高度和宽度)
target_ndim = 3
# 推导潜在特征(latent feature)所需维度的辅助变量
ndim = 5 - 2
# 根据 self.args.vae 中的配置(例如 VAE 模型类型 884 或 888),确定潜在特征的空间尺寸 latents_size
# 884: 时间维度下采样 4 倍(1/4),空间高宽下采样 8 倍(1/8)。
if "884" in self.args.vae:
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
# 888: 时间维度下采样 8 倍(1/8),空间高宽下采样 8 倍。
elif "888" in self.args.vae:
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
# 默认情况下不对时间维度下采样,但高宽依然下采样 8 倍。
else:
latents_size = [video_length, height // 8, width // 8]
# 检查潜在空间尺寸是否与 Patch 尺寸兼容
# 如果 self.model.patch_size 是单个整数,检查潜在特征维度的每一维是否能被 patch_size 整除。
if isinstance(self.model.patch_size, int):
assert all(s % self.model.patch_size == 0 for s in latents_size), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
f"but got {latents_size}."
)
# 如果整除,计算 RoPE 的输入尺寸 rope_sizes(将 latents_size 每一维除以 patch_size)
rope_sizes = [s // self.model.patch_size for s in latents_size]
# 如果 self.model.patch_size 是一个列表,分别对每一维进行整除检查和计算。
elif isinstance(self.model.patch_size, list):
assert all(
s % self.model.patch_size[idx] == 0
for idx, s in enumerate(latents_size)
), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
f"but got {latents_size}."
)
rope_sizes = [
s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)
]
# 如果 rope_sizes 的维度数不足 target_ndim,在开头补充时间维度(值为 1)。
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
# head_dim 是单个注意力头的维度大小,由模型的 hidden_size 和 heads_num 计算得出。
head_dim = self.model.hidden_size // self.model.heads_num
# rope_dim_list 是用于位置嵌入的维度分配列表:
# 如果未定义,默认将 head_dim 平均分配到 target_ndim(时间、高度、宽度)。
rope_dim_list = self.model.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert (
sum(rope_dim_list) == head_dim
), "sum(rope_dim_list) should equal to head_dim of attention layer"
# 调用 get_nd_rotary_pos_embed 函数,计算基于目标尺寸 rope_sizes 和维度分配 rope_dim_list 的多维旋转位置嵌入。
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=self.args.rope_theta, #控制位置嵌入频率。
use_real=True, # 表示使用真实数值而非复数形式。
theta_rescale_factor=1, # 无缩放因子。
)
#返回: freqs_cos: 余弦频率嵌入。freqs_sin: 正弦频率嵌入。
return freqs_cos, freqs_sin
predict 函数用于从文本生成视频或图像的预测函数,通过输入文本 prompt,结合其他参数(如视频分辨率、帧数、推理步数等),生成指定数量的视频或图像。
@torch.no_grad()
def predict(
self,
prompt,
height=192,
width=336,
video_length=129,
seed=None,
negative_prompt=None,
infer_steps=50,
guidance_scale=6,
flow_shift=5.0,
embedded_guidance_scale=None,
batch_size=1,
num_videos_per_prompt=1,
**kwargs,
):
"""
Predict the image/video from the given text.
Args:
prompt (str or List[str]): The input text.
kwargs:
height (int): The height of the output video. Default is 192.
width (int): The width of the output video. Default is 336.
video_length (int): The frame number of the output video. Default is 129.
seed (int or List[str]): The random seed for the generation. Default is a random integer.
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
guidance_scale (float): The guidance scale for the generation. Default is 6.0.
num_images_per_prompt (int): The number of images per prompt. Default is 1.
infer_steps (int): The number of inference steps. Default is 100.
"""
# 分布式环境检查
if self.parallel_args['ulysses_degree'] > 1 or self.parallel_args['ring_degree'] > 1:
assert seed is not None, \
"You have to set a seed in the distributed environment, please rerun with --seed <your-seed>."
# 满足分布式环境的条件,调用 parallelize_transformer 函数并行化模型
parallelize_transformer(self.pipeline)
# 初始化一个空字典 out_dict,用于存储最终的生成结果。
out_dict = dict()
# ========================================================================
# 根据传入的 seed 参数生成一组随机种子,并将这些种子用于初始化随机数生成器 (torch.Generator) 来控制生成过程的随机性。
# ========================================================================
# 根据 seed 参数的类型(None、int、list、tuple 或 torch.Tensor),执行不同的逻辑,生成用于控制随机数生成器的 seeds 列表
if isinstance(seed, torch.Tensor):
seed = seed.tolist()
if seed is None:
seeds = [
random.randint(0, 1_000_000)
for _ in range(batch_size * num_videos_per_prompt)
]
elif isinstance(seed, int):
seeds = [
seed + i
for _ in range(batch_size)
for i in range(num_videos_per_prompt)
]
elif isinstance(seed, (list, tuple)):
if len(seed) == batch_size:
seeds = [
int(seed[i]) + j
for i in range(batch_size)
for j in range(num_videos_per_prompt)
]
elif len(seed) == batch_size * num_videos_per_prompt:
seeds = [int(s) for s in seed]
else:
raise ValueError(
f"Length of seed must be equal to number of prompt(batch_size) or "
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
)
else:
raise ValueError(
f"Seed must be an integer, a list of integers, or None, got {seed}."
)
# 对每个种子,在指定设备(self.device)上创建一个 PyTorch 的随机数生成器 torch.Generator,并使用对应的种子进行手动初始化
# (manual_seed(seed))。将这些生成器存储在列表 generator 中。
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
# 将生成的 seeds 列表存储在 out_dict 中,供后续使用(可能用于复现生成结果或记录生成过程的随机性)。
out_dict["seeds"] = seeds
# ========================================================================
# 检查和调整视频生成的输入参数(height、width 和 video_length)的合法性与对齐要求,并计算出调整后的目标尺寸。
# ========================================================================
# 检查输入的 height、width 和 video_length 是否为正整数:
if width <= 0 or height <= 0 or video_length <= 0:
raise ValueError(
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}"
)
# 检查 video_length - 1 是否为 4 的倍数
if (video_length - 1) % 4 != 0:
raise ValueError(
f"`video_length-1` must be a multiple of 4, got {video_length}"
)
# 日志记录
logger.info(
f"Input (height, width, video_length) = ({height}, {width}, {video_length})"
)
# 目标高度和宽度对齐到 16 的倍数
target_height = align_to(height, 16)
target_width = align_to(width, 16)
target_video_length = video_length
# 存储目标尺寸
out_dict["size"] = (target_height, target_width, target_video_length)
# ========================================================================
# 检查和处理文本生成任务中的 prompt 和 negative_prompt 参数
# ========================================================================
# 确保输入的 prompt 是字符串类型
if not isinstance(prompt, str):
raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
prompt = [prompt.strip()] # 对 prompt 去除首尾多余的空格(使用 .strip()),然后包装成一个单元素列表
# 处理 negative_prompt 参数
if negative_prompt is None or negative_prompt == "":
negative_prompt = self.default_negative_prompt
if not isinstance(negative_prompt, str):
raise TypeError(
f"`negative_prompt` must be a string, but got {type(negative_prompt)}"
)
negative_prompt = [negative_prompt.strip()]
# ========================================================================
# 设置调度器 (Scheduler)
# ========================================================================
scheduler = FlowMatchDiscreteScheduler( # 处理流(Flow)的调度
shift=flow_shift, # 控制流动调度器的偏移量。flow_shift 通常与时序或流动模型相关,例如调整时间步之间的关系。
reverse=self.args.flow_reverse, # 决定是否反向调度(可能是在推理过程中逆序生成帧)
solver=self.args.flow_solver # 指定用于调度的解算器类型(solver),例如选择数值方法来优化时间步间的计算。
)
self.pipeline.scheduler = scheduler
# ========================================================================
# 构建旋转位置嵌入 (Rotary Positional Embedding)
# ========================================================================
# 根据目标视频长度、高度和宽度生成正弦 (freqs_sin) 和余弦 (freqs_cos) 的频率嵌入。
freqs_cos, freqs_sin = self.get_rotary_pos_embed(
target_video_length, target_height, target_width
)
# 表示视频中总的编码标记数(tokens),通常等于时间步数(帧数)与空间分辨率(像素数)相乘。
n_tokens = freqs_cos.shape[0]
# ========================================================================
# 打印推理参数
# ========================================================================
debug_str = f"""
height: {target_height}
width: {target_width}
video_length: {target_video_length}
prompt: {prompt}
neg_prompt: {negative_prompt}
seed: {seed}
infer_steps: {infer_steps}
num_videos_per_prompt: {num_videos_per_prompt}
guidance_scale: {guidance_scale}
n_tokens: {n_tokens}
flow_shift: {flow_shift}
embedded_guidance_scale: {embedded_guidance_scale}"""
logger.debug(debug_str)
# ========================================================================
# Pipeline inference
# ========================================================================
start_time = time.time()
samples = self.pipeline(
prompt=prompt, # 文本提示,用于指导生成内容。
height=target_height, # 生成图像或视频帧的分辨率。
width=target_width, #
video_length=target_video_length, # 视频的帧数。如果 video_length > 1,表示生成视频;否则生成单张图像。
num_inference_steps=infer_steps, # 推理步数,决定生成过程的细粒度程度,步数越多,生成结果越精细。
guidance_scale=guidance_scale, # 指导比例,控制生成与 prompt 的一致性程度。
negative_prompt=negative_prompt, # 负面提示,用于约束生成内容,避免不期望的结果。
num_videos_per_prompt=num_videos_per_prompt, # 每条提示生成的视频数量。
generator=generator, # 随机生成器对象,用于控制生成过程中的随机性,通常与随机种子结合。
output_type="pil", # 指定输出格式为 PIL.Image 对象,便于后续处理
freqs_cis=(freqs_cos, freqs_sin), # 旋转位置嵌入 (RoPE) 的频率矩阵,增强时空位置感知能力。
n_tokens=n_tokens, # 输入序列的总标记数,用于指导生成过程。
embedded_guidance_scale=embedded_guidance_scale, # 嵌入式指导比例,用于进一步优化嵌入向量的生成。
data_type="video" if target_video_length > 1 else "image", # 指定生成目标为视频或图像,取决于帧数。
is_progress_bar=True, # 显示推理进度条,方便监控生成进度。
vae_ver=self.args.vae, # 使用指定版本的 VAE(变分自编码器),决定生成内容的潜在空间。
enable_tiling=self.args.vae_tiling, # 启用 VAE 分块处理,提高内存效率,特别适用于高分辨率生成。
)[0] # 返回生成的样本,通常是一个 PIL.Image 或视频帧序列
# 保存生成结果
out_dict["samples"] = samples
out_dict["prompts"] = prompt
# 计算并记录推理时间
gen_time = time.time() - start_time
logger.info(f"Success, time: {gen_time}")
return out_dict
推理pipeline(hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py):call 方法接受用户输入的提示、生成图像或视频的尺寸,以及其他生成过程的参数,完成推理并返回生成的图像或视频。
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]],
height: int,
width: int,
video_length: int,
data_type: str = "video",
num_inference_steps: int = 50,
timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[
Union[
Callable[[int, int, Dict], None],
PipelineCallback,
MultiPipelineCallbacks,
]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
vae_ver: str = "88-4c-sd",
enable_tiling: bool = False,
n_tokens: Optional[int] = None,
embedded_guidance_scale: Optional[float] = None,
**kwargs,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
video_length (`int`):
The number of frames in the generated video.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~HunyuanVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# 处理与回调函数相关的参数,同时对已弃用的参数发出警告(deprecation warnings)。
# 它还检查了新的回调函数机制 callback_on_step_end 是否符合预期类型。
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. Default height and width to unet
# height = height or self.transformer.config.sample_size * self.vae_scale_factor
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. 验证输入参数是否合法。
self.check_inputs(
prompt,
height,
width,
video_length,
callback_steps, # 回调频率,指定在生成过程中每隔多少步执行一次回调。
negative_prompt,
prompt_embeds, # 预嵌入的提示词和反向提示词。如果已经对文本进行了嵌入处理,可以直接传递这些值,而不是原始文本。
negative_prompt_embeds,
callback_on_step_end_tensor_inputs, # 与回调机制相关的数据张量。
vae_ver=vae_ver, # 可选参数,可能指定生成内容时使用的 VAE(变分自动编码器)的版本。
)
# 控制生成内容的引导强度。一般用于调整模型对 prompt(提示词)的依赖程度。较大的值会让生成内容更接近提示词,但可能导致丢失多样性。
self._guidance_scale = guidance_scale
# 用于重新调整指导比例,可能是对 guidance_scale 的一种动态调整。用于平衡模型在特定生成任务中的表现。
self._guidance_rescale = guidance_rescale
# 控制是否在 CLIP 模型中跳过某些层的计算。在某些生成任务中,跳过部分层可以改善生成质量。
self._clip_skip = clip_skip
# 与交叉注意力(Cross Attention)相关的参数。可能包括对注意力权重的控制,比如调整注意力机制如何在提示词和生成内容之间分配权重。
self._cross_attention_kwargs = cross_attention_kwargs
# 标志可能在生成过程的某些阶段被动态修改。
# 如果 _interrupt 被设置为 True,生成过程可能会被中止。这种设计通常用于在用户希望终止长时间生成任务时使用。
self._interrupt = False
# 2. 根据输入的提示词 prompt 或嵌入 prompt_embeds,确定生成任务的批量大小(batch_size)。
if prompt is not None and isinstance(prompt, str):
batch_size = 1 # 如果 prompt 是单个字符串,说明只有一个提示词。批量大小设置为 1。
elif prompt is not None and isinstance(prompt, list):
# 如果 prompt 是一个列表,说明有多个提示词。
# 此时,批量大小等于提示词的数量,即 len(prompt)。
batch_size = len(prompt)
else:
#如果 prompt 是 None,说明提示词未提供,可能直接使用预先计算的嵌入 prompt_embeds。
# 此时,批量大小由 prompt_embeds 的第一维(通常是样本数量)决定。
batch_size = prompt_embeds.shape[0]
# 确定设备的device
device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
# 3. Encode input prompt
# 处理 LoRA(Low-Rank Adaptation)缩放系数:通过 cross_attention_kwargs 提取或设置缩放比例 lora_scale
lora_scale = (
self.cross_attention_kwargs.get("scale", None)
if self.cross_attention_kwargs is not None
else None
)
# 对提示词进行编码:将文本提示词 prompt 和负向提示词 negative_prompt 编码为嵌入向量,并生成对应的注意力掩码。
(
prompt_embeds, # 正向提示词的嵌入向量。
negative_prompt_embeds, # 负向提示词的嵌入向量。
prompt_mask, # 正向提示词的注意力掩码。
negative_prompt_mask, # 负向提示词的注意力掩码。
) = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
attention_mask=attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_attention_mask=negative_attention_mask,
lora_scale=lora_scale,
clip_skip=self.clip_skip,
data_type=data_type,
)
# 处理多文本编码器:若存在额外的文本编码器 text_encoder_2,使用该编码器再次处理提示词。
if self.text_encoder_2 is not None:
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_mask_2,
negative_prompt_mask_2,
) = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=None,
attention_mask=None,
negative_prompt_embeds=None,
negative_attention_mask=None,
lora_scale=lora_scale,
clip_skip=self.clip_skip,
text_encoder=self.text_encoder_2,
data_type=data_type,
)
else:
prompt_embeds_2 = None
negative_prompt_embeds_2 = None
prompt_mask_2 = None
negative_prompt_mask_2 = None
# 处理自由分类指导(Classifier-Free Guidance):为实现该技术,合并正向和负向提示词嵌入,避免多次前向传递。
if self.do_classifier_free_guidance:
# 功能:如果启用了自由分类指导,则将正向和负向提示词的嵌入和掩码合并为一个批次。
# 原因:自由分类指导需要两次前向传递:一次处理负向提示词(指导无条件生成),一次处理正向提示词(指导条件生成)。
# 为了提高效率,将两组嵌入拼接在一起,作为一个批次传递给模型,避免两次单独的前向传递。
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if prompt_mask is not None:
prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
if prompt_embeds_2 is not None:
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
if prompt_mask_2 is not None:
prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
# 4. Prepare timesteps
# 准备调度器的额外参数
extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
self.scheduler.set_timesteps, {"n_tokens": n_tokens}
)
# 获取推理过程中需要用到的时间步 (timesteps) 和推理步数 (num_inference_steps)。
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
**extra_set_timesteps_kwargs,
)
# 根据 vae_ver 调整视频长度
if "884" in vae_ver:
video_length = (video_length - 1) // 4 + 1
elif "888" in vae_ver:
video_length = (video_length - 1) // 8 + 1
else:
video_length = video_length
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
video_length,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_func_kwargs(
self.scheduler.step, # 扩散模型的调度器中的 step 方法,负责更新噪声预测结果。
{"generator": generator, "eta": eta}, # 一个字典,包含生成器 generator 和步长相关参数 eta。
)
# 确定目标数据类型及自动混合精度的设置
target_dtype = PRECISION_TO_TYPE[self.args.precision]
autocast_enabled = (
target_dtype != torch.float32
) and not self.args.disable_autocast
# 确定 VAE 的数据类型及自动混合精度设置
vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
vae_autocast_enabled = (
vae_dtype != torch.float32
) and not self.args.disable_autocast
# 7. 初始化去噪循环的预处理步骤
# timesteps:调度器生成的时间步序列。
# num_inference_steps:推理过程中真正的去噪步数。
# self.scheduler.order:调度器的阶数(通常与预测算法的高阶插值相关)。
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
# if is_progress_bar:
# progress_bar 用于显示推理过程的进度,num_inference_steps 是总推理步数。
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# 如果启用了 分类器自由指导(do_classifier_free_guidance),则将 latents 复制两份,用于同时计算 条件预测 和 无条件预测。
# 否则,仅使用原始 latents。
latent_model_input = (
torch.cat([latents] * 2)
if self.do_classifier_free_guidance
else latents
)
# 调用 scheduler 的 scale_model_input 方法,对 latent_model_input 在当前时间步 t 上进行预处理。
# 这个缩放操作可能根据调度器的实现涉及到归一化或其他调整。
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
# t_expand 将时间步 t 扩展到与 latent_model_input 的批量维度一致。
# 如果 embedded_guidance_scale 存在,则创建扩展的指导参数 guidance_expand,用于对模型预测进行额外控制。
t_expand = t.repeat(latent_model_input.shape[0])
guidance_expand = (
torch.tensor(
[embedded_guidance_scale] * latent_model_input.shape[0],
dtype=torch.float32,
device=device,
).to(target_dtype)
* 1000.0
if embedded_guidance_scale is not None
else None
)
# 使用 Transformer 模型预测噪声残差
with torch.autocast(
device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
):
noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
latent_model_input, # 当前的潜变量输入 [2, 16, 33, 24, 42]
t_expand, # 时间步信息 [2]
text_states=prompt_embeds, # 与文本提示相关的嵌入向量 [2, 256, 4096]
text_mask=prompt_mask, # [2, 256]
text_states_2=prompt_embeds_2, # [2, 768]
freqs_cos=freqs_cis[0], # 频率信息,用于特定的时间步缩放 [seqlen, head_dim]
freqs_sin=freqs_cis[1], # [seqlen, head_dim]
guidance=guidance_expand, # 指导参数,用于条件生成
return_dict=True,
)[
"x"
]
# 分类器自由指导的噪声调整
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) # 无条件预测的噪声;条件预测的噪声(基于文本提示)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# 噪声重缩放
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(
noise_pred,
noise_pred_text,
guidance_rescale=self.guidance_rescale,
)
# 使用调度器更新潜变量
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
# callback_on_step_end 函数,则在每步结束时调用,用于自定义操作(如日志记录、结果保存)。
# 更新潜变量和提示嵌入向量。
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop(
"negative_prompt_embeds", negative_prompt_embeds
)
# 进度条更新与其他回调
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
if progress_bar is not None:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# 从潜变量(latent space)解码生成图像
if not output_type == "latent":
# 潜变量维度的扩展检查
expand_temporal_dim = False
# 如果形状为 4D ([batch_size, channels, height, width]):
# 如果 VAE 是 3D 自回归模型(AutoencoderKLCausal3D),则对潜变量增加一个时间维度 (unsqueeze(2))。
# 设置 expand_temporal_dim=True,标记后续需要移除该额外维度。
if len(latents.shape) == 4:
if isinstance(self.vae, AutoencoderKLCausal3D):
latents = latents.unsqueeze(2)
expand_temporal_dim = True
# 如果形状为 5D ([batch_size, channels, frames, height, width]),则不需要操作。
elif len(latents.shape) == 5:
pass
else:
raise ValueError(
f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
)
# 潜变量的缩放与偏移
# 检查 VAE 配置中是否定义了 shift_factor(偏移因子)
if (
hasattr(self.vae.config, "shift_factor")
and self.vae.config.shift_factor
): # 如果存在,则对潜变量执行缩放和偏移操作
latents = (
latents / self.vae.config.scaling_factor
+ self.vae.config.shift_factor
)
else: # 如果 shift_factor 不存在,仅进行缩放操作
latents = latents / self.vae.config.scaling_factor
with torch.autocast(
device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
):
if enable_tiling:
# 调用 VAE 的 enable_tiling() 方法,可能用于解码较大的图像块。
self.vae.enable_tiling()
# 使用 VAE(变分自编码器)的 decode 方法将潜变量解码为图像。
image = self.vae.decode(
latents, return_dict=False, generator=generator
)[0]
else:
image = self.vae.decode(
latents, return_dict=False, generator=generator
)[0]
# 如果添加了时间维度(expand_temporal_dim=True),或者解码出的图像在时间维度上只有一个帧,则移除时间维度。
if expand_temporal_dim or image.shape[2] == 1:
image = image.squeeze(2)
else:
image = latents
# 图像归一化
image = (image / 2 + 0.5).clamp(0, 1)
# 将图像移动到 CPU,并转换为 float32 类型。这是为了确保图像兼容性,无论之前是否使用了混合精度
image = image.cpu().float()
# 调用 maybe_free_model_hooks() 方法,可能会释放模型占用的内存资源,尤其是在内存有限的 GPU 上有用。
self.maybe_free_model_hooks()
# 如果不需要返回字典(return_dict=False),则直接返回处理后的图像
if not return_dict:
return image
return HunyuanVideoPipelineOutput(videos=image)
3.4 模型结构
模型结构文件是:hyvideo/modules/models.py,主要包括双流块、单流块和主干网络。
3.4.1 双流块 (MMDoubleStreamBlock)
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal dit block with seperate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int, # 模型隐藏层维度。
heads_num: int, # 多头注意力的头数。
mlp_width_ratio: float, # MLP 中隐藏层宽度与 hidden_size 的比率。
mlp_act_type: str = "gelu_tanh", # 激活函数的类型(默认 gelu_tanh)
qk_norm: bool = True, # 是否对 Query 和 Key 启用归一化。
qk_norm_type: str = "rms", # Query 和 Key 归一化的方法(默认 rms)。
qkv_bias: bool = False, # QKV 投影中是否启用偏置项。
dtype: Optional[torch.dtype] = None, # 张量的数据类型和设备。
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
### 图像模态
# 模态调制模块:使用 ModulateDiT,为图像和文本生成 6 组参数(shift、scale、gate)。
self.img_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
# 归一化
self.img_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
# QKV 投影层:通过全连接层计算 Query、Key 和 Value
self.img_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
# 归一化模块
qk_norm_layer = get_norm_layer(qk_norm_type)
self.img_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.img_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
### 文本模态
self.txt_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
self.txt_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.txt_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
img: torch.Tensor, # 图像张量 (B, L_img, hidden_size)
txt: torch.Tensor, # 文本张量 (B, L_txt, hidden_size)
vec: torch.Tensor, # 特征向量,用于调制
cu_seqlens_q: Optional[torch.Tensor] = None, # Query 的累积序列长度
cu_seqlens_kv: Optional[torch.Tensor] = None, # Key/Value 的累积序列长度
max_seqlen_q: Optional[int] = None, # Query 最大序列长度
max_seqlen_kv: Optional[int] = None, # Key/Value 最大序列长度
freqs_cis: tuple = None, # 可选的旋转位置编码参数
) -> Tuple[torch.Tensor, torch.Tensor]:
# vec 特征向量通过 ModulateDiT 模块分别为图像和文本模态生成 6 组调制参数:
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(vec).chunk(6, dim=-1)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(vec).chunk(6, dim=-1)
'''图像模态的前向处理'''
# Layernorm 归一化
img_modulated = self.img_norm1(img)
# 调制函数 modulate 进行标准化和缩放
img_modulated = modulate(
img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
)
# 得到 Query、Key 和 Value
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# 对 Query 和 Key 进行归一化。
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# 对 Query 和 Key 应用旋转位置编码。
if freqs_cis is not None:
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
'''文本模态的前向处理'''
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(
txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
# Apply QK-Norm if needed.
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
# 将图像和文本的 Query、Key、Value 拼接
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
assert (
cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
# 多模态融合注意力计算
if not self.hybrid_seq_parallel_attn:
attn = attention(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
batch_size=img_k.shape[0],
)
else:
attn = parallel_attention(
self.hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv
)
# 最终将注意力结果拆分为图像部分 img_attn 和文本部分 txt_attn
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
'''图像模态的更新'''
# 将注意力结果通过残差连接更新图像特征,并通过 MLP 进一步增强
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
img = img + apply_gate(
self.img_mlp(
modulate(
self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
)
),
gate=img_mod2_gate,
)
'''文本模态的更新'''
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
txt = txt + apply_gate(
self.txt_mlp(
modulate(
self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
)
),
gate=txt_mod2_gate,
)
# 返回更新后的图像特征和文本特征
return img, txt
3.4.2 单流块 (MMSingleStreamBlock)
class MMSingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
Also refer to (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int, # 隐藏层的维度大小,用于表示特征的维度。
heads_num: int, # 注意力头的数量。
mlp_width_ratio: float = 4.0, # 用于确定多层感知机 (MLP) 的隐藏层宽度比例,默认值为 4.0
mlp_act_type: str = "gelu_tanh", # 激活函数类型
qk_norm: bool = True, # 决定是否对 Query 和 Key 应用归一化
qk_norm_type: str = "rms", # 指定 Query 和 Key 的归一化方式,例如 rms(均方根归一化)
qk_scale: float = None, # 自定义缩放因子(用于注意力分数计算中的缩放)
dtype: Optional[torch.dtype] = None, # 控制数据类型
device: Optional[torch.device] = None, # 控制缩放因子
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
self.scale = qk_scale or head_dim ** -0.5
# qkv and mlp_in
self.linear1 = nn.Linear(
hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
)
# proj and mlp_out
self.linear2 = nn.Linear(
hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.pre_norm = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.mlp_act = get_activation_layer(mlp_act_type)()
self.modulation = ModulateDiT(
hidden_size,
factor=3,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
x: torch.Tensor, # x: 输入特征张量,形状为 (batch_size, seq_len, hidden_size)
vec: torch.Tensor, # vec: 辅助特征向量,通常来自调制器
txt_len: int, # txt_len: 文本序列长度,用于区分图像和文本部分。
cu_seqlens_q: Optional[torch.Tensor] = None, # 累积序列长度,用于高效的分段注意力计算。
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,# Query 和 Key/Value 的最大序列长度。
max_seqlen_kv: Optional[int] = None,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, # 可选的旋转位置编码(RoPE)
) -> torch.Tensor:
# 调用 modulation 获取调制参数 mod_shift、mod_scale 和 mod_gate。
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
# 对输入 x 应用 LayerNorm,并进行调制(即元素级缩放和偏移)
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
# 将 x_mod 映射到 qkv 和 mlp 两个部分。
qkv, mlp = torch.split(
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
# qkv 被分为 Query (q)、Key (k)、Value (v) 三个张量,形状为 (batch_size, seq_len, heads_num, head_dim)。
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# 对 Query 和 Key 应用归一化。
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
# 旋转位置编码 (RoPE)
if freqs_cis is not None:
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
# 分别对图像和文本部分应用旋转位置编码
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
# 图像部分和文本部分的 Query/Key 在编码后重新拼接。
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
# Compute attention.
assert (
cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
# attention computation start
if not self.hybrid_seq_parallel_attn:
# 如果没有启用并行注意力机制,调用标准注意力函数 attention
attn = attention(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
batch_size=x.shape[0],
)
else:
# 否则,使用并行注意力机制 parallel_attention
attn = parallel_attention(
self.hybrid_seq_parallel_attn,
q,
k,
v,
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv
)
# attention computation end
# 将注意力结果和 MLP 激活结果拼接,通过线性层投影回输入维度。
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
# 使用 mod_gate 进行门控融合,将残差连接后的结果返回。
return x + apply_gate(output, gate=mod_gate)
3.4.3 混元主干网络(HYVideoDiffusionTransformer)
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
"""
HunyuanVideo Transformer backbone
该类继承了 ModelMixin 和 ConfigMixin,使其与 diffusers 库的采样器(例如 StableDiffusionPipeline)兼容
ModelMixin: 来自 diffusers 的模块,提供了模型的保存和加载功能。
ConfigMixin: 使模型能够以字典形式保存和加载配置信息。
Reference:
[1] Flux.1: https://github.com/black-forest-labs/flux
[2] MMDiT: http://arxiv.org/abs/2403.03206
Parameters
----------
args: argparse.Namespace
传入的命令行参数,用于设置模型的配置。
patch_size: list
输入特征的分块尺寸。一般用于图像或视频的分块操作。
in_channels: int
输入数据的通道数(如 RGB 图像为 3 通道)。
out_channels: int
模型输出的通道数。
hidden_size: int
Transformer 模块中隐藏层的维度。
heads_num: int
多头注意力机制中的注意力头数量,通常用来分配不同的注意力特征。
mlp_width_ratio: float
MLP(多层感知机)中隐藏层维度相对于 hidden_size 的比例。
mlp_act_type: str
MLP 使用的激活函数类型,例如 ReLU、GELU 等。
depth_double_blocks: int
双 Transformer 块的数量。双块可能是指包含多层结构的单元。
depth_single_blocks: int
单 Transformer 块的数量。
rope_dim_list: list
为时空维度(t, h, w)设计的旋转位置编码(ROPE)的维度。
qkv_bias: bool
是否在 QKV(查询、键、值)线性层中使用偏置项。
qk_norm: bool
是否对 Q 和 K 应用归一化。
qk_norm_type: str
QK 归一化的类型。
guidance_embed: bool
是否使用指导嵌入(guidance embedding)来支持蒸馏训练。
text_projection: str
文本投影类型,默认为 single_refiner,可能用于文本引导的视频生成。
use_attention_mask: bool
是否在文本编码器中使用注意力掩码。
dtype: torch.dtype
模型参数的数据类型,例如 torch.float32 或 torch.float16。
device: torch.device
模型的运行设备,如 CPU 或 GPU。
"""
@register_to_config
def __init__(
self,
args: Any, #
patch_size: list = [1, 2, 2],
in_channels: int = 4, # Should be VAE.config.latent_channels.
out_channels: int = None,
hidden_size: int = 3072,
heads_num: int = 24,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
mm_double_blocks_depth: int = 20,
mm_single_blocks_depth: int = 40,
rope_dim_list: List[int] = [16, 56, 56],
qkv_bias: bool = True,
qk_norm: bool = True,
qk_norm_type: str = "rms",
guidance_embed: bool = False, # For modulation.
text_projection: str = "single_refiner",
use_attention_mask: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
# 用来传递设备和数据类型(如torch.float32)的参数,方便后续模块的初始化。
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.unpatchify_channels = self.out_channels # 用来重新拼接patch时的通道数。
self.guidance_embed = guidance_embed
self.rope_dim_list = rope_dim_list
# Text projection. Default to linear projection.
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
self.use_attention_mask = use_attention_mask
self.text_projection = text_projection
self.text_states_dim = args.text_states_dim
self.text_states_dim_2 = args.text_states_dim_2
# 确保每个头的维度是整数。
if hidden_size % heads_num != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
)
pe_dim = hidden_size // heads_num
# 确保位置嵌入的维度与Transformer头的维度一致。
if sum(rope_dim_list) != pe_dim:
raise ValueError(
f"Got {rope_dim_list} but expected positional dim {pe_dim}"
)
self.hidden_size = hidden_size
self.heads_num = heads_num
# 将输入图像分割为小块(patch),并映射到Transformer的隐藏空间hidden_size。
# 每个patch相当于一个Transformer的输入token。
self.img_in = PatchEmbed(
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
)
# 根据text_projection参数选择不同的文本投影方式:
# TextProjection:线性投影,直接将文本特征映射到模型隐藏空间。
if self.text_projection == "linear":
self.txt_in = TextProjection(
self.text_states_dim,
self.hidden_size,
get_activation_layer("silu"),
**factory_kwargs,
)
# SingleTokenRefiner:使用小型Transformer(深度为2)对文本特征进行细化处理。
elif self.text_projection == "single_refiner":
self.txt_in = SingleTokenRefiner(
self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
# TimestepEmbedder:时间步嵌入模块,输入时间信息(例如视频帧的索引),并嵌入到Transformer隐藏空间。
self.time_in = TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
# text modulation
# MLPEmbedder:用于处理来自文本或其他辅助信息的特征,并投影到隐藏空间。
self.vector_in = MLPEmbedder(
self.text_states_dim_2, self.hidden_size, **factory_kwargs
)
# guidance_in:引导嵌入模块,用于处理额外的控制信号(如扩散模型中的引导提示)。
self.guidance_in = (
TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
if guidance_embed
else None
)
# MMDoubleStreamBlock:多模态双流块,融合了图像流和文本流信息。
self.double_blocks = nn.ModuleList(
[
MMDoubleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,
)
for _ in range(mm_double_blocks_depth)
]
)
# MMSingleStreamBlock:单流块,用于进一步处理多模态融合后的单一流特征。
self.single_blocks = nn.ModuleList(
[
MMSingleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs,
)
for _ in range(mm_single_blocks_depth)
]
)
# FinalLayer:将Transformer隐藏空间中的token重新解码为图像patch,并还原到完整图像的分辨率。
self.final_layer = FinalLayer(
self.hidden_size,
self.patch_size,
self.out_channels,
get_activation_layer("silu"),
**factory_kwargs,
)
# 分别在模型中的 双流模块(double_blocks) 和 单流模块(single_blocks) 中启用或禁用确定性行为。
def enable_deterministic(self):
# 在深度学习中,启用确定性行为意味着模型在同样的输入和参数初始化条件下,无论多少次运行都能产生相同的输出结果。
for block in self.double_blocks:
block.enable_deterministic()
for block in self.single_blocks:
block.enable_deterministic()
def disable_deterministic(self):
# 禁用确定性行为可能会允许使用非确定性的操作(如某些高效的并行实现),从而提升计算效率。
for block in self.double_blocks:
block.disable_deterministic()
for block in self.single_blocks:
block.disable_deterministic()
def forward(
self,
x: torch.Tensor, # 输入图像张量,形状为 (N, C, T, H, W)。批量大小为 N,通道数为 C,时间步为 T,高度和宽度为 H 和 W。
t: torch.Tensor, # 时间步张量,用于时间嵌入。范围应为 [0, 1000],可能对应扩散模型或时间相关的特征。
text_states: torch.Tensor = None, # 文本嵌入,表示与图像配对的文本特征。
text_mask: torch.Tensor = None, # 文本掩码张量(可选)。当前未使用,可能用于控制哪些文本特征参与计算。
text_states_2: Optional[torch.Tensor] = None, # 额外的文本嵌入,用于进一步调制(modulation)。在模型中可能是辅助的文本特征表示
freqs_cos: Optional[torch.Tensor] = None, # 正弦和余弦频率,用于位置编码或调制。
freqs_sin: Optional[torch.Tensor] = None,
guidance: torch.Tensor = None, # 引导调制强度,形状可能是 cfg_scale x 1000。通常用于引导生成(如扩散模型的分类引导)。
return_dict: bool = True, # 是否返回一个字典结果。默认为 True。
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
out = {}
img = x
txt = text_states
_, _, ot, oh, ow = x.shape
# 得到划分patch后的t,h,w
tt, th, tw = (
ot // self.patch_size[0],
oh // self.patch_size[1],
ow // self.patch_size[2],
)
# Prepare modulation vectors.
# 时间嵌入通过 self.time_in(t) 提取特征。
vec = self.time_in(t)
# text modulation
# 如果有额外文本嵌入 text_states_2,则通过 self.vector_in 模块对 vec 进行调制。
vec = vec + self.vector_in(text_states_2)
# 启用了引导调制(self.guidance_embed),通过 self.guidance_in 引入引导特征。
if self.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
vec = vec + self.guidance_in(guidance)
# Embed image and text.
# 图像嵌入
img = self.img_in(img)
# 文本嵌入
if self.text_projection == "linear": # 线性投影
txt = self.txt_in(txt)
elif self.text_projection == "single_refiner": # 结合时间步 t 和文本掩码进行更复杂的处理。
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
# 计算序列长度和累积序列索引
# 用于 Flash Attention 的高效计算,cu_seqlens_* 和 max_seqlen_* 控制序列长度和最大长度。
# Compute cu_squlens and max_seqlen for flash attention
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
cu_seqlens_kv = cu_seqlens_q
max_seqlen_q = img_seq_len + txt_seq_len
max_seqlen_kv = max_seqlen_q
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
# --------------------- Pass through DiT blocks ------------------------
for _, block in enumerate(self.double_blocks):
double_block_args = [
img,
txt,
vec,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
freqs_cis,
]
# 并行处理图像和文本信息,使用输入参数(包括嵌入和序列长度等)逐步更新 img 和 txt。
img, txt = block(*double_block_args)
# 合并图像和文本并通过单流模块
x = torch.cat((img, txt), 1)
if len(self.single_blocks) > 0:
for _, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
txt_seq_len,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
(freqs_cos, freqs_sin),
]
x = block(*single_block_args)
# 分离图像特征
img = x[:, :img_seq_len, ...]
# ---------------------------- Final layer ------------------------------
# 图像特征通过 final_layer 提取最终结果
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
# 通过 unpatchify 恢复到原始分辨率。
img = self.unpatchify(img, tt, th, tw)
if return_dict:
out["x"] = img
return out
return img
def unpatchify(self, x, t, h, w):
# 是将被切分为小块(patches)的特征重新还原成原始的张量形状,通常用于图像处理任务中,
# 例如在 ViT(Vision Transformer)模型的输出阶段将 patch 还原为完整图像的形式。
"""
x: (N, T, patch_size**2 * C) (批量大小,时间帧数,每个patch中的通道数)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def params_count(self):
# 计算模型的参数数量,并将其按类别进行统计。它返回一个包含不同类别参数数量的字典,通常用于分析模型的规模或复杂度。
counts = {
"double": sum( # double_blocks 模块的所有参数数量
[
sum(p.numel() for p in block.img_attn_qkv.parameters())
+ sum(p.numel() for p in block.img_attn_proj.parameters())
+ sum(p.numel() for p in block.img_mlp.parameters())
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
+ sum(p.numel() for p in block.txt_mlp.parameters())
for block in self.double_blocks
]
),
"single": sum( # single_blocks 模块的所有参数数量
[
sum(p.numel() for p in block.linear1.parameters())
+ sum(p.numel() for p in block.linear2.parameters())
for block in self.single_blocks
]
),
"total": sum(p.numel() for p in self.parameters()),
}
# double 和 single 参数的总和,主要聚焦于注意力和 MLP 层。
counts["attn+mlp"] = counts["double"] + counts["single"]
return counts