手把手写深度学习(23):视频扩散模型之Video DataLoader

手把手写深度学习(0):专栏文章导航

前言:训练自己的视频扩散模型的第一步就是准备数据集,而且这个数据集是text-video或者image-video的多模态数据集,这篇博客手把手教读者如何写一个这样扩散模型的的Video DataLoader。

目录

准备工作

下载数据集

视频数据打标签

代码讲解

纯视频文件夹+txt描述prompt 读取方式

CSV描述文件读取方式


准备工作

下载数据集

一般会去下载webvid数据集,但是这个数据集非常大,如果读者不做预训练的话不建议下载。

《Animating Pictures with Eulerian Motion Fields》提供了一个比较小的测试数据集:Animating Pictures with Eulerian Motion Fields

大概一个GB左右,谷歌云盘的链接如下:

https://drive.google.com/file/d/1-MKuNxO1mjopgY6UoEVGDVt5I_QvVeDn/view

下载之后的.pth文件我们暂时不用管,可以先删除掉,只保留.mp4文件。

视频数据打标签

很多数据集是没有一个比较好的文字描述的,如果我们要训练text-to-video的任务,第一步要做的事情是对视频数据打上文字标签。

如果有,那么就算了,主打一个淘气(不是)

还是下一讲专门讲一下如何用V-BLIP给视频数据打上text标签吧

代码讲解

纯视频文件夹+txt描述prompt 读取方式

第一个DataLoader只需要输入视频的文件夹路径,prompt要么是全部指定成相同的(那肯定不行),要么从同名的txt文件中读取:

        if os.path.exists(self.video_files[index].replace(".mp4", ".txt")):
            with open(self.video_files[index].replace(".mp4", ".txt"), "r") as f:
                prompt = f.read()
        else:
            prompt = self.fallback_prompt

注意这里的text我们直接用预训练的tokenizer编码了,如果不想要的话也可以把这里注释掉:

    def get_prompt_ids(self, prompt):
        return self.tokenizer(
            prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids

获取视频的部分需要特别注意的是,要把"f h w c"转换成"f c h w":

        video = rearrange(video, "f h w c -> f c h w")

完整代码如下:

class VideoFolderDataset(Dataset):
    def __init__(
        self,
        tokenizer=None,
        width: int = 256,
        height: int = 256,
        n_sample_frames: int = 16,
        fps: int = 8,
        path: str = "./data",
        fallback_prompt: str = "",
        use_bucketing: bool = False,
        **kwargs
    ):
        self.tokenizer = tokenizer
        self.use_bucketing = use_bucketing

        self.fallback_prompt = fallback_prompt

        self.video_files = glob(f"{path}/*.mp4")

        self.width = width
        self.height = height

        self.n_sample_frames = n_sample_frames
        self.fps = fps

    def get_frame_buckets(self, vr):
        h, w, c = vr[0].shape        
        width, height = sensible_buckets(self.width, self.height, w, h)
        resize = T.transforms.Resize((height, width), antialias=True)

        return resize

    def get_frame_batch(self, vr, resize=None):
        n_sample_frames = self.n_sample_frames
        native_fps = vr.get_avg_fps()
        
        every_nth_frame = max(1, round(native_fps / self.fps))
        every_nth_frame = min(len(vr), every_nth_frame)
        
        effective_length = len(vr) // every_nth_frame
        if effective_length < n_sample_frames:
            n_sample_frames = effective_length

        effective_idx = random.randint(0, (effective_length - n_sample_frames))
        idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)

        video = vr.get_batch(idxs)
        video = rearrange(video, "f h w c -> f c h w")

        if resize is not None: video = resize(video)
        return video, vr
        
    def process_video_wrapper(self, vid_path):
        video, vr = process_video(
                vid_path,
                self.use_bucketing,
                self.width, 
                self.height, 
                self.get_frame_buckets, 
                self.get_frame_batch
            )
        return video, vr
    
    def get_prompt_ids(self, prompt):
        return self.tokenizer(
            prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids

    @staticmethod
    def __getname__(): return 'folder'

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, index):

        video, _ = self.process_video_wrapper(self.video_files[index])

        if os.path.exists(self.video_files[index].replace(".mp4", ".txt")):
            with open(self.video_files[index].replace(".mp4", ".txt"), "r") as f:
                prompt = f.read()
        else:
            prompt = self.fallback_prompt

        prompt_ids = self.get_prompt_ids(prompt)

        return {"pixel_values": normalize_input(video[0]), "prompt_ids": prompt_ids, "text_prompt": prompt, 'dataset': self.__getname__()}

