人体关键点检测2:Pytorch实现人体关键点检测(人体姿势估计)含训练代码
目录
人体关键点检测2:Pytorch实现人体关键点检测(人体姿势估计)含训练代码
1. 前言
2.人体关键点检测方法
(1)Top-Down(自上而下)方法
(2)Bottom-Up(自下而上)方法:
3.人体关键点检测数据集
4.人体检测模型训练
5.人体关键点检测模型训练
(1)项目安装
(2)准备Train和Test数据
(3)配置文件configs
(4)开始训练
(5)Tensorboard可视化训练过程
6.人体关键点检测检测模型效果
7.人体关键点检测(推理代码)下载
8.人体关键点检测(训练代码)下载
9.人体关键点检测C++/Android版本
1. 前言
人体关键点检测(Human Keypoints Detection)又称为人体姿态估计2D Pose,是计算机视觉中一个相对基础的任务,是人体动作识别、行为分析、人机交互等的前置任务。一般情况下可以将人体关键点检测细分为单人/多人关键点检测、2D/3D关键点检测,同时有算法在完成关键点检测之后还会进行关键点的跟踪,也被称为人体姿态跟踪。
本项目将实现人体关键点检测算法,其中使用YOLOv5模型实现人体检测(Person Detection),使用HRNet,LiteHRNet和Mobilenet-v2模型实现人体关键点检测。项目分为数据集说明,模型训练和C++/Android部署等多个章节,本篇是项目《人体关键点检测(人体姿势估计)》系列文章之Pytorch实现人体关键点检测(人体姿势估计);为了方便后续模型工程化和Android平台部署,项目支持高精度HRNet检测模型,轻量化模型LiteHRNet和Mobilenet模型训练和测试,并提供Python/C++/Android多个版本;
轻量化Mobilenet-v2模型在普通Android手机上可以达到实时的检测效果,CPU(4线程)约50ms左右,GPU约30ms左右 ,基本满足业务的性能需求。下表格给出HRNet,以及轻量化模型LiteHRNet和Mobilenet的计算量和参数量,以及其检测精度
模型 | input-size | params(M) | GFLOPs | AP |
HRNet-w32 | 192×256 | 28.48M | 5734.05M | 0.7585 |
LiteHRNet18 | 192×256 | 1.10M | 182.15M | 0.6237 |
Mobilenet-v2 | 192×256 | 2.63M | 529.25M | 0.6181 |
先展示一下人体关键点检测效果:
Android人体关键点检测APP Demo体验(下载):https://download.csdn.net/download/guyuealian/88610359
【尊重原创,转载请注明出处】https://blog.csdn.net/guyuealian/article/details/134837816
更多项目《人体关键点检测(人体姿势估计)》系列文章请参考:
- 人体关键点检测1:人体姿势估计数据集(含下载链接) https://blog.csdn.net/guyuealian/article/details/134703548
- 人体关键点检测2:Pytorch实现人体关键点检测(人体姿势估计)含训练代码和数据集 https://blog.csdn.net/guyuealian/article/details/134837816
- 人体关键点检测3:Android实现人体关键点检测(人体姿势估计)含源码 可实时检测 https://blog.csdn.net/guyuealian/article/details/134881797
- 人体关键点检测4:C/C++实现人体关键点检测(人体姿势估计)含源码 可实时检测 https://blog.csdn.net/guyuealian/article/details/134881797
2.人体关键点检测方法
目前主流的人体关键点检测(人体姿势估计)方法主要两种:一种是Top-Down(自上而下)方法,另外一种是Bottom-Up(自下而上)方法;
(1)Top-Down(自上而下)方法
将人体检测和人体关键点检测(人体姿势估计)检测分离,在图像上首先进行人体目标检测,定位人体位置;然后crop每一个人体图像,再估计人体关键点;这类方法往往比较慢,但姿态估计准确度较高。目前主流模型主要有CPN,Hourglass,CPM,Alpha Pose,HRNet等。
(2)Bottom-Up(自下而上)方法:
先估计图像中所有人体关键点,然后在通过Grouping的方法组合成一个一个实例;因此这类方法在测试推断的时候往往更快速,准确度稍低。典型就是COCO2016年人体关键点检测冠军Open Pose。
通常来说,Top-Down具有更高的精度,而Bottom-Up具有更快的速度;就目前调研而言, Top-Down的方法研究较多,精度也比Bottom-Up(自下而上)方法高。本项目采用Top-Down(自上而下)方法,先使用YOLOv5模型实现人体检测,然后再使用HRNet进行人体关键点检测(人体姿势估计);
本项目基于开源的HRNet进行改进,关于HRNet项目请参考GitHub
HRNet: https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
3.人体关键点检测数据集
本项目主要使用COCO数据集和MPII数据集,关于人体关键点检测数据集说明,请参考《人体关键点检测1:人体姿势估计数据集》https://blog.csdn.net/guyuealian/article/details/134703548
4.人体检测模型训练
本项目采用Top-Down(自上而下)方法,使用YOLOv5模型实现人体目标检测,使用HRNet进行人体关键点检测(人体姿势估计);关于人体检测模型训练方法,可参考 :
行人检测(人体检测)2:YOLOv5实现人体检测(含人体检测数据集和训练代码)
5.人体关键点检测模型训练
整套工程项目基本结构如下:
.
├── configs # 训练配置文件
├── data # 一些数据
├── libs # 一些工具库
├── pose # 姿态估计模型文件
├── work_space # 训练输出工作目录
├── demo.py # 模型推理demo文件
├── README.md # 项目工程说明文档
├── requirements.txt # 项目相关依赖包
└── train.py # 训练文件
(1)项目安装
推荐使用Python3.8或Python3.7,更高版本可能存在版本差异问题,项目依赖python包请参考requirements.txt,使用pip安装即可,项目代码都在Ubuntu系统和Windows系统验证正常运行,请放心使用;若出现异常,大概率是相关依赖包版本没有完全对应
numpy==1.21.6
matplotlib==3.2.2
Pillow==8.4.0
bcolz==1.2.1
easydict==1.9
onnx==1.8.1
onnx-simplifier==0.2.28
onnxoptimizer==0.2.0
onnxruntime==1.6.0
opencv-contrib-python==4.5.2.52
opencv-python==4.5.1.48
pandas==1.1.5
PyYAML==5.3.1
scikit-image==0.17.2
scikit-learn==0.24.0
scipy==1.5.4
seaborn==0.11.2
sklearn==0.0
tensorboard==2.5.0
tensorboardX==2.1
torch==1.7.1+cu110
torchvision==0.8.2+cu110
tqdm==4.55.1
xmltodict==0.12.0
pycocotools==2.0.2
pybaseutils==0.9.4
basetrainer
项目安装教程请参考(初学者入门,麻烦先看完下面教程,配置好Python开发环境):
- 推荐使用Python3.8或Python3.7,更高版本可能存在版本差异问题
- 项目开发使用教程和常见问题和解决方法
- 视频教程:1 手把手教你安装CUDA和cuDNN(1)
- 视频教程:2 手把手教你安装CUDA和cuDNN(2)
- 视频教程:3 如何用Anaconda创建pycharm环境
- 视频教程:4 如何在pycharm中使用Anaconda创建的python环境
(2)准备Train和Test数据
下载COCO数据集或者MPII数据集(建议使用COCO数据集),然后:
- COCO数据集下载并解压到本地,存储目录结构参考如下(原始图片目录和标注信息文件在同一级目录)
─── COCO
├── train2017
│ ├── images # COCO训练集原始图片目录
│ └── person_keypoints_train2017.json # COCO训练集标注信息文件
└── val2017
├── images # COCO验证集原始图片目录
└── person_keypoints_val2017.json # COCO验证集标注信息文件
- MPII数据集下载并解压到本地,存储目录结构参考如下
─── MPII
├── images # MPII数据集原始图片目录
├── train.json # MPII训练集标注信息文件
└── valid.json # MPII训练集标注信息文件
(3)配置文件configs
项目支持HRNet以及轻量化模型LiteHRNet和Mobilenet模型训练,并提供对应的配置文件;你需要修改对应配置文件的数据路径;本篇以训练HRNet-w32为例子,其配置文件在configs/coco/hrnet/w32_adam_192_192.yaml,修改该文件的训练数据集路径TRAIN_FILE(支持多个数据集训练)和测试数据集TEST_FILE的数据路径为你本地数据路径,其他参数保持默认即可,如下所示:
WORKERS: 8
PRINT_FREQ: 10
DATASET:
DATASET: 'custom_coco'
TRAIN_FILE:
- 'D:/COCO/train2017/person_keypoints_train2017.json'
TEST_FILE: 'D:/COCO/val2017/person_keypoints_val2017.json'
FLIP: true
ROT_FACTOR: 45
SCALE_FACTOR: 0.3
SCALE_RATE: 1.25
JOINT_IDS: [0,1]
FLIP_PAIRS: [ ]
SKELETON: [ ]
配置文件的一些参数说明,请参考
参数 | 类型 | 参考值 | 说明 |
WORKERS | int | 8 | 数据加载处理的进程数 |
PRINT_FREQ | int | 10 | 打印LOG信息的间隔 |
DATASET | str | custom_coco | 数据集类型,目前仅支持COCO数据格式 |
TRAIN_FILE | List | - | 训练数据集文件列表(COCO数据格式),支持多个数据集 |
TEST_FILE | string | - | 测试数据集文件(COCO数据格式),仅支持单个数据集 |
FLIP | bool | True | 是否翻转图片进行测试,可提高测试效果 |
ROT_FACTOR | float | 45 | 训练数据随机旋转的最大角度,用于数据增强 |
SCALE_FACTOR | float | 1.25 | 图像缩放比例因子 |
SCALE_RATE | float | 0.25 | 图像缩放率 |
JOINT_IDS | list | [ ] | [ ]表示所有关键点,也可以指定需要训练的关键点序号ID |
FLIP_PAIRS | list | [ ] | 图像翻转时,关键点不受翻转影响的ID号 |
SKELETON | list | [ ] | 关键点连接线的序列列表,用于可视化效果 |
(4)开始训练
修改好配置文件后,就可以开始准备训练了:
- 训练高精度模型HRNet-w48或者HRNet-w32
# 高精度模型:HRNet-w32
python train.py -c "configs/coco/hrnet/w48_adam_192_192.yaml" --workers=8 --batch_size=32 --gpu_id=0 --work_dir="work_space/person"
# 高精度模型:HRNet-w48
python train.py -c "configs/coco/hrnet/w32_adam_192_192.yaml" --workers=8 --batch_size=32 --gpu_id=0 --work_dir="work_space/person"
- 训练轻量化模型LiteHRNet
# 轻量化模型:LiteHRNet
python train.py -c "configs/coco/litehrnet/litehrnet18_192_192.yaml" --workers=8 --batch_size=32 --gpu_id=0 --work_dir="work_space/person"
- 训练轻量化模型Mobilenetv2
# 轻量化模型:Mobilenet
python train.py -c "configs/coco/mobilenet/mobilenetv2_192_192.yaml" --workers=8 --batch_size=32 --gpu_id=0 --work_dir="work_space/person"
下表格给出HRNet,以及轻量化模型LiteHRNet和Mobilenet的计算量和参数量,以及其检测精度AP; 高精度检测模型HRNet-w32,AP可以达到0.7585,但其参数量和计算量比较大,不合适在移动端部署;LiteHRNet18和Mobilenet-v2参数量和计算量比较少,合适在移动端部署;虽然LiteHRNet18的理论计算量和参数量比Mobilenet-v2低,但在实际测试中,发现Mobilenet-v2运行速度更快。轻量化Mobilenet-v2模型在普通Android手机上可以达到实时的检测效果,CPU(4线程)约50ms左右,GPU约30ms左右 ,基本满足业务的性能需求
模型 | input-size | params(M) | GFLOPs | AP |
HRNet-w32 | 192×256 | 28.48M | 5734.05M | 0.7585 |
LiteHRNet18 | 192×256 | 1.10M | 182.15M | 0.6237 |
Mobilenet-v2 | 192×256 | 2.63M | 529.25M | 0.6181 |
(5)Tensorboard可视化训练过程
训练过程可视化工具是使用Tensorboard,使用方法,在终端输入:
# 基本方法
tensorboard --logdir=path/to/log/
# 例如
tensorboard --logdir="work_space/person/hrnet_w32_16_192_256_mpii_20231127_113836_6644/log"
点击终端TensorBoard打印的链接,即可在浏览器查看训练LOG信息等:
6.人体关键点检测检测模型效果
demo.py文件用于推理和测试模型的效果,填写好配置文件,模型文件以及测试图片即可运行测试了;demo.py命令行参数说明如下:
参数 | 类型 | 参考值 | 说明 |
---|---|---|---|
-c,--config_file | str | - | 配置文件 |
-m,--model_file | str | - | 模型文件 |
target | str | - | 骨骼点类型,如hand,coco_person,mpii_person |
image_dir | str | data/image | 测试图片的路径 |
video_file | str,int | - | 测试的视频文件 |
out_dir | str | output | 保存结果,为空不保存 |
threshold | float | 0.3 | 关键点检测置信度 |
device | str | cuda:0 | GPU ID |
下面以运行HRNet-w32为样例,其他模型修改--config_file或者--model_file即可
- 测试图片
python demo.py -c work_space/person/hrnet_w32_17_192_256_custom_coco_20231115_092948_1789/w32_adam_192_192.yaml -m work_space/person/hrnet_w32_17_192_256_custom_coco_20231115_092948_1789/model/best_model_195_0.7585.pth --image_dir data/test_images --out_dir output
- 测试视频文件
python demo.py -c work_space/person/hrnet_w32_17_192_256_custom_coco_20231115_092948_1789/w32_adam_192_192.yaml -m work_space/person/hrnet_w32_17_192_256_custom_coco_20231115_092948_1789/model/best_model_195_0.7585.pth --video_file data/video-test.mp4 --out_dir output
- 测试摄像头
python demo.py -c work_space/person/hrnet_w32_17_192_256_custom_coco_20231115_092948_1789/w32_adam_192_192.yaml -m work_space/person/hrnet_w32_17_192_256_custom_coco_20231115_092948_1789/model/best_model_195_0.7585.pth --video_file 0 --out_dir output
项目同时支持MPII数据集格式人体关键点检测
- 测试图片(MPII格式的人体关键点检测)
python demo.py -c work_space/person/hrnet_w32_16_192_256_mpii_20231127_113836_6644/w32_adam_192_192.yaml -m work_space/person/hrnet_w32_16_192_256_mpii_20231127_113836_6644/model/best_model_148_89.4041.pth --image_dir data/test_images --out_dir output --target mpii_person
运行效果(支持单人和多人人体关键点检测):
7.人体关键点检测(推理代码)下载
人体关键点检测推理代码下载地址:Pytorch实现人体关键点检测(人体姿势估计)推理代码
资源内容包含:人体关键点检测推理代码(Pytorch)
- 提供YOLOv5人体检测推理代码(不包含训练代码)
- 提供人体关键点检测(人体姿势估计)推理代码demo.py(不包含训练代码)
- 提供高精度版本HRNet人体关键点检测(人体姿势估计)(不包含训练代码)
- 提供轻量化模型LiteHRNet,以及Mobilenet-v2人体关键点检测(人体姿势估计)(不包含训练代码)
- 提供训练好的模型:HRNet-w32,LiteHRNet和Mobilenet-v2模型权重文件,配置好环境,可直接运行demo.py
- 推理代码demo.py支持图片,视频和摄像头测试
如果你需要配套的训练数据集和训练代码,请查看下面部分
8.人体关键点检测(训练代码)下载
人体关键点检测训练代码下载地址:
资源内容包含:
- 提供YOLOv5人体检测推理代码
- 提供整套完整的项目工程代码,包含人体关键点检测(人体姿势估计)的训练代码train.py和推理测试代码demo.py
- 提供高精度版本HRNet人体关键点检测(人体姿势估计)训练和测试代码
- 提供轻量化模型LiteHRNet以及Mobilenet-v2人体关键点检测(人体姿势估计)训练和测试代码
- 项目代码支持MPII数据集和COCO数据集人体关键点检测模型训练和测试
- 根据本篇博文说明,简单配置即可开始训练:train.py
- 提供训练好的模型:HRNet-w32,LiteHRNet和Mobilenet-v2模型权重文件,配置好环境,可直接运行demo.py
- 推理代码demo.py支持图片,视频和摄像头测试
9.人体关键点检测C++/Android版本
- 人体关键点检测3:Android实现人体关键点检测(人体姿势估计)含源码 可实时检测 https://blog.csdn.net/guyuealian/article/details/134881797
- 人体关键点检测4:C/C++实现人体关键点检测(人体姿势估计)含源码 可实时检测 https://blog.csdn.net/guyuealian/article/details/134881797
Android人体关键点检测APP Demo体验(下载):https://download.csdn.net/download/guyuealian/88610359