【目标检测】YOLOv5算法实现(八):模型验证

  本系列文章记录本人硕士阶段YOLO系列目标检测算法自学及其代码实现的过程。其中算法具体实现借鉴于ultralytics YOLO源码Github,删减了源码中部分内容,满足个人科研需求。
  本系列文章主要以YOLOv5为例完成算法的实现,后续修改、增加相关模块即可实现其他版本的YOLO算法。

文章地址:
YOLOv5算法实现(一):算法框架概述
YOLOv5算法实现(二):模型加搭建
YOLOv5算法实现(三):数据集加载
YOLOv5算法实现(四):损失计算
YOLOv5算法实现(五):预测结果后处理
YOLOv5算法实现(六):评价指标及实现
YOLOv5算法实现(七):模型训练
YOLOv5算法实现(八):模型验证
YOLOv5算法实现(九):模型预测(编辑中…)

本文目录

  • 0 引言
  • 1 模型验证(validation.py)

0 引言

  本篇文章综合之前文章中的功能,实现模型的验证。模型验证的逻辑如图1所示。
在这里插入图片描述

图1 模型验证流程

1 模型验证(validation.py)

def validation(parser_data):
    device = torch.device(parser_data.device if torch.cuda.is_available() else "cpu")
    print("Using {} device validation.".format(device.type))

    # read class_indict
    label_json_path = './data/object.json'
    assert os.path.exists(label_json_path), "json file {} dose not exist.".format(label_json_path)
    with open(label_json_path, 'r') as f:
        class_dict = json.load(f)

    category_index = {v: k for k, v in class_dict.items()}

    data_dict = parse_data_cfg(parser_data.data)
    test_path = data_dict["valid"]

    # 注意这里的collate_fn是自定义的,因为读取的数据包括image和targets,不能直接使用默认的方法合成batch
    batch_size = parser_data.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using %g dataloader workers' % nw)

    # load validation data set
    val_dataset = LoadImagesAndLabels(test_path, parser_data.img_size, batch_size,
                                      hyp=parser_data.hyp,
                                      rect=False)  # 将每个batch的图像调整到合适大小,可减少运算量(并不是512x512标准尺寸)

    val_dataset_loader = torch.utils.data.DataLoader(val_dataset,
                                                     batch_size=batch_size,
                                                     shuffle=True,
                                                     num_workers=nw,
                                                     pin_memory=True,
                                                     collate_fn=val_dataset.collate_fn)

    # create model
    model = Model(parser_data.cfg, ch=3, nc=parser_data.nc)
    weights_dict = torch.load(parser_data.weights, map_location='cpu')
    weights_dict = weights_dict["model"] if "model" in weights_dict else weights_dict
    model.load_state_dict(weights_dict, strict=False)
    model.to(device)

    # evaluate on the test dataset

    # 计算PR曲线和AP
    stats = []
    iouv = torch.linspace(0.5, 0.95, 10, device=device)  # iou vector for mAP@0.5:0.95
    niou = iouv.numel()
    # 混淆矩阵
    confusion_matrix = ConfusionMatrix(nc=3, conf=0.6)
    model.eval()

    with torch.no_grad():
        for imgs, targets, paths, shapes, img_index in tqdm(val_dataset_loader, desc="validation..."):
            imgs = imgs.to(device).float() / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0
            nb, _, height, width = imgs.shape  # batch size, channels, height, width
            targets = targets.to(device)
            preds = model(imgs)[0]  # only get inference result
            preds = non_max_suppression(preds, conf_thres=0.3, iou_thres=0.6, multi_label=False)
            targets[:, 2:] *= torch.tensor((width, height, width, height), device=device)
            outputs = []
            for si, pred in enumerate(preds):
                '''
                labels: [clas, x, y, w, h] (训练图像上绝对坐标)
                pred: [x,y,x,y,obj,cls] (训练图像上绝对坐标)
                predn: [x,y,x,y,obj,cls] (输入图像上绝对坐标)
                labels: [x,y,x,y,class] (输入图像上绝对坐标)
                shapes[si][0]: 输入图像大小
                shapes[si][1]
                '''
                labels = targets[targets[:, 0] == si, 1:]  # 当前图片的标签信息
                nl = labels.shape[0]  # number of labels # 当前图片标签数量
                if pred is None:
                    npr = 0
                else:
                    npr = pred.shape[0]  # 预测结果数量
                correct = torch.zeros(npr, niou, dtype=torch.bool, device=device)  # 判断在不同IoU下预测是否预测正确
                path, shape = Path(paths[si]), shapes[si][0]  # 当前图片shape(原图大小)
                if npr == 0:  # 若没有预测结果
                    if nl:  # 没有预测结果但有实际目标
                        # 不同IoU阈值下预测准确率,目标类别置信度,预测类别,实际类别
                        stats.append((correct, *torch.zeros((2, 0), device=device), labels[:, 0]))
                        # 混淆矩阵计算(类别信息)
                        confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
                    continue
                predn = pred.clone()
                scale_boxes(imgs[si].shape[1:], predn[:, :4], shape, shapes[si][1])  # native-space pred
                if nl:  # 有预测结果且有实际目标
                    tbox = xywh2xyxy(labels[:, 1:5])  # target boxes
                    scale_boxes(imgs[si].shape[1:], tbox, shape, shapes[si][1])  # native-space labels
                    labelsn = torch.cat((labels[:, 0:1], tbox), 1)  # native-space labels
                    correct = process_batch(predn, labelsn, iouv)
                    confusion_matrix.process_batch(predn, labelsn)
                stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0]))  # 预测结果在不同IoU是否预测正确, 预测置信度, 预测类别, 实际类别
        confusion_matrix.plot(save_dir=parser_data.save_path, names=["normal", 'defect', 'leakage'])

    # 图片:预测结果在不同IoU下预测结果,预测置信度,预测类别,实际类别
    stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)]  # to numpy
    if len(stats) and stats[0].any():
        tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, names=["normal", 'defect', 'leakage'])
        ap50, ap = ap[:, 0], ap.mean(1)  # AP@0.5, AP@0.5:0.95
        mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
        print(map50)

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description=__doc__)

    # 使用设备类型
    parser.add_argument('--device', default='cuda', help='device')

    # 检测目标类别数
    parser.add_argument('--nc', type=int, default=3, help='number of classes')
    file = 'yolov5s'
    cfg = f'cfg/models/{file}.yaml'
    parser.add_argument('--cfg', type=str, default=cfg, help="*.cfg path")
    parser.add_argument('--data', type=str, default='data/my_data.data', help='*.data path')
    parser.add_argument('--hyp', type=str, default='cfg/hyps/hyp.scratch-med.yaml', help='hyperparameters path')
    parser.add_argument('--img-size', type=int, default=640, help='test size')

    # 训练好的权重文件
    weight_1 = f'./weights/{file}/{file}' + '-best_map.pt'
    weight_2 = f'./weights/{file}/{file}' + '.pt'
    weight = weight_1 if os.path.exists(weight_1) else weight_2
    parser.add_argument('--weights', default=weight, type=str, help='training weights')
    parser.add_argument('--save_path', default=f'results/{file}', type=str, help='result save path')

    # batch size
    parser.add_argument('--batch_size', default=2, type=int, metavar='N',
                        help='batch size when validation.')

    args = parser.parse_args()

    validation(args)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/325971.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于Java SSM框架实现学生综合考评管理系统项目【项目源码+论文说明】

