基于PyTorch的视频分类实战

1、数据集下载

官方链接:https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/#Downloads

百度网盘连接:

https://pan.baidu.com/s/1sSn--u_oLvTDjH-BgOAv_Q?pwd=xsri

提取码: xsri 

        官方链接有详细的数据集介绍,下载的是压缩包 ‘hmdb51_org.rar’,解压后里面是 51 个.rar 压缩包,每个压缩包名是一个类别,里面的是对应类别的视频片段(.avi 文件)。因为资源有限,这里只解压了 5 个类别的视频如图 1 所示:

图1 'hmdb5/org'

        这里新建了 ‘hmdb5’ 文件夹,并新建了 ‘org’ 子文件夹,然后把 ‘hmdb51_org’ 文件夹的 5 个子文件夹放到 ‘org’ 中。作为这次实践的源视频数据。

2、utils.py
        在这里先实现 utils.py,即取帧(get_frames)和存帧(store_frames)函数,取帧函数的功能为从视频中等间距抽取 n_frame 帧,并返回这些帧组成的列表。存帧函数的功能即为将帧列表按序存到 path 中。

import os
import cv2
import numpy as np


def get_frames(path, n_frames=1):
    """
    :param path: 视频文件路径
    :param n_frames: 读取的帧数
    :return: 读取的帧列表 frames
    """
    frames = []

    # 实例化一个用于捕获视频流的对象, 若参数为整数则用于读取摄像头视频, 若参数为字符串则用于读取视频文件
    v_cap = cv2.VideoCapture(path)

    '''
    cv2.CAP_PROP_FRAME_COUNT 是 cv2.VideoCapture 类的一个属性标识符,用于获取视频流或视频文件中的总帧数
    cv2.VideoCapture 的 get 方法用于获取视频流或视频文件的属性(返回值均为实数):
        propId 是属性标识符,整数:
            cv2.CAP_PROP_FRAME_WIDTH:视频的帧宽度(以像素为单位)
            cv2.CAP_PROP_FRAME_HEIGHT:视频的帧高度(以像素为单位)
            cv2.CAP_PROP_FPS:视频的帧率(每秒的帧数)
            cv2.CAP_PROP_POS_FRAMES:当前读取帧的位置(基于 0 的索引)
            cv2.CAP_PROP_POS_AVI_RATIO:视频文件的相对位置(播放进度)
            cv2.CAP_PROP_FRAME_COUNT:视频文件中的总帧数
    '''
    v_len = int(v_cap.get(propId=cv2.CAP_PROP_FRAME_COUNT))

    '''
    在指定区间返回等距的数字数组:
        start: 区间起点
        stop: 区间终点
        num: 采样数量
        endpoint: 默认为 True,若为 False 则区间不包括 stop
    '''
    frame_list = np.linspace(start=0, stop=v_len - 1, num=n_frames + 1, dtype=np.int16)

    for fn in range(v_len):
        # 读取下一帧。它返回两个值:一个布尔值 success 表示是否成功读取帧和一个数组 frame 表示读取到的帧。
        success, frame = v_cap.read()
        if success is False:
            continue
        if fn in frame_list:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
    v_cap.release()
    return frames


def store_frames(frames, path):
    """
    :param frames: 待保存为 jpg 图片的帧列表
    :param path: 存储路径
    :return:
    """
    for i, frame in enumerate(frames):
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        path2img = os.path.join(path, "frame" + str(i) + ".jpg")
        cv2.imwrite(path2img, frame)

3、数据抽帧,并划分训练集和测试集

        先在 ‘hmdb5’ 文件夹中新建子文件夹 ‘train’ 和 ‘test',再运行以下代码即可数据抽帧,并划分训练集和测试集。       

import os
from utils import get_frames, store_frames

path = "hmdb5"
org_dir = "org"
org_path = os.path.join(path, org_dir)
categories_list = os.listdir(org_path)
# brush_hair: 0, chartwheel: 1, clap: 2, catch: 3, chew: 4

# 输出每个类别的视频数量
for c in categories_list:
	print("category:", c)
	p = os.path.join(org_path, c)
	video_list = os.listdir(p)
	print("number of videos:", len(video_list))
	print("-" * 50)
