【保姆级教程】YOLOv8_Pose多目标+关键点检测:训练自己的数据集

Yolov8官方给出的是单类别的人体姿态关键点检测,本文将记录如果实现训练自己的多类别的关键点检测。

一、YOLOV8环境准备

1.1 下载安装最新的YOLOv8代码

 仓库地址: https://github.com/ultralytics/ultralytics

1.2 配置环境

  pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

二、数据准备

2.1 安装labelme标注软件

pip install labelme

2.1.2 打开roLabelImg软件

使用Anaconda Prompt启动labeme标注工具

在这里插入图片描述

2.2 标注自己的数据

不同的目标的关键点可以自己定义,关键点数量少的目标,再下一步转换标签格式的时候,需要将关键点的数量补齐。例如,下图所示的摩托车有9个关键点,汽车有2个关键点,那边汽车的标签还需要补上7个 0 0 0 (坐标为0,且不可见)。这样才能进行多目标的关键点检测的训练。
我这里的关键点命名按 11 , 22, 33, … ,99来命名的,读者可以根据自己的实际情况来命名。
在这里插入图片描述

2.3 数据转换

2.3.1 运行下面代码,将xml标签格式转为txt标签格式

关键点数量对齐

在这里插入图片描述

# 将labelme标注的json文件转为yolo格式
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import glob
import json
import tqdm
# 物体类别

class_list = ["motorbike","car","cone"]
# 关键点的顺序
keypoint_list = ["11", "22", "33", "44","55", "66", "77", "88", "99"]
def json_to_yolo(img_data ,json_data):
    h ,w = img_data.shape[:2]
    # 步骤:
    # 1. 找出所有的矩形,记录下矩形的坐标,以及对应group_id
    # 2. 遍历所有的head和tail,记下点的坐标,以及对应group_id,加入到对应的矩形中
    # 3. 转为yolo格式
    rectangles = {}
    # 遍历初始化
    for shape in json_data["shapes"]:
        label = shape["label"] # pen, head, tail
        group_id = shape["group_id"] # 0, 1, 2, ...
        points = shape["points"] # x,y coordinates
        shape_type = shape["shape_type"]

        # 只处理矩形,读矩形
        if shape_type == "rectangle":
            if group_id not in rectangles:
                rectangles[group_id] = {
                "label": label,
                "rect": points[0] + points[1],  # Rectangle [x1, y1, x2, y2]
                "keypoints_list": []
        }
    # 遍历更新,将点加入对应group_id的矩形中,读关键点,根据group_id匹配
    for keypoint in keypoint_list:
        for shape in json_data["shapes"]:
            label = shape["label"]
            group_id = shape["group_id"]
            points = shape["points"]
            # 如果匹配到了对应的keypoint
            if label == keypoint:
                rectangles[group_id]["keypoints_list"].append(points[0])
            #else:
             #   rectangles[group_id]["keypoints_list"].append([0,0])

    # 转为yolo格式
    yolo_list = []
    for id, rectangle in rectangles.items():
        result_list  = []
        if rectangle['label'] not in class_list:
            continue
        label_id = class_list.index(rectangle["label"])
        # x1,y1,x2,y2
        x1 ,y1 ,x2 ,y2 = rectangle["rect"]
        # center_x, center_y, width, height
        center_x = (x1 +x2 ) /2
        center_y = (y1 +y2 ) /2
        width = abs(x1 -x2)
        height = abs(y1 -y2)
        # normalize
        center_x /= w
        center_y /= h
        width /= w
        height /= h

        # 保留6位小数
        center_x = round(center_x, 6)
        center_y = round(center_y, 6)
        width = round(width, 6)
        height = round(height, 6)

        # 添加 label_id, center_x, center_y, width, height
        result_list = [label_id, center_x, center_y, width, height]

        # 添加 p1_x, p1_y, p1_v, p2_x, p2_y, p2_v
        for point in rectangle["keypoints_list"]:
            x ,y = point
            x ,y = int(x), int(y)
            x /= w
            y /= h
            # 保留6位小数
            x = round(x, 6)
            y = round(y, 6)
            result_list.extend([x ,y ,2])
        if len(rectangle["keypoints_list"]) == 4:
            result_list.extend([0, 0, 0])
            result_list.extend([0, 0, 0])
            result_list.extend([0, 0, 0])
            result_list.extend([0, 0, 0])
            result_list.extend([0, 0, 0])

        if len(rectangle["keypoints_list"]) == 2:
            result_list.extend([0, 0, 0])
            result_list.extend([0, 0, 0])
            result_list.extend([0, 0, 0])
            result_list.extend([0, 0, 0])
            result_list.extend([0, 0, 0])
            result_list.extend([0, 0, 0])
            result_list.extend([0, 0, 0])
        
        yolo_list.append(result_list)
    return yolo_list
