【目标跟踪网络训练 Market-1501 数据集】DeepSort 训练自己的跟踪网络模型

前言

Deepsort之所以可以大量避免IDSwitch,是因为Deepsort算法中特征提取网络可以将目标检测框中的特征提取出来并保存,在目标被遮挡后又从新出现后,利用前后的特征对比可以将遮挡的后又出现的目标和遮挡之前的追踪的目标重新找到,大大减少了目标在遮挡后,追踪失败的可能。

一、数据集简介

Market-1501 数据集在清华大学校园中采集,夏天拍摄,在 2015 年构建并公开。它包括由6个摄像头(其中5个高清摄像头和1个低清摄像头)拍摄到的 1501 个行人、32668 个检测到的行人矩形框。每个行人至少由2个摄像头捕获到,并且在一个摄像头中可能具有多张图像。训练集有 751 人,包含 12,936 张图像,平均每个人有 17.2 张训练数据;测试集有 750 人,包含 19,732 张图像,平均每个人有 26.3 张测试数据。3368 张查询图像的行人检测矩形框是人工绘制的,而 gallery 中的行人检测矩形框则是使用DPM检测器检测得到的。

该数据集提供的固定数量的训练集和测试集均可以在single-shot或multi-shot测试设置下使用。

目录结构

Market-1501-v15.09.15

  ├── bounding_box_test

       ├── 0000_c1s1_000151_01.jpg

       ├── 0000_c1s1_000376_03.jpg

       ├── 0000_c1s1_001051_02.jpg

  ├── bounding_box_train

       ├── 0002_c1s1_000451_03.jpg

       ├── 0002_c1s1_000551_01.jpg

       ├── 0002_c1s1_000801_01.jpg

  ├── gt_bbox

       ├── 0001_c1s1_001051_00.jpg

       ├── 0001_c1s1_009376_00.jpg

       ├── 0001_c2s1_001976_00.jpg

  ├── gt_query

       ├── 0001_c1s1_001051_00_good.mat

       ├── 0001_c1s1_001051_00_junk.mat

  ├── query

       ├── 0001_c1s1_001051_00.jpg

       ├── 0001_c2s1_000301_00.jpg

       ├── 0001_c3s1_000551_00.jpg

  └── readme.txt

目录介绍

(1) “bounding_box_test”——用于测试集的 750 人,包含 19,732 张图像,前缀为 0000 表示在提取这 750 人的过程中DPM检测错的图(可能与query是同一个人),-1 表示检测出来其他人的图(不在这 750 人中)

(2) “bounding_box_train”——用于训练集的 751 人,包含 12,936 张图像

(3) “query”——为 750 人在每个摄像头中随机选择一张图像作为query,因此一个人的query最多有 6 个,共有 3,368 张图像

(4) “gt_query”——matlab格式,用于判断一个query的哪些图片是好的匹配(同一个人不同摄像头的图像)和不好的匹配(同一个人同一个摄像头的图像或非同一个人的图像)

(5) “gt_bbox”——手工标注的bounding box,用于判断DPM检测的bounding box是不是一个好的box

命名规则

以 0001_c1s1_000151_01.jpg 为例

1) 0001 表示每个人的标签编号,从0001到1501;

2) c1 表示第一个摄像头(camera1),共有6个摄像头;

3) s1 表示第一个录像片段(sequece1),每个摄像机都有数个录像段;

4) 000151 表示 c1s1 的第000151帧图片,视频帧率25fps;

5) 01 表示 c1s1_001051 这一帧上的第1个检测框,由于采用DPM检测器,对于每一帧上的行人可能会框出好几个bbox。00 表示手工标注框

二、跟踪模型介绍

特征提取的模型有很多,可以替换特征提取模型网络。

下述给出的是 deep_sort/deep/model.py 里面的模型代码

