pytorch 计算混淆矩阵

混淆矩阵是评估模型结果的一种指标 用来判断分类模型的好坏

 预测对了 为对角线 

还可以通过矩阵的上下角发现哪些容易出错

从这个 矩阵出发 可以得到 acc != precision recall  特异度?

 

 目标检测01笔记AP mAP recall precision是什么 查全率是什么 查准率是什么 什么是准确率 什么是召回率_:)�东东要拼命的博客-CSDN博客

 acc  是对所有类别来说的

其他三个都是 对于类别来说的

下面给出源码 

import json
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from prettytable import PrettyTable
from torchvision import datasets
from torchvision.models import MobileNetV2
from torchvision.transforms import transforms


class ConfusionMatrix(object):
    """
    注意版本问题,使用numpy来进行数值计算的
    """

    def __init__(self, num_classes: int, labels: list):
            self.matrix = np.zeros((num_classes, num_classes))
            self.num_classes = num_classes
            self.labels = labels

    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[t, p] += 1

# 行代表预测标签 列表示真实标签




    def summary(self):
        # calculate accuracy
        sum_TP = 0
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]
        acc = sum_TP / np.sum(self.matrix)
        print("acc is", acc)

        # precision, recall, specificity
        table = PrettyTable()
        table.fields_names = ["", "pre", "recall", "spec"]
        for i in range(self.num_classes):
            TP = self.matrix[i, i]
            FP = np.sum(self.matrix[i, :]) - TP
            FN = np.sum(self.matrix[:, i]) - TP
            TN = np.sum(self.matrix) - TP - FP - FN
            pre = round(TP / (TP + FP), 3)    # round 保留三位小数
            recall = round(TP / (TP + FN), 3)
            spec = round(TN / (FP + FN), 3)
            table.add_row([self.labels[i], pre, recall, spec])
        print(table)


    def plot(self):
        matrix = self.matrix
        print(matrix)
        plt.imshow(matrix, cmap=plt.cm.Blues)  # 颜色变化从白色到蓝色

        # 设置 x  轴坐标 label
        plt.xticks(range(self.num_classes), self.labels, rotation=45)
        # 将原来的 x 轴的数字替换成我们想要的信息 self.num_classes  x 轴旋转45度
        # 设置 y  轴坐标 label
        plt.yticks(range(self.num_classes), self.labels)

        # 显示 color bar  可以通过颜色的密度看出数值的分布
        plt.colorbar()
        plt.xlabel("true_label")
        plt.ylabel("Predicted_label")
        plt.title("ConfusionMatrix")

        # 在图中标注数量 概率信息
        thresh = matrix.max() / 2
        # 设定阈值来设定数值文本的颜色 开始遍历图像的时候一般是图像的左上角
        for x in range(self.num_classes):
            for y in range(self.num_classes):
                # 这里矩阵的行列交换,因为遍历的方向 第y行 第x列
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black")
        plt.tight_layout()
        # 图形显示更加的紧凑
        plt.show()



if __name__ ==' __main__':
    device = torch.device("cuda:0" if torch.cuda.is_available()else "cpu")
    print(device)
    # 使用验证集的预处理方式
    data_transform = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor()
                                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    data_loot = os.path.abspath(os.path.join(os.getcwd(), "../.."))
    # get data root path
    image_path = data_loot + "/data_set/flower_data/"
    # flower data set path

    validate_dataset = datasets.ImageFolder(root=image_path +"val",
                                            transform=data_transform)

    batch_size = 16
    validate_loader = torch.utils.data.DataLoder(validate_dataset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=2)

    net = MobileNetV2(num_classes=5)
    #加载预训练的权重
    model_weight_path = "./MobileNetV2.pth"
    net.load_state_dict(torch.load(model_weight_path, map_location=device))
    net.to(device)

    #read class_indict
    try:
        json_file = open('./class_indicts.json', 'r')
        class_indict = json.load(json_file)
    except Exception as e:
        print(e)
        exit(-1)


    labels = [label for _, label in class_indict.item()]
    # 通过json文件读出来的label
    confusion = ConfusionMatrix(num_classes=5, labels=labels)
    net.eval()
    # 启动验证模式
    # 通过上下文管理器  no_grad  来停止pytorch的变量对梯度的跟踪
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            outputs = torch.softmax(outputs, dim=1)
            outputs = torch.argmax(outputs, dim=1)
            # 获取概率最大的元素
            confusion.update(outputs.numpy(), val_labels.numpy())
            # 预测值和标签值
    confusion.plot()
    # 绘制混淆矩阵
    confusion.summary()
    # 来打印各个指标信息
