# 获取所有的图片
img_list = glob.glob("D:/study/cnn/yolo/yolov8-mokpt/ultralytics/data_mokpt/*.png")
for img_path in tqdm.tqdm( img_list ):

    img = cv2.imread(img_path)
    print(img_path)
    json_file = img_path.replace('png', 'json')
    with open(json_file) as json_file:
        json_data = json.load(json_file)

    yolo_list = json_to_yolo(img, json_data)
    yolo_txt_path = img_path.replace('png', 'txt')

    with open(yolo_txt_path, "w") as f:
        for yolo in yolo_list:
            for i in range(len(yolo)):
                if i == 0:
                    f.write(str(yolo[i]))
                else:
                    f.write(" " + str(yolo[i]))
            f.write("\n")
运行上面代码,就可以获得TXT格式标签文件

在这里插入图片描述

2.3.2 运行下面代码,检查txt标签转换是否正确

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import glob

img_path = "D:/study/cnn/yolo/yolov8-mokpt/ultralytics/data_mokpt/1.png"

plt.figure(figsize=(15, 10))
img = cv2.imread(img_path)
plt.imshow(img[:, :, ::-1])
plt.axis('off')

yolo_txt_path = img_path.replace('png', 'txt')
print(yolo_txt_path)

with open(yolo_txt_path, 'r') as f:
    lines = f.readlines()

lines = [x.strip() for x in lines]

label = np.array([x.split() for x in lines], dtype=np.float32)

# 物体类别
class_list = ["motorbike","car","cone"]

# 类别的颜色
class_color = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),(255, 0, 0), (0, 255, 0)]
# 关键点的顺序
keypoint_list = ["11", "22", "33", "44","55", "66", "77", "88", "99"]
# 关键点的颜色
keypoint_color = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),(255, 0, 0), (0, 255, 0)]

# 绘制检测框
img_copy = img.copy()
h, w = img_copy.shape[:2]
for id, l in enumerate(label):
    # label_id ,center x,y and width, height
    label_id, cx, cy, bw, bh = l[0:5]
    label_text = class_list[int(label_id)]
    # rescale to image size
    cx *= w
    cy *= h
    bw *= w
    bh *= h

    # draw the bounding box
    xmin = int(cx - bw / 2)
    ymin = int(cy - bh / 2)
    xmax = int(cx + bw / 2)
    ymax = int(cy + bh / 2)
    cv2.rectangle(img_copy, (xmin, ymin), (xmax, ymax), class_color[int(label_id)], 2)
    cv2.putText(img_copy, label_text, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, class_color[int(label_id)], 2)

# display the image
plt.figure(figsize=(15, 10))
plt.imshow(img_copy[:, :, ::-1])
plt.axis('off')
# save the image
cv2.imwrite("./tmp.png", img_copy)

img_copy = img.copy()
h, w = img_copy.shape[:2]
for id, l in enumerate(label):
    # label_id ,center x,y and width, height
    label_id, cx, cy, bw, bh = l[0:5]
    label_text = class_list[int(label_id)]
    # rescale to image size
    cx *= w
    cy *= h
    bw *= w
    bh *= h

    # draw the bounding box
    xmin = int(cx - bw / 2)
    ymin = int(cy - bh / 2)
    xmax = int(cx + bw / 2)
    ymax = int(cy + bh / 2)
    cv2.rectangle(img_copy, (xmin, ymin), (xmax, ymax), class_color[int(label_id)], 2)
    cv2.putText(img_copy, label_text, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 2, class_color[int(label_id)], 2)

    # draw 17 keypoints, px,py,pv,px,py,pv...
    for i in range(5, len(l), 3):
        px, py = l[i:i + 2]
        # rescale to image size
        px *= w
        py *= h
        # puttext the index
        index = int((i - 5) / 2)
        # draw the keypoints
        if(int(px)>0):
         cv2.circle(img_copy, (int(px), int(py)), 10, (0,255,255), -1)

plt.figure(figsize=(15, 10))
plt.imshow(img_copy[:, :, ::-1])
plt.axis('off')
# save
cv2.imwrite('./tmp.png', img_copy)
cv2.imshow('tmp', img_copy)
cv2.waitKey(0)