"""
category: brush_hair
number of videos: 107
--------------------------------------------------
category: cartwheel
number of videos: 107
--------------------------------------------------
category: clap
number of videos: 130
--------------------------------------------------
category: catch
number of videos: 102
--------------------------------------------------
category: chew
number of videos: 109
--------------------------------------------------
"""



extension = '.avi'
n_frames = 16
train_rate = 0.9

for i, c in enumerate(categories_list):
	p = os.path.join(org_path, c)
	videos = [v for v in os.listdir(p) if v.endswith(extension)]
	train_size = int(len(videos) * train_rate)
	for j, name in enumerate(videos):
		video_path = os.path.join(p, name)
		frames = get_frames(video_path, n_frames=n_frames)
		path2store = os.path.join(path, "train")
		if j >= train_size:
			path2store = os.path.join(path, "test")
		path2store = os.path.join(path2store, str(i)+"_"+name[:-4])
		print(path2store)
		os.makedirs(path2store, exist_ok=True)
		store_frames(frames, path2store)

        第一段代码输出五个类别的视频数量,可以看到 brush_hair、cartwheel、clap、catch 和  chew 依次有 107、107、130、102、109 个视频。最后一段代码的功能是依次对每个类别的每个视频抽帧,并将抽帧结果存置指定路径,同时划分训练集和测试集。这里设置每个视频的抽帧数量 n_frame=16,按 9:1(498:57) 划分训练集和测试集。每个样本(视频文件夹)名都在原来的名字前拼接上 ‘类别编号_’,其中类别编号为:

brush_hair: 0, chartwheel: 1, clap: 2, catch: 3, chew: 4

        这段代码的运行结果如图 2 所示(以测试集为例)即所有类别样本都在一个文件夹中,不再有类别目录,样本名字最前面的数字即为该样本的类别。

图2 测试集部分样本

4、train.py

4.1 导包

import os
import re
import torch
from torch import nn
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torchvision.models import video
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

4.2 设置环境变量

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

4.3 定义超参数

lr = 3e-5
gamma = 0.5
epochs = 20
step_size = 5
batch_size = 16
weight_decay = 1e-2

        这里定义初始学习率为 lr=3e-5,训练轮次为 epochs=20,batch_size=16,正则化系数为 weight_decay=1e-2。gamma 和 step_size 分为 torch.optim.lr_scheduler.ReduceLROnPlateau 类构造函数的入参 factor 和 patience。factor 是学习率降低的因子,新的学习率将是当前学习率乘以这个因子;patience 指观察验证指标在多少个 epoch 内没有改善后降低学习率。

4.4 定义图像变换函数

train_transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.RandomHorizontalFlip(p=0.5),
    # 用于对图像进行随机的仿射变换, degrees 为旋转角度, translate 为水平和垂直平移的最大绝对分数
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize([0.4322, 0.3947, 0.3765], [0.2280, 0.2215, 0.2170])
])

test_transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
    transforms.Normalize([0.4322, 0.3947, 0.3765], [0.2280, 0.2215, 0.2170]),
])

        这里定义了两个图像变换函数,即用于训练集的 train_transform 和用于测试集的 test_transform, train_transform 在训练前依次对图片进行 resize 操作,以 0.5 的概率水平镜像变换操作,随机仿射操作(随机沿 x,y 方向分别平移 (-0.1*w,0.1*w)、(-0.1*h,0.1*h)),转换为 tensor 操作和标准化操作。test_transform 相较于 train_transform 去掉了起数据增强作用的两个操作。

4.5 定义训练集和测试集路径

# 训练集(498):测试机(57)=9:1
train_dir = 'hmdb5/train'
test_dir = 'hmdb5/test'

4.6 定义数据集类