是这样的 这篇算是一个学习笔记,其中的基础图都源于我的导师

 霹雳吧啦Wz的个人空间_哔哩哔哩_bilibili

欢迎无依无靠的CV同学加入 

讲的非常好 代码其实也是导师给的 

我能做的就是读懂每一行加点注释

给不想看视频的同学留点时间

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/2206.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【K8S系列】深入解析Pod对象(一)

目录 序言 1.问题引入 1.1 问题描述 2 问题解答 2.1 pod 属性 2.1.1 NodeSelector 2.1.2 HostAliases 2.1.3 shareProcessNamespace 2.1.4 NodeName 2.1.5 其他pod属性 2.2 容器属性 2.2.1 ImagePullPolicy 2.2.2 Lifecycle 3 总结 4. 投票 序言 任何一件事情&am…

一文读懂强化学习!

一.了解强化学习1.1基本概念强化学习是考虑智能体(Agent)与环境(Environment)的交互问题:智能体处在一个环境中,每个状态为智能体对当前环境的感知;智能体只能通过动作来影响环境,当…

空间信息智能应用团队研究成果介绍及人才引进

目录1、多平台移动测量技术1.1 车载移动测量系统1.2 机载移动测量系统2、数据处理与应用技术研究2.1 点云与影像融合2.2 点云配准与拼接2.3 点云滤波与分类2.4 道路矢量地图提取2.5 道路三维自动建模2.6 道路路面三维病害分析2.7 多期点云三维变形分析2.8 地表覆盖遥感监测分析…

ChatGPT在安全研究领域的应用实践

引言ChatGPT是一个人工智能技术驱动的自然语言处理工具,它能够通过理解和学习人类的语言来进行对话,并能进行连续对话。目前ChatGPT已经官方已经更新模型到4.0版本,宣称它是“最先进的系统,能生产更安全和更有用的回复”。当前使用…

wandb:可视化和超参数寻优

参考博客:https://zhuanlan.zhihu.com/p/591047340 1、注册账号 首先,去wandb官网注册一个账号,选择个人使用即可(根据个人需要) 然后,登录得到一个API key 2、wandb使用 (1)命令…

Spring框架学习--xml和Annotation方式实现IOC

AnnotationXml的spring-IOC和全Annotation的spring-IOC 文章目录AnnotationXml的spring-IOC和全Annotation的spring-IOC学习目标第二章 基于AnnotationXml的spring-IOC【重点】1、annotationxml【入门案例】(5)【1】目标【2】实现【2.1】创建项目【2.3】改写AccountDaoImpl【2.…

刷题记录(2023.3.14 - 2023.3.18)

[第五空间 2021]EasyCleanup 临时文件包含考点 分析源码,两个特殊的点,一个是 eval,另一个是 include eval 经过了 strlen filter checkNums 三个函数 include 经过了 strlen filter 两个函数 filter 检测是否包含特定的关键字或字符 fun…

【数据结构与算法】用栈实现队列

文章目录&#x1f63b;前言如何用栈实现队列&#xff1f;用栈实现队列整体的实现代码&#x1f63c;写在最后&#x1f63b;前言 &#x1f61d;上一章我们用队列实现了一个栈&#xff08;-> 传送门 <-&#xff09;&#xff0c;而这一章就带大家用栈实现一个队列。 &#x1…

< 每日算法:在排序数组中查找元素的第一个和最后一个位置 >

每日算法 - JavaScript解析&#xff1a;在排序数组中查找元素的第一个和最后一个位置 一、任务描述&#xff1a;> 示例 1> 示例 2> 示例 3二、题意解析三、解决方案&#xff1a;往期内容 &#x1f4a8;一、任务描述&#xff1a; 给你一个按照非递减顺序排列的整数数组…

C++基础算法③——排序算法(选择、冒泡附完整代码)

排序算法 1、选择排序 2、冒泡排序 1、选择排序 基本思想&#xff1a;从头至尾扫描序列&#xff0c;每一趟从待排序元素中找出最小(最大)的一个元素值&#xff0c;然后与第一个元素交换值&#xff0c;接着从剩下的元素中继续这种选择和交换方式&#xff0c;最终得到一个有序…

