联合目标检测与图像分类提升数据不平衡场景下的准确率

联合目标检测与图像分类提升数据不平衡场景下的准确率

在一些数据不平衡的场景下,使用单一的目标检测模型很难达到99%的准确率。为了优化这一问题,适当将其拆解为目标检测模型图像分类模型的组合,可以更有效地控制最终效果,尤其是在添加焦点损失(focal loss)、调整超参数和数据预处理无效的情况下。以下是具体的实现方式及联合两个模型的推理代码。

整体功能概述

这段代码的主要功能包括:

  1. 加载目标检测和分类模型:使用两个 Ultralytics YOLO(YOLOv8/YOLOv11均可) 模型进行目标检测和分类。
  2. 处理图像:遍历指定输入文件夹中的所有图像,进行目标检测和分类。
  3. 绘制检测框和分类标签:在图像上绘制检测到的对象的边界框,并在框上方添加分类名称和置信度。
  4. 可选保存裁剪的对象图像:根据设置,裁剪检测到的对象区域并保存为单独的图像文件,文件名包含类别名称、置信度和坐标信息(便于调试)。

实现细节

1. 加载模型

代码加载了两个 YOLO 模型:

  • 目标检测模型:一个单一类别的 YOLO 模型,用于检测主体对象。
  • 图像分类模型:一个多类别的 YOLO 模型,用于对检测到的对象进行分类。

2. 处理图像

脚本处理输入文件夹中的每一张图像,步骤如下:

  • 目标检测:使用目标检测模型检测图像中的对象。
  • 裁剪检测到的对象:根据检测到的边界框坐标,裁剪出感兴趣的区域。
  • 图像分类:对裁剪出的对象区域进行分类。
  • 数据增强或欠采样:根据任务需求,对裁剪出的子图像进行数据增强或欠采样,以平衡数据集。

3. 绘制检测框和标签

对于每一个检测到的对象,脚本会:

  • 在图像上绘制一个边界框。
  • 在边界框上方添加分类名称和置信度标签。

4. 保存裁剪的对象图像

可选地,脚本会保存裁剪出的对象图像,文件名包含以下信息:

  • 分类名称
  • 置信度
  • 边界框坐标

这对于调试和分析特定的检测结果非常有帮助。

推理代码

import os
import cv2
import numpy as np
from pathlib import Path
from ultralytics import YOLO
import random

def generate_random_color_from_name(name):
    """根据类别名生成可重复的颜色。"""
    random.seed(name)  # 使用类别名作为随机种子
    return tuple(random.randint(0, 255) for _ in range(3))

def generate_class_colors(names):
    """为每个类别生成一个固定的颜色。"""
    class_colors = {}
    for class_name in names:
        class_colors[class_name] = generate_random_color_from_name(class_name)
    return class_colors

def draw_box_on_image(image, box, color=(0, 255, 0), thickness=2):
    """在图像上绘制检测框。"""
    x1, y1, x2, y2 = map(int, box)
    cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)

def add_classification_to_box(image, box, class_name, confidence, color=(0, 255, 0)):
    """在边界框上方添加分类名称和置信度。"""
    x1, y1, x2, y2 = map(int, box)
    label = f"{class_name}: {confidence:.2f}"
    cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2, cv2.LINE_AA)

def save_cropped_object(image, box, cls_class_name, confidence, output_folder, image_name):
    """将裁剪的对象区域保存为图像到子文件夹中,文件名包含类别名、置信度和坐标。"""
    x1, y1, x2, y2 = map(int, box)
    cropped_img = image[y1:y2, x1:x2]
    
    # 为当前图像创建一个以图像文件名命名的子文件夹
    image_subfolder = Path(output_folder) / Path(image_name).stem
    image_subfolder.mkdir(parents=True, exist_ok=True)
    
    # 为裁剪的对象创建文件名(class_name_confidence_x1_y1_x2_y2.jpg)
    # 确保置信度格式安全,使用两位小数,并用下划线分隔
    cropped_img_name = f"{cls_class_name}_{confidence:.2f}_{x1}_{y1}_{x2}_{y2}.jpg"
    cropped_img_path = image_subfolder / cropped_img_name
    cv2.imwrite(str(cropped_img_path), cropped_img)
    print(f"已保存裁剪对象: {cropped_img_path}")

