基于安卓的虫害识别软件设计--(2)模型性能可视化|混淆矩阵、热力图

1.混淆矩阵(Confusion Matrix)

1.1基础理论

(1)在机器学习、深度学习领域中,混淆矩阵常用于监督学习,匹配矩阵常用于无监督学习。主要用来比较分类结果和实际预测值。

(2)图中表达的含义:混淆矩阵的每一列代表了预测类别,每一行代表了数据的真实类别。

1.2 实现代码

import torch
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
from torchvision import transforms

classes = ['bai_xing_hua_jin_gui', 'beetle', 'chui_mian_jie', 'ci_e_ke', 'da_qing_ye_chan','dou_yuan_jing','fan_qie_qian_ye_ying_larva','fan_qie_qian_ye_ying_mature','hong_zhi_zhu','huang_zong_ke']

# classes = ['白星化金龟', '甲虫', '吹绵蚧', '刺蛾科', '大青叶蝉','豆芫菁','番茄潜叶蛾幼虫','番茄潜叶蛾成虫','红蜘蛛','蝗总科']

def predict_image(model, image_path, true_label):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    img = Image.open(image_path)
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    tensor_img = val_transform(img)
    tensor_img = tensor_img.to(device)
    tensor_img = tensor_img.unsqueeze(0)
    output = model(tensor_img)
    _, pred = output.max(1)
    pred_label = classes[pred.item()]
    return pred_label, true_label

if __name__ == '__main__':
    # 1. 加载模型
    model_path = r"/kaggle/input/mymodel3/resnet101_final.pth"
    model = torch.load(model_path)
    model.eval()

    # 2. 预测多张图片并记录真实标签和预测结果
    true_labels = []
    pred_labels = []
    images_dir = r"/kaggle/input/insects-new/validation"
    for label in os.listdir(images_dir):
        label_dir = os.path.join(images_dir, label)
        if not os.path.isdir(label_dir):
            continue
        for img_name in os.listdir(label_dir):
            img_path = os.path.join(label_dir, img_name)
            true_labels.append(label)
            pred_label, _ = predict_image(model, img_path, label)
            pred_labels.append(pred_label)

    # 3. 计算混淆矩阵
    cm = confusion_matrix(true_labels, pred_labels, labels=classes)

    # 4. 计算归一化的混淆矩阵
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    # 5. 绘制混淆矩阵
    save_path = "/kaggle/working/confusion_matrix.png"
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_normalized, annot=True, cmap='Blues', xticklabels=classes, yticklabels=classes, fmt='.2f')
    plt.xlabel('预测标签')
    plt.ylabel('真实标签')
    plt.tight_layout()  # 自动调整子图参数
    plt.savefig(save_path)
    plt.show()

注意:以下数值需要和训练时的数值一样!


2.热力图

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
from torchcam.methods import GradCAMpp
# CAM GradCAM GradCAMpp ISCAM LayerCAM SSCAM ScoreCAM SmoothGradCAMpp XGradCAM
from torchvision import transforms
from torchcam.utils import overlay_mask


# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)


model = torch.load('/kaggle/input/mymodel3/resnet101_final.pth')
model = model.eval().to(device)

cam_extractor = GradCAMpp(model)

# 要与训练集保持一致
test_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                             transforms.RandomHorizontalFlip(),
                             transforms.RandomVerticalFlip(),
                             transforms.RandomGrayscale(),
                             transforms.ToTensor(),
                             transforms.RandomErasing(),
                             transforms.Normalize([0.460, 0.483, 0.396], [0.171, 0.150, 0.190])])