import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    def __init__(self, c_in, c_out, is_downsample=False):
        super(BasicBlock, self).__init__()
        self.is_downsample = is_downsample
        if is_downsample:
            self.conv1 = nn.Conv2d(
                c_in, c_out, 3, stride=2, padding=1, bias=False)
        else:
            self.conv1 = nn.Conv2d(
                c_in, c_out, 3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(c_out)
        self.relu = nn.ReLU(True)
        self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(c_out)
        if is_downsample:
            self.downsample = nn.Sequential(
                nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
                nn.BatchNorm2d(c_out)
            )
        elif c_in != c_out:
            self.downsample = nn.Sequential(
                nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
                nn.BatchNorm2d(c_out)
            )
            self.is_downsample = True

    def forward(self, x):
        y = self.conv1(x)
        y = self.bn1(y)
        y = self.relu(y)
        y = self.conv2(y)
        y = self.bn2(y)
        if self.is_downsample:
            x = self.downsample(x)
        return F.relu(x.add(y), True)


def make_layers(c_in, c_out, repeat_times, is_downsample=False):
    blocks = []
    for i in range(repeat_times):
        if i == 0:
            blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]
        else:
            blocks += [BasicBlock(c_out, c_out), ]
    return nn.Sequential(*blocks)


class Net(nn.Module):
    def __init__(self, num_classes=751, reid=False):
        super(Net, self).__init__()
        # 3 128 64
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            # nn.Conv2d(32,32,3,stride=1,padding=1),
            # nn.BatchNorm2d(32),
            # nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, padding=1),
        )
        # 32 64 32
        self.layer1 = make_layers(64, 64, 2, False)
        # 32 64 32
        self.layer2 = make_layers(64, 128, 2, True)
        # 64 32 16
        self.layer3 = make_layers(128, 256, 2, True)
        # 128 16 8
        self.layer4 = make_layers(256, 512, 2, True)
        # 256 8 4
        self.avgpool = nn.AvgPool2d((8, 4), 1)
        # 256 1 1
        self.reid = reid
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        # B x 128
        if self.reid:
            x = x.div(x.norm(p=2, dim=1, keepdim=True))
            return x
        # classifier
        x = self.classifier(x)
        return x


if __name__ == '__main__':
    net = Net()
    x = torch.randn(4, 3, 128, 64)
    y = net(x)
    import ipdb
    ipdb.set_trace()

三、数据集处理

splitDataset.py

用于存放训练的图片

# -*- coding:utf-8 -*-
# @author: 牧锦程
# @微信公众号: AI算法与电子竞赛
# @Email: m21z50c71@163.com
# @VX:fylaicai

import os
from shutil import copyfile

# You only need to change this line to your dataset download path
download_path = 'Market-1501-v15.09.15'

if not os.path.isdir(download_path):
    print('please change the download_path')

save_path = 'pytorch'
if not os.path.isdir(save_path):
    os.mkdir(save_path)

# ------------------- query ----------------------
query_path = download_path + '/query'
query_save_path = save_path + '/query'
print("process: ", query_path)
if not os.path.isdir(query_save_path):
    os.mkdir(query_save_path)

for root, dirs, files in os.walk(query_path, topdown=True):
    for name in files:
        if not name[-3:] == 'jpg':
            continue
        ID = name.split('_')
        src_path = query_path + '/' + name
        dst_path = query_save_path + '/' + ID[0]
        if not os.path.isdir(dst_path):
            os.mkdir(dst_path)
        copyfile(src_path, dst_path + '/' + name)

# ----------------- multi-query ------------------------
query_path = download_path + '/gt_bbox'
print("process: ", query_path)
# for dukemtmc-reid, we do not need multi-query
if os.path.isdir(query_path):
    query_save_path = save_path + '/multi-query'
    if not os.path.isdir(query_save_path):
        os.mkdir(query_save_path)

    for root, dirs, files in os.walk(query_path, topdown=True):
        for name in files:
            if not name[-3:] == 'jpg':
                continue
            ID = name.split('_')
            src_path = query_path + '/' + name
            dst_path = query_save_path + '/' + ID[0]
            if not os.path.isdir(dst_path):
                os.mkdir(dst_path)
            copyfile(src_path, dst_path + '/' + name)

# ------------------- gallery ----------------------
gallery_path = download_path + '/bounding_box_test'
gallery_save_path = save_path + '/gallery'
print("process: ", gallery_path)
if not os.path.isdir(gallery_save_path):
    os.mkdir(gallery_save_path)

for root, dirs, files in os.walk(gallery_path, topdown=True):
    for name in files:
        if not name[-3:] == 'jpg':
            continue
        ID = name.split('_')
        src_path = gallery_path + '/' + name
        dst_path = gallery_save_path + '/' + ID[0]
        if not os.path.isdir(dst_path):
            os.mkdir(dst_path)
        copyfile(src_path, dst_path + '/' + name)

# ------------------ train ---------------------
train_path = download_path + '/bounding_box_train'
train_save_path = save_path + '/train'
val_save_path = save_path + '/test'
if not os.path.isdir(train_save_path):
    os.mkdir(train_save_path)
    os.mkdir(val_save_path)