def process_image_with_detection_and_classification(model_det, model_cls, img_path, names, class_colors, output_folder, save_cropped=False, detection_size=1280, classification_size=640):
    """
    处理单张图像:执行对象检测,分类每个对象,并返回处理后的图像。

    :param model_det: 检测模型
    :param model_cls: 分类模型
    :param img_path: 图像路径
    :param names: 类别名称列表
    :param class_colors: 类别颜色映射字典
    :param output_folder: 输出文件夹路径
    :param save_cropped: 是否保存裁剪的对象图像
    :param detection_size: 检测模型输入图像大小
    :param classification_size: 分类模型输入图像大小
    :return: 处理后的图像
    """
    img = cv2.imread(str(img_path))
    if img is None:
        print(f"无法读取图像: {img_path}")
        return None

    # 创建图像副本用于绘制(不修改原始图像)
    img_copy = img.copy()

    # 执行对象检测
    results_det = model_det.predict(str(img_path), imgsz=detection_size, conf=0.25, iou=0.45)

    # 处理每个检测结果(每个检测框)
    for r in results_det:
        boxes = r.boxes.xyxy.cpu().numpy()  # xyxy 格式
        classes = r.boxes.cls.cpu().numpy()
        confidences = r.boxes.conf.cpu().numpy()

        for box, cls_id, confidence in zip(boxes, classes, confidences):
            # 检测模型的类别名
            det_class_name = names[int(cls_id)]
            
            # 使用检测到的类别名对应的颜色(该颜色是全局唯一的)
            color = class_colors.get(det_class_name, (255, 255, 255))
            
            # 裁剪对象区域
            x1, y1, x2, y2 = map(int, box)
            object_region = img[y1:y2, x1:x2]
            # 将对象区域调整为分类模型的输入大小
            object_region = cv2.resize(object_region, (classification_size, classification_size))

            # 执行分类
            results_cls = model_cls.predict(object_region, imgsz=classification_size)

            for result in results_cls:
                try:
                    # 获取Top1预测结果
                    classification_confidence = result.probs.cpu().numpy().top1conf
                    top1_index = result.probs.top1
                    cls_class_name = names[top1_index]

                    # 根据分类结果的类别名设置颜色
                    final_color = class_colors.get(cls_class_name, color)
                    add_classification_to_box(img_copy, box, cls_class_name, classification_confidence, color=final_color)

                    # 如果启用了保存裁剪对象,则保存
                    if save_cropped:
                        save_cropped_object(img, box, cls_class_name, classification_confidence, output_folder, img_path.name)
                except Exception as e:
                    print(f"分类时出错: {e}")

            # 在图像副本上绘制检测框
            draw_box_on_image(img_copy, box, color=color)

    return img_copy

def process_images(model_det, model_cls, input_folder, output_folder, names, class_colors, save_cropped=False, detection_size=1280, classification_size=640):
    """
    处理输入文件夹中的图像,执行对象检测和分类,并保存处理后的图像。

    :param model_det: 检测模型
    :param model_cls: 分类模型
    :param input_folder: 输入文件夹路径
    :param output_folder: 输出文件夹路径
    :param names: 类别名称列表
    :param class_colors: 类别颜色映射字典
    :param save_cropped: 是否保存裁剪的对象图像
    :param detection_size: 检测模型输入图像大小
    :param classification_size: 分类模型输入图像大小
    """
    Path(output_folder).mkdir(parents=True, exist_ok=True)

    image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.webp']
    for ext in image_extensions:
        for img_path in Path(input_folder).glob(ext):
            print(f"正在处理: {img_path}")
            processed_img = process_image_with_detection_and_classification(
                model_det, model_cls, img_path, names, class_colors, output_folder, save_cropped, detection_size, classification_size
            )

            if processed_img is not None:
                output_image_path = Path(output_folder) / f"{img_path.stem}_with_boxes_and_classification.jpg"
                cv2.imwrite(str(output_image_path), processed_img)
                print(f"已保存处理后的图像: {output_image_path}")
            else:
                print(f"跳过图像: {img_path} (无法处理)")