CSV描述文件读取方式

这种方法每次都要打开一个txt文件去读取prompt,很不方便。而且如果读取的量级大了之后IO的开销会很大!

所以建议使用CSV方式的读取,CSV文件中存放着video-prompt的对应关系,样例如下:

video_path,prompt

...

video_path建议写成绝对路径,这样更方便读取。

完整代码如下:

class VideoCSVDataset(Dataset):
    def __init__(
        self,
        tokenizer=None,
        width: int = 256,
        height: int = 256,
        n_sample_frames: int = 16,
        fps: int = 8,
        csv_path: str = "./data",
        use_bucketing: bool = False,
        **kwargs
    ):
        self.tokenizer = tokenizer
        self.use_bucketing = use_bucketing

        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"The csv path does not exist: {csv_path}")
        self.csv_data = pd.read_csv(csv_path)

        self.width = width
        self.height = height

        self.n_sample_frames = n_sample_frames
        self.fps = fps

    def get_frame_buckets(self, vr):
        h, w, c = vr[0].shape        
        width, height = sensible_buckets(self.width, self.height, w, h)
        resize = T.transforms.Resize((height, width), antialias=True)

        return resize

    def get_frame_batch(self, vr, resize=None):
        n_sample_frames = self.n_sample_frames
        native_fps = vr.get_avg_fps()
        
        every_nth_frame = max(1, round(native_fps / self.fps))
        every_nth_frame = min(len(vr), every_nth_frame)
        
        effective_length = len(vr) // every_nth_frame
        if effective_length < n_sample_frames:
            n_sample_frames = effective_length

        effective_idx = random.randint(0, (effective_length - n_sample_frames))
        idxs = every_nth_frame * np.arange(effective_idx, effective_idx + n_sample_frames)

        video = vr.get_batch(idxs)
        video = rearrange(video, "f h w c -> f c h w")

        if resize is not None: video = resize(video)
        return video, vr
        
    def process_video_wrapper(self, vid_path):
        video, vr = process_video(
                vid_path,
                self.use_bucketing,
                self.width, 
                self.height, 
                self.get_frame_buckets, 
                self.get_frame_batch
            )
        return video, vr
    
    def get_prompt_ids(self, prompt):
        return self.tokenizer(
            prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids

    @staticmethod
    def __getname__(): 
        return 'csv'

    def __len__(self):
        return len(self.csv_data)

    def __getitem__(self, index):
        
        print(self.csv_data.iloc[index])
        video_path, prompt = self.csv_data.iloc[index]
        video, _ = self.process_video_wrapper(video_path)

        prompt_ids = self.get_prompt_ids(prompt)

        return {"pixel_values": normalize_input(video[0]), "prompt_ids": prompt_ids, "text_prompt": prompt, 'dataset': self.__getname__()}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/453714.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

挑战杯 多目标跟踪算法 实时检测 - opencv 深度学习 机器视觉

文章目录 0 前言2 先上成果3 多目标跟踪的两种方法3.1 方法13.2 方法2 4 Tracking By Detecting的跟踪过程4.1 存在的问题4.2 基于轨迹预测的跟踪方式 5 训练代码6 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 深度学习多目标跟踪 …

4G安卓核心板T310_紫光展锐平台方案

紫光展锐T310应用 DynamlQ架构 12nm 制程工艺&#xff0c;采用 1*Cortex-A753*Cortex-A55处理器&#xff0c;搭载Android11.0操作系统&#xff0c;主频最高达2.0GHz.此外&#xff0c;DynamlQ融入了AI神经网络技术&#xff0c;新增机器学习指令&#xff0c;让其在运算方面的机器…

绝对省事!多微信聚合聊天神器大揭秘!

在如今社交网络发达的时代&#xff0c;微信已成为人们生活中不可或缺的通讯工具。然而&#xff0c;对于拥有多个微信账号的用户来说&#xff0c;经常需要来回切换不同账号&#xff0c;给日常使用带来一定的不便。 那么&#xff0c;有没有一种办法能够让我们摆脱这种繁琐的操作…

掼蛋-掌握出牌权

掼蛋游戏中&#xff0c;出牌权往往能决定一局牌的走向&#xff0c;掌握出牌权可以主动控制局势。出牌权是指在每一轮的出牌环节中谁先出牌。出牌权的重要性主要体现在以下两个方面&#xff1a; 一、控制节奏 出牌权可以让我们主动控制游戏的节奏&#xff0c;可以根据自己的出牌…

Post请求出现Request header is too large

问题描述&#xff1a; 在做项目的时候&#xff0c;前端请求体太大的时候&#xff0c;出现Request header is too large问题&#xff0c;后端接口如下&#xff1a; 前端请求接口返回问题如下&#xff1a; 解决方案&#xff1a; 问题原因&#xff1a;这是因为我们在做Springboo…

BUG:RuntimeError: input.size(-1) must be equal to input_size. Expected 1, got 3

出现的bug为:RuntimeError: input.size(-1) must be equal to input_size. Expected 1, got 3 出现问题的截图: 问题产生原因:题主使用pytorch调用的nn.LSTM里面的input_size和外面的数据维度大小不对。问题代码如下: self.lstm nn.LSTM(input_size, hidden_size, num_laye…

计算机网络-第6章 应用层(2)

6.5 电子邮件 电子邮件&#xff0c;把邮件发送到收件人使用的邮件服务器&#xff0c;并放在其中的收件人邮箱中。最重要的两个标准&#xff1a;简单邮件传送协议SMTP&#xff0c;互联网文本报文格式。 SMTP只能传7位ASCII码邮件&#xff0c;93年提出互联网邮件扩充MIME。邮件…

关于YOLOv9去掉辅助分支脚本使用的一些说明。

专栏介绍&#xff1a;YOLOv9改进系列 | 包含深度学习最新创新&#xff0c;主力高效涨点&#xff01;&#xff01;&#xff01; B站链接&#xff1a;YOLOv9去除辅助训练分支&#xff01;_哔哩哔哩_bilibili 一、说明 在subbranch_removal.py脚本中&#xff0c;我们需要填入上方…

新西兰 eSIM 卡 ONE NZ充值、激活

新西兰One NZ 保号规则和费用 先说大家比较关注的保号条件和费用吧。 新买的卡有效期 720 天&#xff0c;能够充值续期&#xff0c;但是充值后的有效期反而变为 360 天&#xff08;用于保号的兄弟就快过期再充值&#xff09;如果到期后不去充值&#xff0c;账户将变为非活跃状…

SAP 工单CO02 TECO时检查的增强BADI:WORKORDER_UPDATE

需求&#xff1a;需要在CO02进行TECO时检查一下 第三代增强&#xff1a;BADI&#xff1a;WORKORDER_UPDATE中的REORG_STATUS_ACT_CHECK方法 第一步&#xff1a;SE19输入BADI&#xff0c;然后创建 填入名称&#xff1a;ZWORKORDER_UPDATE和描述 输入类名&#xff1a;ZCL_WORKORD…

C语言函数—自定义函数

如果库函数能干所有的事情&#xff0c;那还要程序员干什么&#xff1f; 所有更加重要的是自定义函数。 自定义函数和库函数一样&#xff0c;有函数名&#xff0c;返回值类型和函数参数。 但是不一样的是这些都是我们自己来设计。 这给程序员一个很大的发挥空间。 函数的组…

第十四届蓝桥杯蜗牛

蜗牛 线性dp 目录 蜗牛 线性dp 先求到达竹竿底部的状态转移方程 求蜗牛到达第i根竹竿的传送门入口的最短时间​编辑 题目链接&#xff1a;蓝桥杯2023年第十四届省赛真题-蜗牛 - C语言网 关键在于建立数组将竹竿上的每个状态量表示出来&#xff0c;并分析出状态转移方程 in…

在Linux中进行OpenSSH升级

由于OpenSSH有严重漏洞&#xff0c;因此需要升级OpenSSH到最新版本。 OpenSSL和OpenSSH都要更新&#xff0c;OpenSSH依赖于OpenSSL。 第一步&#xff0c;查看当前的OpenSSH服务版本。 命令&#xff1a;ssh -V 第二步&#xff0c;安装、启动telnet&#xff0c;关闭安全文件&a…

免费AI软件开发工具测评:iFlyCode VS CodeFlying

前言 Hello&#xff0c;各位看官&#xff0c;今天为大家带来两款人工智能的软件开发工具的测评&#xff0c;他们分别是iFlyCode和CodeFlying&#xff0c;我相信当大家看到这两款产品名字的时候不禁都会有些好奇&#xff0c;两个产品都有Code 和Fly两个元素&#xff0c;那他们之…

Python语言在编程业界的地位——《跟老吕学Python编程》附录资料

Python语言在编程业界的地位——《跟老吕学Python编程》附录资料 ⭐️Python语言在编程业界的地位2024年3月编程语言排行榜&#xff08;TIOBE前十&#xff09; ⭐️Python开发语言开发环境介绍1.**IDLE**2.⭐️PyCharm3.**Anaconda**4.**Jupyter Notebook**5.**Sublime Text** …

若依上传文件/common/upload踩坑

前言&#xff1a;作者用的mac系统&#xff08;这个是个坑&#xff09;&#xff0c;前端用的uniapp&#xff0c;调用若依通用上传方法报错NoSuchFileException: /home/ruoyi/uploadPath/upload... 前端上传代码示例如下: uni.chooseImage({count: 1,success(res){ uni.uploa…

在centos8中部署Tomcat和Jenkins

参考链接&#xff1a;tomcat安装和部署jenkins_jenkins和tomcat-CSDN博客 1、进入centos中 /usr/local 目录文件下 [rootlocalhost webapps]# cd /usr/local2、使用通过wget命令下下载tomcat或者直接在官网下载centos版本的包后移动到centos中的local路径下 3、下载tomcat按…

1307页字节跳动Android面试全套真题解析在互联网火了-,完整版开放下载

多进程带来的问题 AIDL 使用浅析binder 原理解析binder 最底层解析多进程通信方式以及带来的问题多进程通信方式对比 Android 高级必备 &#xff1a;AMS,WMS,PMS AMS,WMS,PMS 创建过程 AMS,WMS,PMS全解析AMS启动流程WindowManagerService启动过程解析PMS 启动流程解析 A…

PyTorch之完整的神经网络模型训练

简单的示例&#xff1a; 在PyTorch中&#xff0c;可以使用nn.Module类来定义神经网络模型。以下是一个示例的神经网络模型定义的代码&#xff1a; import torch import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()# 定义神经…

DPN网络

DPN DPN&#xff08;Dual Path Networks&#xff09;是一种网络结构&#xff0c;它结合了DensNet和ResNetXt两种思想的优点。这种结构的目的是通过不同的路径来利用神经网络的不同特性&#xff0c;从而提高模型的效率和性能。 DenseNet 的特点是其稠密连接路径&#xff0c;使…