可视化结果如下

在这里插入图片描述

ultralytics\ultralytics\路径下,创建data文件夹,将图片和标签按下面的结构摆放:
在这里插入图片描述

三、配置文件设置

3.1 修改coco-pose.yaml

修改ultralytics\ultralytics\cfg\datasets\coco-pose.yaml配置文件内容:

path: ultralytics/data/images   
train: train 
val: val

# Keypoints
# 9:多目标中关键点最多的那个关键点数量
# 3: x, y和关键点可见性
kpt_shape: [9, 3]
flip_idx: [0, 1, 2, 3,4,5,6,7,8,9]

# Classes
names:
  0: motorbike
  1: car
  2: cone

四、训练

4.1 下载预训练权重

在YOLOv8 github上下载预训练权重:yolov8n-pose.pt,ultralytics\ultralytics\路径下,新建weight文件夹,预训练权重放入其中。
在这里插入图片描述

4.2 训练

步骤一:修改ultralytics\ultralytics\cfg\default.yaml文件中的训练参数(根据自己的实际情况决定)
步骤二:执行下面代码:

from ultralytics import YOLO

# Load a model
model = YOLO('ultralytics/weights/yolov8n-pose.pt')

# Train the model
results = model.train(data='D:/study/cnn/yolo/yolov8-mokpt/ultralytics/ultralytics/cfg/datasets/coco-pose.yaml', epochs=300, imgsz=640)

五、验证

from ultralytics import YOLO
 
def main():
    model = YOLO(r'runs/pose/train/weights/best.pt')
    model.val(data='data/multi-pose.yaml', imgsz=1024, batch=4, workers=4)
if __name__ == '__main__':
    main()

六、推理

根据自己实际的情况,修改

# 测试图片
from ultralytics import YOLO
import cv2
import numpy as np
import sys

# 读取命令行参数
weight_path = 'E:/YOLO/yolov8-mokpt/ultralytics/runs/pose/best.pt'
media_path = "demo/bev_2_1034.png"

# 加载模型
model = YOLO(weight_path)

# 获取类别
objs_labels = model.names  # get class labels
print(objs_labels)

# 类别的颜色
class_color = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),(255, 0, 0), (0, 255, 0)]
# 关键点的顺序
class_list = ["motorbike","car","cone"]

# 关键点的颜色
keypoint_color = [(255, 0, 0), (0, 255, 0),(255, 0, 0), (0, 255, 0),(255, 0, 0), (0, 255, 0),(255, 0, 0), (0, 255, 0),(255, 0, 0), (0, 255, 0)]