if __name__ == '__main__':
    # 设置是否保存裁剪的对象图像(默认不保存)
    SAVE_CROPPED = True  # 设置为 True 以启用保存裁剪对象

    # 加载检测和分类模型
    model_det = YOLO('runs/device_train/exp9/weights/best.pt')
    model_cls = YOLO('runs/cls_99.4%_exp14/weights/best.pt')

    # 设置输入和输出文件夹路径
    input_folder = 'test1'
    output_folder = 'infer-1216'

    # 获取类别名(用于生成一致的类别颜色映射)
    # 这里使用一张全白的图像来获取类别名
    black_image = 255 * np.ones((224, 224, 3), dtype=np.uint8)
    results = model_cls.predict(source=black_image)
    name_dict = results[0].names
    names = list(name_dict.values())

    # 只在这里生成一次类别颜色映射
    class_colors = generate_class_colors(names)

    # 开始处理图像
    process_images(
        model_det, model_cls, input_folder, output_folder,
        names, class_colors,
        save_cropped=SAVE_CROPPED,
        detection_size=1280,
        classification_size=224
    )

执行完后的结果
在这里插入图片描述

下面贴一下目标检测和图像分类的ultralytics的训练代码

目标检测训练代码

注意把single_cls=False改成True,变成单类训练

# nohup python -m torch.distributed.launch --nproc_per_node=4 --master_port=25643 det_train.py > output-lane-1212.txt 2>&1 &
# nohup python -m torch.distributed.launch --nproc_per_node=5 --master_port=25698 det_train.py > output-lane-1212.txt 2>&1 &
from ultralytics import YOLO

