一,安装环境
这一步,需根据自己的硬件环境,按照以下文档安装即可,最大的变数就是不同的硬件,对应的软件版本不一样,这个因人而异,没有统一版本。mmpose安装说明:
https://mmpose.readthedocs.io/zh-cn/latest/installation.html
我的安装版本(A30, CUDA Version: 11.4 )
name: mmpose
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- brotli-python=1.0.9=py38hfa26641_7
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2024.8.30=hbcca054_0
- certifi=2024.8.30=pyhd8ed1ab_0
- charset-normalizer=3.4.0=pyhd8ed1ab_0
- cudatoolkit=11.3.1=h9edb442_10
- ffmpeg=4.3=hf484d3e_0
- freetype=2.12.1=h4a9f257_0
- giflib=5.2.2=h5eee18b_0
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- idna=3.10=pyhd8ed1ab_0
- intel-openmp=2021.4.0=h06a4308_3561
- jpeg=9b=h024ee3a_2
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.40=h12ee557_0
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.16=h5eee18b_3
- libidn2=2.3.4=h5eee18b_0
- libpng=1.6.39=h5eee18b_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.19.0=h5eee18b_0
- libtiff=4.1.0=h2733197_1
- libunistring=0.9.10=h27cfd23_0
- libuv=1.48.0=h5eee18b_0
- libwebp=1.2.0=h89dd481_0
- lz4-c=1.9.4=h6a678d5_1
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py38h7f8727e_0
- mkl_fft=1.3.1=py38hd3c417c_0
- mkl_random=1.2.2=py38h51133e4_0
- ncurses=6.4=h6a678d5_0
- nettle=3.7.3=hbbd107a_1
- ninja=1.10.2=h06a4308_5
- ninja-base=1.10.2=hd09550d_5
- numpy=1.24.3=py38h14f4228_0
- numpy-base=1.24.3=py38h31eccc5_0
- openh264=2.1.1=h4ff587b_0
- openssl=1.1.1o=h166bdaf_0
- pillow=9.3.0=py38hace64e9_1
- pip=24.2=py38h06a4308_0
- pysocks=1.7.1=pyha2e5f31_6
- python=3.8.8=hdb3f193_5
- python_abi=3.8=2_cp38
- pytorch=1.12.1=py3.8_cuda11.3_cudnn8.3.2_0
- pytorch-mutex=1.0=cuda
- readline=8.2=h5eee18b_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.45.3=h5eee18b_0
- tk=8.6.14=h39e8969_0
- torchaudio=0.12.1=py38_cu113
- torchvision=0.13.1=py38_cu113
- typing_extensions=4.11.0=py38h06a4308_0
- wheel=0.44.0=py38h06a4308_0
- xz=5.4.6=h5eee18b_1
- zlib=1.2.13=h5eee18b_1
- zstd=1.4.9=haebb681_0
- pip:
- addict==2.4.0
- aliyun-python-sdk-core==2.16.0
- aliyun-python-sdk-kms==2.16.5
- attrs==24.2.0
- cffi==1.17.1
- chumpy==0.70
- click==8.1.7
- colorama==0.4.6
- contourpy==1.1.1
- coverage==7.6.1
- crcmod==1.7
- cryptography==43.0.3
- cycler==0.12.1
- cython==3.0.11
- exceptiongroup==1.2.2
- filelock==3.14.0
- flake8==7.1.1
- fonttools==4.54.1
- importlib-metadata==8.5.0
- importlib-resources==6.4.5
- iniconfig==2.0.0
- interrogate==1.7.0
- isort==4.3.21
- jmespath==0.10.0
- json-tricks==3.17.3
- kiwisolver==1.4.7
- markdown==3.7
- markdown-it-py==3.0.0
- matplotlib==3.7.5
- mccabe==0.7.0
- mdurl==0.1.2
- mmcv==2.0.0
- mmdet==3.1.0
- mmengine==0.10.5
- model-index==0.1.11
- munkres==1.1.4
- opencv-python==4.10.0.84
- opendatalab==0.0.10
- openmim==0.3.9
- openxlab==0.1.2
- ordered-set==4.1.0
- oss2==2.17.0
- packaging==24.1
- pandas==2.0.3
- parameterized==0.9.0
- platformdirs==4.3.6
- pluggy==1.5.0
- py==1.11.0
- pycocotools==2.0.7
- pycodestyle==2.12.1
- pycparser==2.22
- pycryptodome==3.21.0
- pyflakes==3.2.0
- pygments==2.18.0
- pyparsing==3.1.4
- pytest==8.3.3
- pytest-runner==6.0.1
- python-dateutil==2.9.0.post0
- pytz==2023.4
- pyyaml==6.0.2
- requests==2.28.2
- rich==13.4.2
- scipy==1.10.1
- setuptools==60.2.0
- shapely==2.0.6
- tabulate==0.9.0
- termcolor==2.4.0
- terminaltables==3.1.10
- tomli==2.0.2
- tqdm==4.65.2
- tzdata==2024.2
- urllib3==1.26.20
- xdoctest==1.2.0
- xtcocotools==1.14.3
- yapf==0.40.2
- zipp==3.20.2
prefix: ./tools/anaconda3/envs/mmpose
安装完环境后,下载mmpose代码到本地
验证环境是否ok
python demo/image_demo.py \
tests/data/coco/000000000785.jpg \
td-hm_hrnet-w48_8xb32-210e_coco-256x192.py \
td-hm_hrnet-w48_8xb32-210e_coco-256x192-0e67c616_20220913.pth \
--out-file vis_results.jpg \
--draw-heatmap
二,下载配置文件与模型
- 2.1 目标检测模型
-
MMDetection模型库:
https://github.com/open-mmlab/mmdetection/blob/master/docs/en/model_zoo.md
-
检测模型配置文件:(路径为mmpose对应的路径)
demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py
-
下载检测模型到本地目录
wget https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth -P checkpoint
- 2.2 MMPose人体姿态估计模型
-
hrnet姿态估计模型配置文件:(路径为mmpose对应的路径
configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py
-
下载hrnet姿态估计模型到本地目录
wget https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth -P checkpoint
- 2.3 RTMPose人体姿态估计模型
RTMPose主页:
https://github.com/open-mmlab/mmpose/tree/dev-1.x/projects/rtmpose
-
> RTMPose-S
-
配置文件:(路径为mmpose对应的路径)
projects/rtmpose/rtmpose/body_2d_keypoint/rtmpose-s_8xb256-420e_coco-256x192.py
-
下载rtmpose姿态估计模型到本地目录
wget https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmpose-s_simcc-aic-coco_pt-aic-coco_420e-256x192-fcb2599b_20230126.pth -P checkpoint
-
> RTMPose-L
-
配置文件:(路径为mmpose对应的路径)
projects/rtmpose/rtmpose/body_2d_keypoint/rtmpose-l_8xb256-420e_coco-384x288.py
-
下载rtmpose姿态估计模型到本地目录
wget https://download.openmmlab.com/mmpose/v1/projects/rtmpose/rtmpose-l_simcc-aic-coco_pt-aic-coco_420e-384x288-97d6cb0f_20230228.pth -P checkpoint
三,调用模型预测图片
以下代码参考:同济子豪兄
代码的主要思路是加载目标检测模型,获取图片中人的检测框,然后调用关键点检测模型,针对每个人的检测框,将检测框内的人体关键点检测出来。
## 导入工具包import cv2import numpy as npfrom PIL import Image
import matplotlib.pyplot as plt
import torch
import mmcvfrom mmcv import imreadimport mmenginefrom mmengine.registry import init_default_scope
from mmpose.apis import inference_topdownfrom mmpose.apis import init_model as init_pose_estimatorfrom mmpose.evaluation.functional import nmsfrom mmpose.registry import VISUALIZERSfrom mmpose.structures import merge_data_samples
from mmdet.apis import inference_detector, init_detector
# 有 GPU 就用 GPU,没有就用 CPUdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')print('device', device)
## 载入待测图像img_path = './data/test/multi-person.jpeg'
## 构建目标检测模型# Faster R CNNdetector = init_detector( './demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py', './checkpoint/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth', device=device)
## 构建人体姿态估计模型pose_estimator = init_pose_estimator( './configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py', './checkpoint/hrnet_w32_coco_256x192-c78dce93_20200708.pth', device=device, cfg_options={'model': {'test_cfg': {'output_heatmaps': True}}})
## 预测-目标检测init_default_scope(detector.cfg.get('default_scope', 'mmdet'))
# 获取目标检测预测结果detect_result = inference_detector(detector, img_path)
print(detect_result.keys())
# 预测类别detect_result.pred_instances.labels
# 置信度detect_result.pred_instances.scores
## 置信度阈值过滤,获得最终目标检测预测结果# 置信度阈值CONF_THRES = 0.5
pred_instance = detect_result.pred_instances.cpu().numpy()bboxes = np.concatenate((pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)bboxes = bboxes[np.logical_and(pred_instance.labels == 0, pred_instance.scores > CONF_THRES)]bboxes = bboxes[nms(bboxes, 0.3)][:, :4]
## 预测-关键点
# 获取每个 bbox 的关键点预测结果pose_results = inference_topdown(pose_estimator, img_path, bboxes)
print(len(pose_results))
# 把多个bbox的pose结果打包到一起data_samples = merge_data_samples(pose_results)
print(data_samples.keys())
## 预测结果-关键点坐标# 每个人 17个关键点 坐标print(data_samples.pred_instances.keypoints.shape)
# 索引为 0 的人,每个关键点的坐标data_samples.pred_instances.keypoints[0,:,:]
## 预测结果-关键点热力图# 每一类关键点的预测热力图data_samples.pred_fields.heatmaps.shape
idx_point = 13heatmap = data_samples.pred_fields.heatmaps[idx_point,:,:]
print(heatmap.shape)
# 索引为 idx 的关键点,在全图上的预测热力图# plt.imshow(heatmap)# plt.show()
## MMPose官方可视化工具`visualizer`# 半径pose_estimator.cfg.visualizer.radius = 10# 线宽pose_estimator.cfg.visualizer.line_width = 8visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)# 元数据visualizer.set_dataset_meta(pose_estimator.dataset_meta)
## 展示可视化效果
img = mmcv.imread(img_path)img = mmcv.imconvert(img, 'bgr', 'rgb')
img_output = visualizer.add_datasample( 'result', img, data_sample=data_samples, draw_gt=False, draw_heatmap=True, draw_bbox=True, show_kpt_idx=True, show=False, wait_time=0, out_file='./outputs/B2.jpg')
# 将 RGB 图像转换为 BGR 格式img_output = cv2.cvtColor(img_output, cv2.COLOR_RGB2BGR)
cv2.imwrite('./outputs/output_image.jpg', img_output) # 保存图像# img_output.shape
# plt.figure(figsize=(10,10))# plt.imshow(img_output)# plt.show()# 保存图像print("结束!")
原图
关键点检测结果