基于java的SSM框架实现学生学生综合考评管理系统演示 摘要 随着社会的发展,社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 学生综合考评管理系统,主要的模块包括查看;管理员;个…

统计学-R语言-5.1

文章目录 前言随机性和规律性概率变量的分布离散型--二项、泊松、几何二项分布几何分布泊松分布 连续型--均匀、正态均匀分布正态分布 其它统计分布--χ2分布、t分布、F分布χ2分布t分布F分布 练习 前言 从本篇文章开始介绍有关概率与分布的介绍。 随机性和规律性 当不能预测…

2024 基于 Rust 的 linter 工具速度很快

2024 年 Web 工具的一大趋势是使用 Rust 重写现有工具。Rust 是一种出色的编程语言,能生成运行速度惊人的二进制文件,且与其它 Web 工具的互操作性极佳,这得益于 WebAssembly 的帮助。swc 和 Turbopack 等工具的速度提升为快速开发体验带来了…

麒麟KYLINOS操作系统上使用五种不同方式安装软件

原文链接:麒麟KYLINOS使用五种不同方式安装软件 hello,大家好啊!今天我要给大家介绍的是在麒麟KYLINOS操作系统上使用五种不同方式安装软件的方法。在Linux系统中,有多种方式可以安装软件,每种方式都适用于不同的场景和…

Grind75第11天 | 310.最小高度树、127.单词接龙、230.二叉搜索树中第k小的元素

310.最小高度树 题目链接:https://leetcode.com/problems/minimum-height-trees 解法: 这个题类似最短路径问题,用的BFS。 从题目的例子可以看到,最小高度树的根节点,好像是入度比较大的节点,这是一个大…

MySQL基础笔记(6)函数

函数:是指一段可以直接被另一段程序调用的程序或者代码~(MySQL内置) 一.字符串函数 trim不能去除中间的空格~ select concat(jsl,1325): 执行如上的代码,返回字符串"jsl1325"。 select lower(JSL); 执行如上的代码&…