if __name__ == '__main__':
    # 加载模型
    model = YOLO("checkpoints/yolo11l.pt")  # 使用预训练权重训练
    # 训练参数 ----------------------------------------------------------------------------------------------
    model.train(
        data='/home/lizhijun/01.det/ultralytics-8.3.23/datasets/device_1212_yolo_without_vdd/config.yaml',
        epochs=150,  # (int) 训练的周期数
        patience=50,  # (int) 等待无明显改善以进行早期停止的周期数
        batch=16,  # (int) 每批次的图像数量(-1 为自动批处理)
        imgsz=1280,  # (int) 输入图像的大小,整数或w,h
        save=True,  # (bool) 保存训练检查点和预测结果
        save_period=-1,  # (int) 每x周期保存检查点(如果小于1则禁用)
        cache=False,  # (bool) True/ram、磁盘或False。使用缓存加载数据
        device='1,2,3,5',  # (int | str | list, optional) 运行的设备,例如 cuda device=0 或 device=0,1,2,3 或 device=cpu
        workers=8,  # (int) 数据加载的工作线程数(每个DDP进程)
        project='runs/device_train',  # (str, optional) 项目名称
        name='exp',  # (str, optional) 实验名称,结果保存在'project/name'目录下
        exist_ok=False,  # (bool) 是否覆盖现有实验
        pretrained=True,  # (bool | str) 是否使用预训练模型(bool),或从中加载权重的模型(str)
        optimizer='auto',  # (str) 要使用的优化器,选择=[SGD,Adam,Adamax,AdamW,NAdam,RAdam,RMSProp,auto]
        verbose=True,  # (bool) 是否打印详细输出
        seed=0,  # (int) 用于可重复性的随机种子
        deterministic=True,  # (bool) 是否启用确定性模式
        single_cls=False,  # (bool) 将多类数据训练为单类
        rect=False,  # (bool) 如果mode='train',则进行矩形训练,如果mode='val',则进行矩形验证
        cos_lr=True,  # (bool) 使用余弦学习率调度器
        close_mosaic=10,  # (int) 在最后几个周期禁用马赛克增强
        resume=False,  # (bool) 从上一个检查点恢复训练
        amp=True,  # (bool) 自动混合精度(AMP)训练,选择=[True, False],True运行AMP检查
        fraction=1.0,  # (float) 要训练的数据集分数(默认为1.0,训练集中的所有图像)
        profile=False,  # (bool) 在训练期间为记录器启用ONNX和TensorRT速度
        freeze= None,  # (int | list, 可选) 在训练期间冻结前 n 层,或冻结层索引列表。
        # 超参数 ----------------------------------------------------------------------------------------------
        lr0=0.01,  # (float) 初始学习率(例如,SGD=1E-2,Adam=1E-3)
        lrf=0.01,  # (float) 最终学习率(lr0 * lrf)
        momentum=0.937,  # (float) SGD动量/Adam beta1
        weight_decay=0.0005,  # (float) 优化器权重衰减 5e-4
        warmup_epochs=3.0,  # (float) 预热周期(分数可用)
        warmup_momentum=0.8,  # (float) 预热初始动量
        warmup_bias_lr=0.1,  # (float) 预热初始偏置学习率
        box=6,  # (float) 盒损失增益
        cls=1.5,  # (float) 类别损失增益(与像素比例)
        dfl=1.5,  # (float) dfl损失增益
        pose=12.0,  # (float) 姿势损失增益
        kobj=1.0,  # (float) 关键点对象损失增益
        label_smoothing=0.05,  # (float) 标签平滑(分数)
        nbs=64,  # (int) 名义批量大小
        hsv_h=0.015,  # (float) 图像HSV-Hue增强(分数)
        hsv_s=0.7,  # (float) 图像HSV-Saturation增强(分数)
        hsv_v=0.4,  # (float) 图像HSV-Value增强(分数)
        degrees=90.0,  # (float) 图像旋转(+/- deg)
        translate=0.5,  # (float) 图像平移(+/- 分数)
        scale=0.5,  # (float) 图像缩放(+/- 增益)
        shear=0.4,  # (float) 图像剪切(+/- deg)
        perspective=0.0,  # (float) 图像透视(+/- 分数),范围为0-0.001
        flipud=0.5,  # (float) 图像上下翻转(概率)
        fliplr=0.5,  # (float) 图像左右翻转(概率)
        mosaic=1.0,  # (float) 图像马赛克(概率)
        mixup=0.0,  # (float) 图像混合(概率)
        copy_paste=0.0,  # (float) 分割复制-粘贴(概率)
    )



图像分类训练代码

from ultralytics import YOLO

model = YOLO("checkpoints/yolo11l-cls.pt")
model.train(
    data='/home/lizhijun/01.det/ultralytics-8.3.23/datasets/device_cls_merge_manual_with_21w_1218_train_val_224_truncate_grid_110%', 
    project='runs/cls_train',  # (str, optional) 项目名称
    name='exp',  # (str, optional) 实验名称,结果保存在'project/name'目录下
    epochs=20, 
    batch=1024,
    device='1,2,3,5',
    erasing=0.0,
    crop_fraction=1.0,
    augment=False,
    auto_augment=False,
    hsv_h=0.015,  # (float) 图像HSV-Hue增强(分数)
    hsv_s=0.7,  # (float) 图像HSV-Saturation增强(分数)
    hsv_v=0.4,  # (float) 图像HSV-Value增强(分数)
    degrees=0.0,  # (float) 图像旋转(+/- deg)
    translate=0.0,  # (float) 图像平移(+/- 分数)
    scale=0.0,  # (float) 图像缩放(+/- 增益)
    shear=0.0,  # (float) 图像剪切(+/- deg)
    perspective=0.0,  # (float) 图像透视(+/- 分数),范围为0-0.001
    flipud=0.5,  # (float) 图像上下翻转(概率)
    fliplr=0.5,  # (float) 图像左右翻转(概率)
    mosaic=1.0,  # (float) 图像马赛克(概率)
    mixup=0.0)  # (float) 图像混合(概率))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/942297.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++之红黑树模拟实现