print("process: ", train_path)
for root, dirs, files in os.walk(train_path, topdown=True):
    for name in files:
        if not name[-3:] == 'jpg':
            continue
        ID = name.split('_')
        src_path = train_path + '/' + name
        dst_path = train_save_path + '/' + ID[0]
        if not os.path.isdir(dst_path):
            os.mkdir(dst_path)
            # first image is used as val image
            dst_path = val_save_path + '/' + ID[0]
            os.mkdir(dst_path)
        copyfile(src_path, dst_path + '/' + name)

四、模型训练

修改数据集路径

修改 data-dir 参数为自己的数据集路径

修改数据增强

增加一个尺寸修改

修改类别数量

代码中通过dataloader来获取,因此可以不进行修改

num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes))

修改保存模型名称

这里的修改是为例与原始的模型进行区分,可以做对比

训练结果

查看精度

先运行test.py,生成 features.pth

在运行 evaluate.py,得到如下精度:

五、链接作者

欢迎关注我的公众号:@AI算法与电子竞赛

硬性的标准其实限制不了无限可能的我们,所以啊!少年们加油吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/694066.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32 proteus + STM32Cubemx仿真教程(第二课按键教程)

文章目录 前言一、STM32按键的原理二、STM32Cubemx创建工程三、proteus仿真电路图四、程序代码编写4.1函数介绍4.2使用按键点亮LED灯 总结 前言 本篇文章开始带大家学习如何使用proteus和STM32Cubemx来完成STM32的学习,第二节课我们先来学习按键的用法。 proteus使…

景区ar互动大屏游戏化体验提升营销力度

从20世纪60年代的初步构想,到如今全球范围内无数企业的竞相投入,AR增强现实技术已成为引领科技潮流的重要力量。而在这一浪潮中,中国的AR公司正以其独特的魅力和创新力,崭露头角。 中国的AR市场正在迎来前所未有的发展机遇。如今&…

qmt量化交易策略小白学习笔记第16期【qmt编程之获取北向南向资金(沪港通,深港通和港股通)】

qmt编程之获取北向南向资金 qmt更加详细的教程方法,会持续慢慢梳理。 也可找寻博主的历史文章,搜索关键词查看解决方案 ! 北向南向资金(沪港通,深港通和港股通) #北向南向资金交易日历 获取交易日列表…

UltraEditUEStudio软件最新版下载及详细安装教程

UEStudio简介: UEStudio建立在上文本编辑器UltraEdit的功能基础上,并为团队和开发人员提供了其他功能,例如深度Git集成。您可以直接在UEStudio中克隆,签出,更新,提交,推入/拉入等操作&#xff…

SpringBoot整合钉钉实现消息推送

前言 钉钉作为一款企业级通讯工具,具有广泛的应用场景,包括但不限于团队协作、任务提醒、工作汇报等。 通过Spring Boot应用程序整合钉钉实现消息推送,我们可以实现以下功能: 实时向指定用户或群组发送消息通知。自定义消息内容…

知识图谱的应用---智慧交通

文章目录 智慧交通典型应用 智慧交通 现代城市发展过程中的一大问题是交通拥堵,为解决城市发展中的这一顽疾,有必要以现代化高科技技术为支撑,建造城市中的智慧交通系统,从源头入手缓解城市拥挤问题。当前,“智慧交通”…

如何获取当前dll或exe模块所在路径?

有时我们需要在当前运行的dll或exe模块中去动态加载当前模块同路径中的另一个库,或者启动当前模块同路径中的另一个exe程序,一般需要获取当前模块的路径,然后去构造同路径下目标模块的绝对路径,然后通过该绝对路径去加载或启动该目…

java自学阶段二:JavaWeb开发50(Spring和Springboot学习)

Spring、Springboot基础知识学习 目录 学习目标Spring基础概念IOC控制反转DI依赖注入事务管理AOP面向切面编程Spring案例说明(Postman使用、Restful开发规范、lombok、Restful、nginx了解) 一:学习目标: 1)了解Sprin…

海洋日特别活动—深海来客——可燃冰

深海中有一种神奇的物质,似冰又不是冰。 别看它其貌不扬,但本领不小,遇火即燃,能量巨大,可谓是能源家族的新宠。它就是被国务院正式批准列为我国第173个矿种的“可燃冰”! 可燃冰到底是个啥?它…