# 读取图片
frame = cv2.imread(media_path)
frame = cv2.resize(frame, (frame.shape[1] // 2, frame.shape[0] // 2))
# rotate
# 检测
result = list(model(frame, conf=0.3, stream=True))[0]  # inference,如果stream=False,返回的是一个列表,如果stream=True,返回的是一个生成器
boxes = result.boxes  # Boxes object for bbox outputs
boxes = boxes.cpu().numpy()  # convert to numpy array

# 遍历每个框
for box in boxes.data:
    l, t, r, b = box[:4].astype(np.int32)  # left, top, right, bottom
    conf, id = box[4:]  # confidence, class
    id = int(id)
    # 绘制框
    cv2.rectangle(frame, (l, t), (r, b), (0, 0, 255), 2)
    # 绘制类别+置信度(格式:98.1%)
    cv2.putText(frame, f"{objs_labels[id]} {conf * 100:.1f}", (l, t - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                (0, 0, 255), 1)

# 遍历keypoints
keypoints = result.keypoints  # Keypoints object for pose outputs
keypoints = keypoints.cpu().numpy()  # convert to numpy array

# draw keypoints, set first keypoint is red, second is blue
for keypoint in keypoints.data:
    for i in range(len(keypoint)):
        x, y ,_ = keypoint[i]
        x, y = int(x), int(y)
        cv2.circle(frame, (x, y), 3, (0, 255, 0), -1)
        #cv2.putText(frame, f"{keypoint_list[i]}", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, keypoint_color[i], 2)

    if len(keypoint) >= 2:
        # draw arrow line from tail to half between head and tail
        x0, y0 ,_= keypoint[0]
        x1, y1 ,_= keypoint[1]
        x2, y2 ,_= keypoint[2]
        x3, y3 ,_= keypoint[3]
        x4, y4 ,_= keypoint[4]
        x5, y5 ,_= keypoint[5]
        x6, y6 ,_= keypoint[6]
        x7, y7 ,_= keypoint[7]
        x8, y8 ,_= keypoint[8]


        cv2.line(frame, (int(x0), int(y0)), (int(x1), int(y1)), (255, 0, 255), 1)
        cv2.line(frame, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 255), 1)
        cv2.line(frame, (int(x2), int(y2)), (int(x3), int(y3)), (255, 0, 255), 1)
        cv2.line(frame, (int(x3), int(y3)), (int(x4), int(y4)), (255, 0, 255), 1)
        cv2.line(frame, (int(x4), int(y4)), (int(x5), int(y5)), (255, 0, 255), 1)
        cv2.line(frame, (int(x5), int(y5)), (int(x6), int(y6)), (255, 0, 255), 1)
        cv2.line(frame, (int(x6), int(y6)), (int(x7), int(y7)), (255, 0, 255), 1)
        cv2.line(frame, (int(x7), int(y7)), (int(x8), int(y8)), (255, 0, 255), 1)
        cv2.line(frame, (int(x8), int(y8)), (int(x0), int(y0)), (255, 0, 255), 1)


        #center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
       # cv2.arrowedLine(frame, (int(x2), int(y2)), (int(center_x), int(center_y)), (255, 0, 255), 4,
        #                line_type=cv2.LINE_AA, tipLength=0.1)

# save image
cv2.imwrite("result.jpg", frame)
print("save result.jpg")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/468845.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JNDI+LDAP攻击手法

服务端: package com.naihe3; import java.net.InetAddress; import java.net.MalformedURLException; import java.net.URL;import javax.net.ServerSocketFactory; import javax.net.SocketFactory; import javax.net.ssl.SSLSocketFactory;import com.unboundid.…

27-4 文件上传漏洞 - 黑名单绕过

环境准备:构建完善的安全渗透测试环境:推荐工具、资源和下载链接_渗透测试靶机下载-CSDN博客 一、黑名单绕过和黑白名单机制: 黑名单:黑名单中的文件不允许通过。白名单:白名单中的文件允许通过。二、黑白名单判断: 当输入一串后缀如"sfahkfhakj"时,黑名单不…

docker安装配置dnsmasq

docker下载安装 参考:docker安装、卸载、配置、镜像 如果是低版本的额ubuntu,比如ubuntu16.04.7 LTS,为了加快下载速度,参考:Ubuntu16.04LTS安装Docker。 docker安装dnsmasq 下载dnsmasq镜像 首先镜像我们可以选择…

Solidity Uniswap V2 Output amount calculation

现在,我们即将实现高级交换,包括链式交换(例如,通过token B 将token A 交换为token C)。在实现之前,我们需要了解 Uniswap 如何计算输出量。让我们先弄清楚金额与价格的关系。 什么是价格?就是你…

KVM安装-kvm彻底卸载-docker安装Webvirtmgr

KVM安装和使用 一、安装 检测硬件是否支持KVM需要硬件的支持,使用命令查看硬件是否支持KVM。如果结果中有vmx(Intel)或svm(AMD)字样,就说明CPU的支持的 egrep ‘(vmx|svm)’ /proc/cpuinfo关闭selinux将 /etc/sysconfig/selinux 中的 SELinux=enforcing 修改为 SELinux=d…

工业智能网关的功能特点、应用及其对企业产生的价值-天拓四方

一、工业智能网关的功能特点 工业智能网关是一种具备数据采集、传输、处理能力的智能设备,它能够将工业现场的各种传感器、执行器、控制器等设备连接起来,实现设备间的信息互通与协同工作。同时,工业智能网关还具备强大的数据处理能力&#…

超火短剧分销推广项目cps,现在做还不晚(完整教程)

短剧是一种介于短视频和长视频之间的中视频模式,以爽点和反转为特点,讲究引人入胜,刺激消费。更白话一点表达,短剧就是压缩版的电视剧,易上头上瘾,易冲动消费。 所以,使用“蜂小推”进行短剧分…

xss.pwnfunction(DOM型XSS)靶场

环境进入该网站 Challenges (pwnfunction.com) 第一关&#xff1a;Ma Spaghet! 源码&#xff1a; <!-- Challenge --> <h2 id"spaghet"></h2> <script>spaghet.innerHTML (new URL(location).searchParams.get(somebody) || "Somebo…

实地研究降本增效的杀伤力,LSTM算法实现全国失业率分析预测

前言 ​ 降本增效降本增笑&#xff1f;增不增效暂且不清楚&#xff0c;但是这段时间大厂的产品频繁出现服务器宕机和产品BUG确实是十分增笑。目前来看降本增效这一理念还会不断渗透到各行各业&#xff0c;不单单只是互联网这块了&#xff0c;那么对于目前就业最为严峻的一段时…

五款软件让效率飞跃

幸运的是&#xff0c;随着信息技术的不断演进&#xff0c;一系列高效的软件工具应运而生&#xff0c;它们旨在简化我们的日常工作&#xff0c;帮助我们以更少的时间完成更多的任务。下面&#xff0c;将介绍五款能够有效提升您工作效率的软件神器。 1、亿可达 他是一款自动化工…

从自动化到测开,测试人员逆袭之路从此起步!

在当今竞争激烈的软件测试行业中&#xff0c;近期的招聘市场确实面临一些挑战。大量的求职者争相涌入岗位&#xff0c;许多热衷于功能测试的人士甚至难以找到理想的工作机会。更不幸的是&#xff0c;连自动化测试和性能测试这些专业领域也受到了测试开发人员的竞争压力。然而&a…

Ubuntu使用Docker部署Nginx容器并结合内网穿透实现公网访问本地服务

目录 ⛳️推荐 1. 安装Docker 2. 使用Docker拉取Nginx镜像 3. 创建并启动Nginx容器 4. 本地连接测试 5. 公网远程访问本地Nginx 5.1 内网穿透工具安装 5.2 创建远程连接公网地址 5.3 使用固定公网地址远程访问 ⛳️推荐 前些天发现了一个巨牛的人工智能学习网站&#…

作业:基于udp的tftp文件传输实例

#include <head.h> #include <sys/types.h> #include <sys/socket.h> #include <arpa/inet.h> #include <errno.h>#define PORT 69 //服务器绑定的端口号 #define IP "192.168.1.107" //服务器的IP地址int do_download(i…

adobe animate 时间轴找不到编辑多个帧按钮

如题&#xff0c;找了半天&#xff0c;在时间轴上找不到编辑多个帧按钮,导致无法批量处理帧 然后搜索发现原来是有些版本被隐藏了&#xff0c;需要再设置一下 勾选上就好了

怎么进行流程图制作?这种方法一看就会

怎么进行流程图制作&#xff1f;在当今这个信息爆炸的时代&#xff0c;流程图作为一种直观、高效的表达方式&#xff0c;被广泛应用于各种工作场景。无论是项目管理、流程优化&#xff0c;还是产品设计、教育培训&#xff0c;流程图都能帮助我们更好地理解、分析和优化工作流程…

如何查看chrome里network的payload

如何查看chrome的network的请求payload&#xff0c;点击漏斗形状的过滤器&#xff0c;过滤框清空&#xff0c;表示检测所有&#xff0c;右边按钮点击“全部”&#xff0c;“第三方请求”不要勾选。

数字化金融展厅设计要点,你get到了吗?

近年间随着各类数字化主题展厅的出圈&#xff0c;让这种数字多媒体的设计概念逐渐深入至各个领域&#xff0c;这其中也包含了金融主题展厅&#xff0c;与传统展厅不同的是&#xff0c;借助了先进的技术和设备的数字化展厅&#xff0c;能提供更为丰富、个性化的参观体验&#xf…

Java实现定时发送邮件(基于Springboot工程)

1、功能概述&#xff1f; 1、在企业中有很多需要定时提醒的任务&#xff1a;如每天下午四点钟给第二天的值班人员发送值班消息&#xff1f;如提前一天给参与第二天会议的人员发送参会消息等。 2、这种定时提醒有很多方式如短信提醒、站内提醒等邮件提醒是其中较为方便且廉价的…

opengl日记8-opengl创建三角形

文章目录 环境直接上代码一点小总结参考 环境 系统&#xff1a;ubuntu20.04opengl版本&#xff1a;4.6glfw版本&#xff1a;3.3glad版本&#xff1a;4.6cmake版本&#xff1a;3.16.3gcc版本&#xff1a;10.3.0 直接上代码 CMakeLists.txt cmake_minimum_required(VERSION 2…

开源离线语音识别输入工具CapsWriter v1.0——支持无限时长语音、音视频文件转录字幕。

分享一款开源离线语音识别输入工具&#xff0c;支持无限时长语音、音视频文件转录字幕。 软件简介&#xff1a; CapsWriter是一款免费开源且可完全离线识别的语音输入工具&#xff0c;无需担心因在线版本识别带来的各种隐私泄露问题。支持win7及以上的系统&#xff0c;已经更…