冲击蓝桥杯-时间问题(必考)

目录 前言&#xff1a; 一、时间问题 二、使用步骤 1、考察小时&#xff0c;分以及秒的使用、 2、判断日期是否合法 3、遍历日期 4、推算星期几 总结 前言&#xff1a; 时间问题可以说是蓝桥杯&#xff0c;最喜欢考的问题了,因为时间问题不涉及到算法和一些复杂的知识&#xf…

第十四届蓝桥杯三月真题刷题训练——第 18 天

目录 第 1 题&#xff1a;排列字母 问题描述 运行限制 代码&#xff1a; 第 2 题&#xff1a;GCD_数论 问题描述 输入格式 输出格式 样例输入 样例输出 评测用例规模与约定 运行限制 第 3 题&#xff1a;选数异或 第 4 题&#xff1a;背包与魔法 第 1 题&#x…

1649_Excel中删除重复的数据

全部学习汇总&#xff1a; GreyZhang/windows_skills: some skills when using windows system. (github.com) 长久时间的开发工作性质导致我对Windows生态下的很多工具没有一个深度的掌握。有时候&#xff0c;别说深度&#xff0c;可能最为浅显的操作都是不熟悉的。这个不仅仅…

JVM学习.02 内存分配和回收策略

1、前言《JVM学习.01 内存模型》篇讲述了JVM的内存布局&#xff0c;其中每个区域是作用&#xff0c;以及创建实例对象的时候内存区域的工作流程。上文还讲到了关于对象存货后&#xff0c;会被回收清理的过程。今天这里就着重讲一下对象实例是如何被清理回收的&#xff0c;以及清…

5.方法(最全C#方法攻略)

目录 5.1 方法的结构 5.2 方法体内部的代码执行 5.3.1 类型推断和Var关键字 5.3.2 嵌套块中的本地变量 5.4 本地常量 5.5 控制流 5.6 方法调用 5.7 返回值 5.8 返回语句和void 方法 5.9 参数 5.9.1 形参 5.9.2 实参 位置参数示例 5.10 值参数 5.11 引用参数 5.12…

【vm虚拟机】vmware虚拟机下载安装

vmware虚拟机下载安装&#x1f6a9; vmware虚拟机下载&#x1f6a9; 安装虚拟机程序&#x1f6a9; 创建一个CentOS虚拟机&#x1f6a9; 异常情况&#x1f6a9; vmware虚拟机下载 vmware官网下载地址 &#x1f6a9; 安装虚拟机程序 双击安装包exe程序&#xff0c;无脑下一步即…

来到CSDN的一些感想

之所以会写下今天这篇博客&#xff0c;是因为心中实在是有很多话想说&#xff01;&#xff01;&#xff01; 认识我的人应该都知道&#xff0c;我是才来CSDN不久的&#xff0c;也可以很清楚地看见我的码龄&#xff0c;直到今天&#xff1a;清清楚楚地写着&#xff1a;134天&…

完美日记母公司再度携手中国妇基会,以“创美人生”助力女性成长

撰稿 | 多客 来源 | 贝多财经 当春时节&#xff0c;梦想花开。和煦的三月暖阳&#xff0c;唤醒的不止是满城春意&#xff0c;更有逸仙电商“创美人生”公益项目播撒的一份希望。 3月8日“国际妇女节”当日&#xff0c;为积极响应我国促进共同富裕的政策倡导&#xff0c;助力相…

C语言--自定义类型详解

目录结构体结构体的声明特殊的声明结构的自引用typedef的使用结构体变量的定义和初始化结构体的内存对齐为什么存在内存对齐&#xff1f;修改默认对齐数结构体传参位段位段的内存分配位段的跨平台问题枚举联合联合类型的定义联合在内存中开辟空间联合大小的计算结构体 结构体的…

Linux之磁盘分区、挂载

文章目录一、Linux分区●原理介绍●硬盘说明查看所有设备挂载情况挂载的经典案例二、磁盘情况查询基本语法应用实例磁盘情况-工作实用指令一、Linux分区 ●原理介绍 Linux来说无论有几个分区&#xff0c;分给哪一目录使用&#xff0c;它归根结底就只有一个根目录&#xff0c;…