C++ | Leetcode C++题解之第133题克隆图

题目&#xff1a; 题解&#xff1a; class Solution { public:Node* cloneGraph(Node* node) {if (node nullptr) {return node;}unordered_map<Node*, Node*> visited;// 将题目给定的节点添加到队列queue<Node*> Q;Q.push(node);// 克隆第一个节点并存储到哈希…

Java Web学习笔记31——Maven介绍

Maven&#xff1a;Java项目的构建工具。 Maven&#xff1a; Maven是Apache旗下的一个开源项目&#xff0c;是一款用于管理和构建Java项目的工具。 Apache软件基金会&#xff0c;成立于1999年7月&#xff0c;是目前世界上最大的最受欢迎的开源软件基金会&#xff0c;也是一个专…

手把手带你做一个自己的网络调试助手(2) - TCP服务器完善

服务器指定客户端发送 自定义控件comboBox - 刷新客户端列表 目的&#xff1a; 自定义控件&#xff0c;当鼠标点击这个comboBox控件的时候去刷新客户端列表 mycombobox.h #ifndef MYCOMBOBOX_H #define MYCOMBOBOX_H#include <QComboBox> #include <QWidget>cl…

服务部署:.NET项目使用Docker构建镜像与部署

前提条件 安装Docker&#xff1a;确保你的Linux系统上已经安装了Docker。如果没有&#xff0c;请参考官方文档进行安装。 步骤一&#xff1a;准备项目文件 将你的.NET项目从Windows系统复制到Linux系统。你可以使用Git、SCP等工具来完成这个操作。如何是使用virtualbox虚拟电…

SpringCloud Gateway中Filters详细说明

前面 https://blog.csdn.net/J080624/article/details/139494909 我们研究了GateWay中各种路由断言的使用。SpringCloud GateWay 还提供了各种过滤器用来对请求和响应进行处理。 官网地址&#xff1a;SpringCloud Gateway Filter 【1】GatewayFilter Factories 路由过滤器允…

20240609如何查询淘宝的历史价格

20240609如何查询淘宝的历史价格 2024/6/9 18:39 百度&#xff1a;淘宝历史价格 淘宝历史价格查询网站 https://zhuanlan.zhihu.com/p/670972171 30秒学会淘宝商品历史价格查询&#xff01; https://item.taobao.com/item.htm?id693104421622&pidmm_29415502_2422500430_1…

排序-读取数据流并实时返回中位数

目录 一、问题描述 二、解题思路 1.顺序表排序法 2.使用大根堆、小根堆 三、代码实现 1.顺序表排序法实现 2.大根堆、小根堆法实现 四、刷题链接 一、问题描述 二、解题思路 1.顺序表排序法 &#xff08;1&#xff09;每次读取一个数就对列表排一次序&#xff0c;对排…

python-微分方程计算

首先导入数据 import numpy as np from scipy.integrate import odeint from scipy.optimize import minimize import matplotlib.pyplot as pltdata np.array([[30, 4],[47.2, 6.1],[70.2, 9.8],[77.4, 35.2],[36.3, 59.4],[20.6, 41.7],[18.1, 19],[21.4, 13],[22, 8.3],[2…

初识 peerDependencies

目录 初步认识 peerDependencies semver 介绍 # 摘要 # 简介 # 语义化版本控制规范&#xff08;SemVer&#xff09; # 合法语义化版本的巴科斯范式语法 # 为什么要使用语义化的版本控制&#xff1f; # FAQ 示例讲解&#xff1a;vue-router 插件 # 说明 声明 验证 初…

电子阅览室有何作用

随着互联网的快速发展&#xff0c;电子阅览室逐渐成为人们获取知识的新方式。它为读者提供了便捷、高效的阅读体验&#xff0c;具有诸多作用。首先&#xff0c;电子阅览室拥有丰富的电子书籍资源&#xff0c;涵盖了各个领域的知识。无论是文学作品还是学术论文&#xff0c;读者…

商城项目【尚品汇】08异步编排-01基础篇

文章目录 1.线程的创建方式1.1继承Thread类&#xff0c;重写run方法1.2实现Runnable接口&#xff0c;重写run方法。1.3实现Callable接口&#xff0c;重新call方法1.4以上三种总结1.5使用线程池创建线程1.5.1线程池创建线程的方式1.5.2线程池的七大参数含义1.5.3线程池的工作流程…