手把手写深度学习(0):专栏文章导航
前言:训练自己的视频扩散模型的第一步就是准备数据集,而且这个数据集是text-video或者image-video的多模态数据集,这篇博客手把手教读者如何写一个这样扩散模型的的Video DataLoader。
目录
准备工作
下载数据集
视频数据打标签
代码讲解
纯视频文件夹+txt描述prompt 读取方式
CSV描述文件读取方式
准备工作
下载数据集
一般会去下载webvid数据集,但是这个数据集非常大,如果读者不做预训练的话不建议下载。
《Animating Pictures with Eulerian Motion Fields》提供了一个比较小的测试数据集:Animating Pictures with Eulerian Motion Fields
大概一个GB左右,谷歌云盘的链接如下:
https://drive.google.com/file/d/1-MKuNxO1mjopgY6UoEVGDVt5I_QvVeDn/view
下载之后的.pth文件我们暂时不用管,可以先删除掉,只保留.mp4文件。
视频数据打标签
很多数据集是没有一个比较好的文字描述的,如果我们要训练text-to-video的任务,第一步要做的事情是对视频数据打上文字标签。
如果有,那么就算了,主打一个淘气(不是)
还是下一讲专门讲一下如何用V-BLIP给视频数据打上text标签吧
代码讲解
纯视频文件夹+txt描述prompt 读取方式
第一个DataLoader只需要输入视频的文件夹路径,prompt要么是全部指定成相同的(那肯定不行),要么从同名的txt文件中读取:
if os.path.exists(self.video_files[index].replace(".mp4", ".txt")):
with open(self.video_files[index].replace(".mp4", ".txt"), "r") as f:
prompt = f.read()
else:
prompt = self.fallback_prompt
注意这里的text我们直接用预训练的tokenizer编码了,如果不想要的话也可以把这里注释掉:
def get_prompt_ids(self, prompt):
return self.tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
获取视频的部分需要特别注意的是,要把"f h w c"转换成"f c h w":
video = rearrange(video, "f h w c -> f c h w")
完整代码如下:
class VideoFolderDataset(Dataset):
def __init__(
self,
tokenizer=None,
width: int = 256,
height: int = 256,
n_sample_frames: int = 16,
fps: int = 8,
path: str = "./data",
fallback_prompt: str = "",
use_bucketing: bool = False,
**kwargs
):
self.tokenizer = tokenizer
self.use_bucketing = use_bucketing
self.fallback_prompt = fallback_prompt
self.video_files = glob(f"{path}/*.mp4")
self.width = width
self.height = height
self.n_sample_frames = n_sample_frames
self.fps = fps
def get_frame_buckets(self, vr):
h, w, c = vr[0].shape
width, height = sensible_buckets(self.width, self.height, w, h)
resize = T.transforms.Resize((height, width), antialias=True)
return resize
def get_frame_batch(self, vr, resize=None):
n_sample_frames = self.n_sample_frames
native_fps = vr.get_avg_fps()
every_nth_frame = max(1, round(native_fps / self.fps))
every_nth_frame = min(len(vr), every_nth_frame)
effective_length = len(vr) // every_nth_frame
if effective_length < n_sample_frames:
n_sample_frames = effective_length
effective_idx = random.randint(0, (effective_length - n_sample_frames))
idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)
video = vr.get_batch(idxs)
video = rearrange(video, "f h w c -> f c h w")
if resize is not None: video = resize(video)
return video, vr
def process_video_wrapper(self, vid_path):
video, vr = process_video(
vid_path,
self.use_bucketing,
self.width,
self.height,
self.get_frame_buckets,
self.get_frame_batch
)
return video, vr
def get_prompt_ids(self, prompt):
return self.tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
@staticmethod
def __getname__(): return 'folder'
def __len__(self):
return len(self.video_files)
def __getitem__(self, index):
video, _ = self.process_video_wrapper(self.video_files[index])
if os.path.exists(self.video_files[index].replace(".mp4", ".txt")):
with open(self.video_files[index].replace(".mp4", ".txt"), "r") as f:
prompt = f.read()
else:
prompt = self.fallback_prompt
prompt_ids = self.get_prompt_ids(prompt)
return {"pixel_values": normalize_input(video[0]), "prompt_ids": prompt_ids, "text_prompt": prompt, 'dataset': self.__getname__()}
CSV描述文件读取方式
这种方法每次都要打开一个txt文件去读取prompt,很不方便。而且如果读取的量级大了之后IO的开销会很大!
所以建议使用CSV方式的读取,CSV文件中存放着video-prompt的对应关系,样例如下:
video_path,prompt
...
video_path建议写成绝对路径,这样更方便读取。
完整代码如下:
class VideoCSVDataset(Dataset):
def __init__(
self,
tokenizer=None,
width: int = 256,
height: int = 256,
n_sample_frames: int = 16,
fps: int = 8,
csv_path: str = "./data",
use_bucketing: bool = False,
**kwargs
):
self.tokenizer = tokenizer
self.use_bucketing = use_bucketing
if not os.path.exists(csv_path):
raise FileNotFoundError(f"The csv path does not exist: {csv_path}")
self.csv_data = pd.read_csv(csv_path)
self.width = width
self.height = height
self.n_sample_frames = n_sample_frames
self.fps = fps
def get_frame_buckets(self, vr):
h, w, c = vr[0].shape
width, height = sensible_buckets(self.width, self.height, w, h)
resize = T.transforms.Resize((height, width), antialias=True)
return resize
def get_frame_batch(self, vr, resize=None):
n_sample_frames = self.n_sample_frames
native_fps = vr.get_avg_fps()
every_nth_frame = max(1, round(native_fps / self.fps))
every_nth_frame = min(len(vr), every_nth_frame)
effective_length = len(vr) // every_nth_frame
if effective_length < n_sample_frames:
n_sample_frames = effective_length
effective_idx = random.randint(0, (effective_length - n_sample_frames))
idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)
video = vr.get_batch(idxs)
video = rearrange(video, "f h w c -> f c h w")
if resize is not None: video = resize(video)
return video, vr
def process_video_wrapper(self, vid_path):
video, vr = process_video(
vid_path,
self.use_bucketing,
self.width,
self.height,
self.get_frame_buckets,
self.get_frame_batch
)
return video, vr
def get_prompt_ids(self, prompt):
return self.tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
@staticmethod
def __getname__():
return 'csv'
def __len__(self):
return len(self.csv_data)
def __getitem__(self, index):
print(self.csv_data.iloc[index])
video_path, prompt = self.csv_data.iloc[index]
video, _ = self.process_video_wrapper(video_path)
prompt_ids = self.get_prompt_ids(prompt)
return {"pixel_values": normalize_input(video[0]), "prompt_ids": prompt_ids, "text_prompt": prompt, 'dataset': self.__getname__()}