class HMDB5Dataset(Dataset):
    def __init__(self, directory, transform):
        self.dir = directory
        self.transform = transform
        self.names = os.listdir(directory)

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        path = os.path.join(self.dir, self.names[idx])
        frames = []
        for i in range(16):
            frame = Image.open(os.path.join(path, 'frame' + str(i) + '.jpg'))
            frames.append(self.transform(frame))

        frames = torch.stack(frames)
        # 返回 input 的转置版本, 即交换 input 的 dim0 和 dim1
        frames = torch.transpose(input=frames, dim0=0, dim1=1)
        # 编译正则表达式, ^ 表示匹配字符串的开始, + 表示一个或多个
        pattern = re.compile(r'^(\d+)_')
        match = re.search(pattern, self.names[idx])
        return frames, int(match.group(1))

        数据集类的构造函数定义了 3 个属性:dir(数据集路径)、transform(数据预处理方式)和names(样本名列表)。

        __getitem__ 函数根据 idx 按序取出一个样本的所有帧,并对所有帧执行了 transform 操作,最后返回的样本 frames 是 shape 为(channels,n_frames,h,w)的 tensor,该函数还利用 re 库从样本名中获取该样本的标签并返回。

4.7 定义模型

def init_model(mi):
    m = None
    if mi == 1:
        m = video.r3d_18(num_classes=5)  # epochs = 20, correct = 0.754

    return m.to(device)

        这里使用的模型为  torchvision.models.video.r3d_18[1],原文链接:https://arxiv.org/abs/1711.11248。实现可以参考 torch 源码。

4.8 计算评价指标

def correct_loss(data_loader, desc, test):
    results = []
    correct = 0.0
    test_loss = 0.0

    for img, tag in tqdm(data_loader, desc, total=len(data_loader)):
        img = img.to(device)
        tag = tag.to(device)
        pre = model(img)
        if test:
            test_loss += loss_fn(pre, tag)
        correct += torch.sum((pre.argmax(dim=1) == tag).float())

    results.append(correct / len(data_loader.dataset))
    if test:
        results.append(test_loss)

    return results

        correct_loss 函数用于计算 model 在 data_loader 上的 correct 和 loss(如果 test=True ,即data_loader 是测试集的数据加载器)。并将结果以列表的形式返回。

4.9 训练

if __name__ == '__main__':
    model = init_model(1)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr, weight_decay=weight_decay)

    train_ds = HMDB5Dataset(train_dir, train_transform)
    test_ds = HMDB5Dataset(test_dir, test_transform)
    train_dl = DataLoader(train_ds, batch_size, True, num_workers=2)
    test_dl = DataLoader(test_ds, batch_size, False, num_workers=2)
    
    '''
     在验证指标停止改善时降低学习率:
         mode(str): 值域为 {'min', 'max'}。指定优化器应该监视的指标是应该最小化还是最大化
         factor(float): 学习率降低的因子。新的学习率将是当前学习率乘以这个因子
         patience(int): 观察验证指标在多少个 epoch 内没有改善后降低学习率
    '''
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=gamma, patience=step_size, verbose=True)
    best_loss = float('inf')
    for epoch in range(epochs):
        s_loss = 0.0
        print('Epoch:', epoch + 1, '/', epochs)
        for x, y in tqdm(train_dl, total=len(train_dl)):
            x = x.to(device)
            y = y.to(device)
            pred = model(x)
            loss = loss_fn(pred, y)
            s_loss += loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        model.eval()  # 将模型设置为评估模式
        with torch.no_grad():
            print("s_loss:%.3f" % s_loss)
            train_metrics = correct_loss(train_dl, 'compute train_metrics:', False)
            test_metrics = correct_loss(test_dl, 'compute test_metrics:', True)
            if test_metrics[1] < best_loss:
                best_loss = test_metrics[1]
            print("train_correct:%.3f,test_correct:%.3f" % (train_metrics[0], test_metrics[0]))

        model.train()
        scheduler.step(best_loss)

        这里使用交叉墒损失函数,AdamW 优化器,学习率使用 ReduceLROnPlateau scheduler,该 scheduler 监视的指标为 test loss。训练过程中得到的最高 test_correct=0.754。

5、项目目录结构

参考文献

[1] Du Tran, Heng Wang, Lorenzo Torresani, Jamie Ray, Yann LeCun, and Manohar Paluri. A closer look at spatiotemporal convolutions for action recognition. In CVPR, pages 6450–6459, 2018. 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/468019.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ASP .Net Core 8.0 依赖注入的三种注入模式