# 载入目标图像
img_path = '/kaggle/input/insects-new/train/hong_zhi_zhu/13845.jpg'
img_pil = Image.open(img_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
# 预测标签
pred_logits = model(input_tensor)
pred_id = torch.topk(pred_logits, 1)[1].detach().cpu().numpy().squeeze().item()

activation_map = cam_extractor(pred_id, pred_logits)
activation_map = activation_map[0][0].detach().cpu().numpy()
# 矩阵热力图
plt.imshow(activation_map)
plt.show()
plt.savefig('/kaggle/working/activation_map.png')

# 将原图重合
result = overlay_mask(img_pil, Image.fromarray(activation_map), alpha=0.7)
result.save('/kaggle/working/result.png')

result

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/672862.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

物理模拟技术在AI绘画中的革新作用

引言: 随着人工智能(AI)技术的飞速发展,艺术领域也迎来了一场创新的革命。AI绘画,作为这场革命的重要组成部分,不仅改变了传统艺术创作的模式,而且为艺术家提供了前所未有的创作工具。在这一过程…

Linux基础1-基本指令1

1.Linux学习前言 Linux的学习非常重要,我们学习Linux的第一步是在电脑中搭建Linux环境。 对于没有搭建过的可以看这阿伟t的一篇文章 【Linux入门】Linux环境配置-CSDN博客 我的环境为XShell,运行的云服务器是阿里云 2.本章重点 1.显示当前目录下的所有文件…

软件杯 题目:基于卷积神经网络的手写字符识别 - 深度学习

文章目录 0 前言1 简介2 LeNet-5 模型的介绍2.1 结构解析2.2 C1层2.3 S2层S2层和C3层连接 2.4 F6与C5层 3 写数字识别算法模型的构建3.1 输入层设计3.2 激活函数的选取3.3 卷积层设计3.4 降采样层3.5 输出层设计 4 网络模型的总体结构5 部分实现代码6 在线手写识别7 最后 0 前言…

展现市场布局雄心,ATFX再度亮相非洲峰会,开启区域市场新篇章

自2023年全球市场营销战略部署实施以来,ATFX在全球各区域市场取得了丰硕成果,其品牌实力、知名度、影响力均有大幅提升。在这场全球扩张的征程中,非洲市场日益成为集团关注的焦点。自2023年首次踏上这片充满潜力的市场以来,ATFX持…

定义类并创建类的实例

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 在Python中,类表示具有相同属性和方法的对象的集合。在使用类时,需要先定义类,然后再创建类的实例,通…

谨以此文章记录我的蓝桥杯备赛过程

以国优秀结束了蓝桥杯cb组 鄙人来自电信学院,非科班出身,在寒假,大约2024年2月份,跟着黑马程序员将c基础语法学完了,因为过年,事情较多,没在学了。 最初就是抱着拿省三的态度去打这个比赛的&a…

低代码是什么?开发系统更有什么优势?

低代码(Low-Code)是一种应用开发方法,它采用图形化界面和预构建的模块,使得开发者能够通过少量的手动编程来快速创建应用程序。这种方法显著减少了传统软件开发中的手动编码量,提高了开发效率,降低了技术门…

图形学初识--多边形剪裁算法

文章目录 前言正文为什么需要多边形剪裁算法?前置知识二维直线直线方程:距离本质:点和直线距离关系: 三维平面平面方程距离本质:点和直线距离关系: Suntherland hodgman算法基本介绍基本思想二维举例问题描…

mysql中EXPLAIN详解

大家好。众所周知,MySQL 查询优化器的各种基于成本和规则的优化会后生成一个所谓的执行计划,这个执行计划展示了接下来具体执行查询的方式。在日常工作过程中,我们可以使用EXPLAIN语句来查看某个查询语句的具体执行计划, 今天我们…

椭圆轨道的周期性运动轨道

一、背景介绍 本节将从轨道六根数的角度,探究目标星为椭圆轨道,追踪星周期性环绕目标的必要条件。根据航天动力学的原理,对于一个椭圆轨道,其轨道能量为 对于能够不产生漂移的情况,绕飞编队的能量。对于追踪星到目标星…

(2024,扩散,去噪调度,维度,误差,收敛速度)适应基于分数的扩散模型中的未知低维结构

Adapting to Unknown Low-Dimensional Structures in Score-Based Diffusion Models 公和众和号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群) 目录 0. 摘要 1. 引言 1.1 扩散模型 1.2 现有结果的不…

Xilinx RFSOC 47DR 8收8发 信号处理板卡

系统资源如图所示:  FPGA采用XCZU47DR 1156芯片,PS端搭载一组64Bit DDR4,容量为4GB,最高支持速率:2400MT/s;  PS端挂载两片QSPI X4 FLASH;  PS支持一路NVME存储;  PS端挂载SD接口,用于存储程序&…

图解大模型分布式并行各种通信原语

背景 在分布式集群上执行大模型任务时候,往往使用到数据并行,流水线并行,张量并行等技术,这些技术本质上也就是对数据进行各种方案的切分,然后放到不同的节点上运算。不同节点在计算的过程中需要对数据分发或者同步等…

LeetCode刷题之HOT100之在排序数组中查找元素的第一个和最后一个位置

下午雨变小了,但我并未去实验室,难得的一天呆在宿舍。有些无聊,看看这个,弄弄那个,听听歌,消磨时间。不知觉中时间指针蹦到了九点,做题啦!朋友推荐了 Eason 的 2010-DUO 演唱会&…

一文了解经典报童模型的扩展问题

文章目录 1 引言2 经典报童模型3 综述文章4 模型扩展4.1 扩展目标函数4.2 增加约束条件4.3 增加优化变量4.4 扩展模型参数4.5 扩展问题场景 5 总结6 相关阅读 1 引言 时间过的真快呀,已经6月份了。距离上一篇文章发表,已经过去了将近一个月,…

JS(DOM、事件)

DOM 概念:Document Object Model,文档对象模型。将标记语言的各个组成部分封装为对应的对象: Document:整个文档对象Element:元素对象Attribute:属性对象Text:文本对象Comment:注释对象 JavaScript通过DOM,就能够对HTML进行操作: 改变 HTML 元素的内…

系统操作规约(System Operation Contract)

领域建模补充 问题: 联系有方向性 属性有类型 领域模型尽量避免出现界面相关的东西 习题 问题 考察点 系统操作规约 示例 A) Operation: MakeSale() Cross References: UC:Purchase Preconditions: User has logged in Postconditions: An ProductLis…

集成算法实验与分析(软投票与硬投票)

概述 目的:让机器学习效果更好,单个不行,集成多个 集成算法 Bagging:训练多个分类器取平均 f ( x ) 1 / M ∑ m 1 M f m ( x ) f(x)1/M\sum^M_{m1}{f_m(x)} f(x)1/M∑m1M​fm​(x) Boosting:从弱学习器开始加强&am…

Fiddler抓包工具的使用

目录 1、抓包原理:👇 2、抓包结果👇 1)如何查看一个http请求的原始摸样: 2)分析数据格式: 3、请求格式分析👇 4、响应格式分析👇 官网下载:安装过程比较…

win11+vmware16.0+Ubuntu22.04+开机蓝屏

总结 本机系统 vm虚拟机下载 参考链接 1. 小白必看的Ubuntu20.04安装教程(图文讲解) 2. 软件目录【火星】——VM下载 3. Win11使用VMware15/16启动虚拟机直接蓝屏的爬坑记录 VMware16.0