目录 红黑树的概念 红黑树的性质 红黑树的查找效率 红黑树的实现 红黑树的定义 红黑树节点的插入 红黑树的平衡调整 判断红黑树是否平衡 红黑树整体代码 测试代码 上期我们学习了AVL树的模拟实现,在此基础上,我们本期将学习另一个数据结构-…

机器学习常用术语

目录 概要 机器学习常用术语 1、模型 2、数据集 3、样本与特征 4、向量 5、矩阵 6、假设函数与损失函数 7、拟合、过拟合与欠拟合 8、激活函数(Activation Function) 9、反向传播(Backpropagation) 10、基线(Baseline) 11、批量(Batch) 12、批量大小(Batch Size)…

nest 学习3

学习小册(nest通关秘籍) 邮箱验证码登陆 流程图: 邮箱作为key,生成随机验证码,然后放到redis中。调用邮箱api发送邮箱。 前端获取到code后,将验证码输入传给后端,后端根据邮箱取出redis数据,比对验证码&…

原点安全再次入选信通院 2024 大数据“星河”案例

近日,中国信息通信研究院和中国通信标准化协会大数据技术标准推进委员会(CCSA TC601)共同组织开展的 2024 大数据“星河(Galaxy)”案例征集活动结果正式公布。由工银瑞信基金管理有限公司、北京原点数安科技有限公司联…

RabbitMQ 的7种工作模式

RabbitMQ 共提供了7种⼯作模式,进⾏消息传递,. 官⽅⽂档:RabbitMQ Tutorials | RabbitMQ 1.Simple(简单模式) P:⽣产者,也就是要发送消息的程序 C:消费者,消息的接收者 Queue:消息队列,图中⻩⾊背景部分.类似⼀个邮箱,可以缓存消息;⽣产者向其中投递消息,消费者从其中取出消息…

Restaurants WebAPI(四)——Identity

文章目录 项目地址一、Authentication(身份认证)1.1 配置环境(解决类库包无法引用)1.2 使用Authentication控制Controller的访问1.3 获取User的Context1.3.1 在Application下创建User文件夹1. 创建User.cs record类封装角色信息2. 创建UserContext.cs提供…

010 Qt_输入类控件(LineEdit、TextEdit、ComboBox、SpinBox、DateTimeEdit、Dial、Slider)

文章目录 前言一、QLineEdit1.简介2.常见属性及说明3.重要信号及说明4.示例一:用户登录界面5.示例二:验证两次输入的密码是否一致显示密码 二、TextEdit1.简介2.常见属性及说明3.重要信号及说明4.示例一:获取多行输入框的内容5.示例二&#x…

Vue3:uv-upload图片上传

效果图&#xff1a; 参考文档&#xff1a; Upload 上传 | 我的资料管理-uv-ui 是全面兼容vue32、nvue、app、h5、小程序等多端的uni-app生态框架 (uvui.cn) 代码&#xff1a; <view class"greenBtn_zw2" click"handleAddGroup">添加班级群</vie…

通过Docker Compose来实现项目可以指定读取不同环境的yml包

通过Docker Compose来实现项目可以指定读取不同环境的yml包 1. 配置文件2. 启动命令 切换不同环境注意挂载的文件权限要777 1. 配置文件 version: 3.8 services:docker-test:image: openjdk:8-jdk-alpineports:- "${APP_PORT}:${CONTAINER_PORT}"volumes:- "${J…

华为实训课笔记 2024 1223-1224

华为实训 12/2312/24 12/23 [Huawei]stp enable --开启STP display stp brief --查询STP MSTID Port Role STP State Protection 实例ID 端口 端口角色 端口状态 是否开启保护[Huawei]display stp vlan xxxx --查询制定vlan的生成树计算结…

GitCode 光引计划投稿 | GoIoT:开源分布式物联网开发平台