&#x1f433;前言 &#x1f340;在.NET中&#xff0c;依赖注入&#xff08;Dependency Injection&#xff0c;简称DI&#xff09;是一种设计模式&#xff0c;用于解耦组件之间的依赖关系。 依赖注入的核心思想是将对象的依赖关系&#xff08;即对象所需的其他服务或组件&#…

【系统架构设计师】软件架构设计 02

系统架构设计师 - 系列文章目录 01 系统工程与信息系统基础 02 软件架构设计 文章目录 系统架构设计师 - 系列文章目录 文章目录 前言 一、软件架构的概念 ★★★ 二、基于架构的软件开发 ★★★★ 三、软件架构风格 ★★★★ 1.数据流风格 2.调用/返回风格 3.独立构件风格…

鸿蒙Harmony应用开发—ArkTS声明式开发(绘制组件:Polygon)

多边形绘制组件。 说明&#xff1a; 该组件从API Version 7开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 子组件 无 接口 Polygon(value?: {width?: string | number, height?: string | number}) 从API version 9开始&#xff0…

面试经典150题【81-90】

文章目录 面试经典150题【81-90】530.二叉搜索树的最小绝对差值230.二叉搜索树中第k小的元素98.验证二叉搜索树92.反转链表II25.K个一组翻转链表146.LRU缓存909. 蛇梯棋&#xff08;未做&#xff09;433.最小基因变化127.单词接龙&#xff08;未做&#xff09;17.电话号码的字母…

C++Qt学习——QFile、QPainter、QChart

目录 1、QFile&#xff08;文本读写&#xff09;——概念 1.1、拖入三个控件&#xff0c;对pushButton进行水平布局&#xff0c;之后整体做垂直布局 1.2、按住控件&#xff0c;转到槽&#xff0c;写函数 1.3、打开文件控件 A、首先引入以下两个头文件 B、设置点击打开文件控…

c语言,联合体

一.什么是联合体&#xff1a; 像结构体一样&#xff0c;联合体也是由一个或多个成员变量组成的这些成员变量可以是不同的类型&#xff0c;但编译器只给最大成员分配足够的内存&#xff0c;联合体体内的成员都是公用一块空间的&#xff0c;因此联合体也叫做共同体 二.联合体类…

MySql安装与卸载—我耀学IT

1.MySql安装 打开下载的mysql安装文件mysql-5.5.27-win32.zip&#xff0c;双击解压缩&#xff0c;运行“setup.exe”。 选择安装类型&#xff0c;有“Typical&#xff08;默认&#xff09;”、“Complete&#xff08;完全&#xff09;”、“Custom&#xff08;用户自定义&…

04|调用模型:使用OpenAI API还是微调开源Llama2/ChatGLM?

大语言模型发展史 Transformer是几乎所有预训练模型的核心底层架构。基于Transformer预训练所得的大规模语言模型也被叫做“基础模型”&#xff08;Foundation Model 或Base Model&#xff09;。 在预训练模型出现的早期&#xff0c;BERT毫无疑问是最具代表性的&#xff0c;也…

HTTP 工作流程请求响应 - 面试常问

文章目录 HTTP 工作流程请求和响应格式HTTP请求格式请求行&#xff1a;请求头部字段&#xff1a;空行&#xff1a;消息正文&#xff08;请求正文&#xff09;&#xff1a; HTTP响应格式状态行&#xff1a;响应头部字段&#xff1a;空行&#xff1a; HTTP方法HTTP状态码常用HTTP…

详解基于快速排序算法的qsort的模拟实现

目录 1. 快速排序 1.1 快速排序理论分析 1.2 快速排序的模拟实现 2. qsort的模拟实现 2.1 qsort的理论分析 2.2 qsort的模拟实现 qsort函数是基于快速排序思想设计的可以针对任意数据类型的c语言函数。要对qsort进行模拟实现&#xff0c;首先就要理解快速排序。 1. 快…