【健康小贴士】关节炎是不是冻出来的?

大家冬天肯定被父母唠叨过: 「天气这么冷,裤子穿这么短,小心得关节炎!」 ❌这种说法其实是不对的或者并不全面,答案来了👀

【PAT甲级】1175 Professional Ability Test

问题思考: 首先,若所有的计划(plan)中的节点都可达,则输出 Okay,否则输出 Impossible。注意:这里的“plan”判断的是整个图(这里是有向图)上的节点,而不只是那K个queries节点。若存…

基于Java SSM框架实现学生寝室管理系统项目【项目源码+论文说明】计算机毕业设计

基于java的SSM框架实现学生寝室管理系统演示 摘要 寝室管理设计是高校为学生提供第二课堂,而我们所在学院多采用半手工管理学生寝室的方式,所以有必要开发寝室管理系统来对进行数字化管理。既可减轻学院宿舍长工作压力,比较系统地对宿舍通告…

如何将千亿文件放进一个文件系统,EuroSys‘23 CFS 论文背后的故事

1. 引言 本文的主要目的是解读百度沧海存储团队发表于 EuroSys 2023 的论文《CFS: Scaling Metadata Service for Distributed File System via Pruned Scope of Critical Sections》,论文全文可以在 CFS: Scaling Metadata Service for Distributed File System v…

存内计算技术打破常规算力局限性

目录 前言 关于存内计算 1、常规算力局限性 2、存内计算诞生记 3、存内计算核心 存内计算芯片研发历程及商业化 1、存内计算芯片研发历程 2、存内计算先驱出道 3、存内计算商业化落地 基于知存科技存内计算开发板ZT1的降噪验证 (一)任务目标以…

Linux上新部署的项目jar包没有生效

今天公司新安排了一个项目,这里简称项目A,需要新增两个功能,我这边完成之后,跟前端对接好了,调试也没有问题。 然后把项目打包上传到测试服务器上,重新启动项目,发现项目A新增的接口没有生效&a…

QT属性动画

时间记录:2024/1/15 一、介绍 属性动画类为QPropertyAnimation,类似于CSS的keyframes关键帧 二、分类及使用步骤 1.几何动画 (1)创建QPropertyAnimation对象 (2)setPropertyName方法设置属性名称&#…

MetaGPT入门(二)

接着MetaGPT入门(一),在文件里再添加一个role类 class SimpleCoder(Role):def __init__(self,name:str"Alice",profile:str"SimpleCoder",**kwargs):super().__init__(name,profile,**kwargs)self._init_actions([Write…

【设计模式-3.3】结构型——享元模式

说明:说明:本文介绍设计模式中结构型设计模式中的,享元模式; 游戏地图 在一些闯关类的游戏,如超级玛丽、坦克大战里面,游戏的背景每一个关卡都不相同,但仔细观察可以发现,其都是用…

第十五讲_css水平垂直居中的技巧

css水平垂直居中的技巧 1. 水平垂直居中(场景一)2. 水平垂直居中(场景二)3. 水平垂直居中(场景三)4. 水平垂直居中(场景四) 1. 水平垂直居中(场景一) 条件&a…

Python UI框架库之kivy使用详解

概要 Python是一种广泛使用的编程语言,而Kivy是一个用于创建跨平台移动应用和多点触控应用的开源Python框架。Kivy的设计目标是提供一种简单而强大的方式来构建富有创意的用户界面和交互体验。本文将详细介绍Kivy的基本概念、核心特性、布局系统、用户界面设计和实…

【服务器数据恢复】服务器迁移数据时lun数据丢失的数据恢复案例

服务器数据恢复环境&服务器故障: 一台安装Windows操作系统的服务器。工作人员在迁移该服务器中数据时突然无法读取数据,服务器管理界面出现报错。经过检查发现服务器中一个lun的数据丢失。 服务器数据恢复过程: 1、将故障服务器中所有磁盘…

macOS 13(本机)golang程序交叉编译成 ARM架构

## 背景 golang程序(JuiceFS)需要支持ARM64架构,重新编译; 本地环境:macOS:13 ## 操作 安装交叉编译工具: brew install FiloSottile/musl-cross/musl-cross --with-aarch64 可以在 /usr/l…

【MATLAB随笔】遗传算法优化的BP神经网络(随笔,不是很详细)

文章目录 一、算法思想1.1 BP神经网络1.2 遗传算法1.3 遗传算法优化的BP神经网络 二、代码解读2.1 数据预处理2.2 GABP2.3 部分函数说明 一、算法思想 1.1 BP神经网络 BP神经网络(Backpropagation Neural Network,反向传播神经网络)是一种监…