GoIoT 是基于Gin 的开源分布式物联网&#xff08;IoT&#xff09;开发平台&#xff0c;用于快速开发&#xff0c;部署物联设备接入项目&#xff0c;是一套涵盖数据生产、数据使用和数据展示的解决方案。 GoIoT 开发平台&#xff0c;它是一个企业级物联网平台解决方案&#xff…

EasyGBS国标GB28181公网平台P2P远程访问故障诊断:云端服务端排查指南

随着信息技术的飞速发展&#xff0c;视频监控领域正经历从传统安防向智能化、网络化安防的深刻转变。EasyGBS平台&#xff0c;作为基于国标GB28181协议的视频流媒体平台&#xff0c;为用户提供了强大的视频监控直播功能。然而&#xff0c;在实际应用中&#xff0c;P2P远程访问可…

Vnlhun靶场Log4j2漏洞

相关概念 log4j2是Apache的⼀个java日志框架&#xff0c;我们借助它进行日志相关操作管理&#xff0c;然而在2021年末log4j2爆出了远程代码执行漏洞&#xff0c;属于严重等级的漏洞 漏洞原理 简单说就是当你使⽤log4j2中提供的⽅法去输出⽇志信息时&#xff0c;⽐如说最常⻅…

千兆网中的gmii与rgmii

物理链路上是千兆网。1 Gbps1000 Mb/s1000/8 MB/s125 MB/s&#xff0c;这是和你的测试设备相连的1 Gbps物理带宽下的极速。关键点是1 B&#xff08;byte&#xff09;8 b&#xff08;bit&#xff09;。实际下载速度还取决于下载源的限制、出口的物理链路和运营商的限制。

2024-12-24 NO1. XR Interaction ToolKit 环境配置

文章目录 1 软件配置2 安装 XRToolKit3 配置 OpenXR4 安装示例场景5 运行测试 1 软件配置 Unity 版本&#xff1a;Unity6000.0.26 ​ 2 安装 XRToolKit 创建新项目&#xff08;URP 3D&#xff09;&#xff0c;点击进入 Asset Store。 进入“Unity Registry”页签&#xff0…

重温设计模式--外观模式

文章目录 外观模式&#xff08;Facade Pattern&#xff09;概述定义 外观模式UML图作用 外观模式的结构C 代码示例1C代码示例2总结 外观模式&#xff08;Facade Pattern&#xff09;概述 定义 外观模式是一种结构型设计模式&#xff0c;它为子系统中的一组接口提供了一个统一…

【恶意软件检测】一种基于API语义提取的Android恶意软件检测方法(期刊等级:CCF-B、Q2)

一种基于API语义提取的Android恶意软件检测方法 A novel Android malware detection method with API semantics extraction 摘要 由于Android框架和恶意软件的持续演变&#xff0c;使用过时应用程序训练的传统恶意软件检测方法在有效识别复杂演化的恶意软件方面已显不足。为…

【微信小程序】2|轮播图 | 我的咖啡店-综合实训

轮播图 引言 在微信小程序中&#xff0c;轮播图是一种常见的用户界面元素&#xff0c;用于展示广告、产品图片等。本文将通过“我的咖啡店”小程序的轮播图实现&#xff0c;详细介绍如何在微信小程序中创建和管理轮播图。 轮播图数据准备 首先&#xff0c;在home.js文件中&a…

RT-DETR学习笔记(2)

七、IOU-aware query selection 下图是原始DETR。content query 是初始化为0的label embedding, position query 是通过nn.Embedding初始化的一个嵌入矩阵&#xff0c;这两部分没有任何的先验信息&#xff0c;导致DETR的收敛慢。 RT-DETR则提出要给这两部分&#xff08;conten…

fpgafor循环语句使用

genvar i;//循环变量名称 generate for(i0;i<4;ii1)begin:tx//自己定义名称 //循环内容 end endgenerate12位的16进制乘以4就是48位位宽的2进制 因为 222*2(2^4)16