Odoo17免费开源ERP开发技巧:如何在表单视图中调用JS类

文/Odoo亚太金牌服务开源智造 老杨 在Odoo最新V17新版中&#xff0c;其突出功能之一是能够构建个性化视图&#xff0c;允许用户以独特的方式与数据互动。本文深入探讨了如何使用 JavaScript 类来呈现表单视图来创建自定义视图。通过学习本教程&#xff0c;你将获得关于开发Odo…

在IDEA中设置使用鼠标滚轮控制字体大小

IDEA是我们常用的程序编程工具&#xff0c;有时为了方便&#xff0c;我们需要随时的调整字体的大小 本篇文章我使用了两种方式来设置IDEA中的字体大小 方式一&#xff1a;使用传统的方式来设置 首先在IDEA顶部的菜单栏中选择“file”菜单 然后在“file”菜单中选择“Setting…

Linux进程通信补充——System V通信

三、System V进程通信 ​ System V是一个单独设计的内核模块&#xff1b; ​ 这套标准的设计不符合Linux下一切皆文件的思想&#xff0c;尽管隶属于文件部分&#xff0c;但是已经是一个独立的模块&#xff0c;并且shmid与文件描述符之间的兼容性做的并不好&#xff0c;网络通…

TCL管理Vivado工程

文章目录 TCL管理Vivado工程1. 项目目录2. 导出脚本文件3. 修改TCL脚本3.1 project.tcl3.2 bd.tcl 4. 工程恢复 TCL管理Vivado工程 工程结构 1. 项目目录 config: 配置文件、coe文件等。doc: 文档fpga: 最后恢复的fpga工程目录ip: ip文件mcs: bit流文件等,方便直接使用src: .…

蓝桥杯物联网竞赛_STM32L071_12_按键中断与串口中断

按键中断&#xff1a; 将按键配置成GPIO_EXTI中断即外部中断 模式有三种上升沿&#xff0c;下降沿&#xff0c;上升沿和下降沿都会中断 external -> 外部的 interrupt -> 打断 trigger -> 触发 detection -> 探测 NVIC中将中断线ENABLE 找接口函数 在接口函数中写…

UE 蓝图场景搭建全流程

点个关注不迷人&#xff0c;可私信帮您解决问题&#xff1a; 学习的知识点&#xff1a; 1、cityEngine 2、bleneder 3、M3工具 4、Uv 基础学习 5、Gameplay 框架 内容很长详情关注&#xff0c;回复ue 获取UE 全流程免费视频教程

OpenCV 单目相机标定

文章目录 一、简介二、实现代码三、实现效果参考资料一、简介 单目相机的标定过程与双目相机的标定过程很类似,具体过程如下所述: 1、首先我们需要获取一个已知图形的图像(这里我们使用MATLAB所提供的数据)。 2、找到同名像点(匹配点),这里主要是探测黑白格子之间的角点…

本地虚拟机平台Proxmox VE结合Cpolar内网穿透实现公网远程访问

&#x1f525;博客主页&#xff1a; 小羊失眠啦. &#x1f3a5;系列专栏&#xff1a;《C语言》 《数据结构》 《C》 《Linux》 《Cpolar》 ❤️感谢大家点赞&#x1f44d;收藏⭐评论✍️ 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&…

【数据挖掘算法与应用】——数据挖掘导论

数据挖掘导论 导入一、为什么要进行数据挖掘1.数据爆炸但知识贫乏2.数据在爆炸式增长3.数据安全4.从商业数据到商业智能的进化5.KDD的出现 二、什么是数据挖掘1.广义技术角度的定义2.狭义技术角度的定义3.商业角度的定义4.数据挖掘与其他科学的关系5.数据挖掘对象6.挖掘到什么知…

数据结构——B树和B+树

数据结构——B树和B树 一、B树1.B树的特征2.B树的插入操作3.B树的删除操作4.B树的缺点 二、B树B树的特征 平衡二叉树或红黑树的查找效率最高&#xff0c;时间复杂度是O(nlogn)。但不适合用来做数据库的索引树。 因为磁盘和内存读写速度有明显的差距&#xff0c;磁盘中存储的数…