语义分割混淆矩阵、 mIoU、mPA计算

一、操作

需要会调试代码的人自己改,小白直接运行会出错

这是我从自己的大文件里摘取的一部分代码,可以运行,只是要改的文件地址path比较多,遇到双引号“”的地址注意一下,不然地址不对容易出错

 把 calculate.py和 utiles_metrics.py放在同一文件夹下,然后运行 calculate.py。

二、理解

test_mIou,test_mPA,test_miou,test_mpa=compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes,weight_name)  # 执行计算mIoU的函数

gt_dir 真实标签文件夹

pred_dir 预测结果文件夹

主要是这两个变量设置,后面的可以选择性修改

image_ids 文件名称 dirList(pred_dir,path_list) saveList(path_list) 这两个函数得到

num_classes 类别数

name_classes 类别名称

weight_name 权重名称

hist为混淆矩阵,mIoU为交并比

三、代码 

 calculate.py

# -*- coding: utf-8 -*-
import torch
import os

from time import time
# from PIL import Image
from utils_metrics import compute_mIoU
def saveList(pathName):
    for file_name in pathName:
        #f=open("C:/Users/Administrator/Desktop/DeepGlobe-Road-Extraction-link34-py3/dataset/real/gt.txt", "x")
        with open("./dataset/gt.txt", "a") as f:
            f.write(file_name.split(".")[0] + "\n")
        f.close

def dirList(gt_dir,path_list):
    for i in range(0, len(path_list)):
        path = os.path.join(gt_dir, path_list[i])
    if os.path.isdir(path):
        saveList(os.listdir(path))

data_path  = './dataset/'


f=open("./dataset/gt.txt", 'w')
gt_dir      = os.path.join(data_path, "real/")
pred_dir    = "./submits/log01_Dink101_five_100/test_iou/iou_60u/"
path_list = os.listdir(pred_dir)
path_list.sort()
dirList(pred_dir,path_list)
saveList(path_list)
num_classes=2
name_classes    = ["nontarget","target"]
weight_name='log01_Dink101_five_100'
image_ids   = open(os.path.join(data_path, "gt.txt"),'r').read().splitlines() 

test_mIou,test_mPA,test_miou,test_mpa=compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes,weight_name)  # 执行计算mIoU的函数
print('  test_mIoU:  '+str(test_miou))

 utiles_metrics.py

from os.path import join

import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import os
import cv2

# from matplotlib import pyplot as plt
import shutil
import numpy as np
# from matplotlib.pyplot import MultipleLocator

def f_score(inputs, target, beta=1, smooth = 1e-5, threhold = 0.5):
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht and w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
        
    temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
    temp_target = target.view(n, -1, ct)

    #--------------------------------------------#
    #   计算dice系数
    #--------------------------------------------#
    temp_inputs = torch.gt(temp_inputs, threhold).float()
    tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
    fp = torch.sum(temp_inputs                       , axis=[0,1]) - tp
    fn = torch.sum(temp_target[...,:-1]              , axis=[0,1]) - tp

    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    score = torch.mean(score)
    return score

# 设标签宽W,长H
def fast_hist(a, b, n):
    #--------------------------------------------------------------------------------#
    #   a是转化成一维数组的标签,形状(H×W,);b是转化成一维数组的预测结果,形状(H×W,)
    #--------------------------------------------------------------------------------#
    k = (a >= 0) & (a < n)
    #--------------------------------------------------------------------------------#
    #   np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数,返回值形状(n, n)
    #   返回中,写对角线上的为分类正确的像素点
    #--------------------------------------------------------------------------------#
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)  

def per_class_iu(hist):
    return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1) 

def per_class_PA(hist):
    return np.diag(hist) / np.maximum(hist.sum(1), 1) 

def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes,weight_name):  
    # print('Num classes', num_classes)  
    #-----------------------------------------#
    #   创建一个全是0的矩阵,是一个混淆矩阵
    #-----------------------------------------#
    hist = np.zeros((num_classes, num_classes))
    
    #------------------------------------------------#
    #   获得验证集标签路径列表,方便直接读取
    #   获得验证集图像分割结果路径列表,方便直接读取
    #------------------------------------------------#
    gt_imgs     = [join(gt_dir, x + ".png") for x in png_name_list]  
    pred_imgs   = [join(pred_dir, x + ".png") for x in png_name_list]  
    # building_iou=[]
    # background_iou=[]
    m_iou=[]
    # building_pa=[]
    # background_pa=[]
    m_pa=[]

    #------------------------------------------------#
    #   读取每一个(图片-标签)对
    #------------------------------------------------#
    for ind in range(len(gt_imgs)): 
        #------------------------------------------------#
        #   读取一张图像分割结果,转化成numpy数组
        #------------------------------------------------#
        pred = np.array(Image.open(pred_imgs[ind]))
        
        #------------------------------------------------#
        #   读取一张对应的标签,转化成numpy数组
        #------------------------------------------------#
        label = np.array(Image.open(gt_imgs[ind]))  
        
        # 如果图像分割结果与标签的大小不一样,这张图片就不计算
        if len(label.flatten()) != len(pred.flatten()):  
            print(
                'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(
                    len(label.flatten()), len(pred.flatten()), gt_imgs[ind],
                    pred_imgs[ind]))
            continue

        #------------------------------------------------#
        #   对一张图片计算21×21的hist矩阵,并累加
        #------------------------------------------------#
        a=label.flatten()
        a//=254
       
        b=pred.flatten()
        b//=254
        hist += fast_hist(a, b,num_classes)  
        # # 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值
        # mIoUs   = per_class_iu(hist)
        # mPA     = per_class_PA(hist)
        # m_iou.append(100 * np.nanmean(mIoUs[1]))
        # m_pa.append(100 * np.nanmean(mPA[1]))
        # # if ind > 0 and ind % 10 == 0:  
        # #     print('{:d} / {:d}: mIou-{:0.2f}; mPA-{:0.2f}'.format(ind, len(gt_imgs),
        # #                                             100 * np.nanmean(mIoUs[1]),
        # #                                             100 * np.nanmean(mPA[1])))
    mIoUs   = per_class_iu(hist)
    mPA     = per_class_PA(hist)
    print(mIoUs)

    # plt.figure()
    # x=np.arange(len(m_iou))
    # plt.plot(x,m_iou)
    # plt.plot(x,m_pa)
    # plt.grid(True)
    # y_major_locator=MultipleLocator(10)#把y轴的刻度间隔设置为10,并存在变量里
    # ax = plt.gca()
    # ax.yaxis.set_major_locator(y_major_locator)
    # ax.set_ylim(0,100)
    # plt.xlabel('Order')
    # plt.ylabel('mIOU & mPA')
    # plt.legend(['mIOU','mPA'],loc="upper right")

    # targ=os.path.join(pred_dir,os.path.pardir)
    

    # plt.savefig(os.path.join(targ, weight_name[:-3]+"_sin_miou.png"))

    return m_iou,m_pa,str(round(mIoUs[1] * 100, 2)),str(round(mPA[1] * 100, 2))

调试

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/37443.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SpringCloud

SpringCloud01 为什么要学习微服务框架知识&#xff1f; 因为互联网发展迅速&#xff0c;业务更新迭代快 微服务符合敏捷开发需求 服 务 网 关&#xff08;请求路由&#xff0c;负载均衡&#xff09; 注册中心&#xff08;拉取或注册服务信息 eureka nacos&#xff09; 配…

tcp转发服务桥(windows)

目的 目的是为了在网关上转发udp数据和tcp数据。对于网络里面隔离的内网来说&#xff0c;有一台可以上网的服务器&#xff0c;那么通过两块网卡就可以转发出去&#xff0c;在服务器上进行数据的转发&#xff0c;有tcp和udp两种&#xff0c;udp已经写过了&#xff0c;这次使用了…

pycharm import的类库修改后要重启问题的解决方法

通过将以下行添加到pycharm中的settings-> Build,Excecution,Deployment-> Console-> Python Console中&#xff0c;可以指示Pycharm在更改时自动重新加载模块&#xff1a; %load_ext autoreload %autoreload 2

APP开发的未来:虚拟现实和增强现实的角色

移动应用程序越来越多地在我们的日常生活中发挥着重要作用。但是&#xff0c;随着技术的不断发展&#xff0c;未来的 APP开发会有什么新的发展方向呢&#xff1f;这是每个人都在关心的问题。在过去的几年中&#xff0c;移动应用程序领域发生了巨大变化。像 VR/AR这样的技术为人…

OpenCv (C++) 使用矩形 Rect 覆盖图像中某个区域

文章目录 1. 使用矩形将图像中某个区域置为黑色2. cv::Rect 类介绍 1. 使用矩形将图像中某个区域置为黑色 推荐参考博客&#xff1a;OpenCV实现将任意形状ROI区域置黑&#xff08;多边形区域置黑&#xff09; 比较常用的是使用 Rect 矩形实现该功能&#xff0c;代码如下&…

大模型与端到端会成为城市自动驾驶新范式吗?

摘要&#xff1a; 最近可以明显看到或者感受到第一梯队的城市自动驾驶量产已经进入快车道&#xff0c;他们背后所依靠的正是当下最热的大模型和端到端的技术。 近期&#xff0c;城市自动驾驶量产在产品和技术上都出现了新的变化。 在产品层面&#xff0c;出现了记性行车或者称…

MySQL基础篇第7章(单行函数)

文章目录 1、函数的理解1.1 什么是函数1.2 不同DBMS函数的差异1.3 MySQL的内置函数分类 2、数值函数2.1 基本函数2.2 角度与弧度互转函数2.3 三角函数2.4 指数和对数2.5 进制间的转换 3、字符串函数4、日期和时间函数4.1 获取日期、时间4.2 日期与时间戳的转换4.3 获取月份、星…

自制游戏引擎之shader预编译

shader预编译为二进制,在程序运行时候加载,可以提升性能,节省启动时间. 1. 采用google shaderc预编译与加载shader 1.1 下载代码 https://github.com/google/shaderc third_party文件里需要放依赖的第三方 因为电脑访问google的问题,无法通过shaderc-2023.4\utils\git-sync-de…

Java设计模式之行为型-状态模式(UML类图+案例分析)

目录 一、基础概念 二、UML类图 三、角色设计 四、案例分析 五、总结 一、基础概念 状态模式允许一个对象在其内部状态改变时改变它的行为&#xff0c;对象看起来似乎修改了它的类&#xff0c;状态模式主要解决的是当控制一个对象状态转换的条件表达式过于复杂时的情况&a…

运输层(TCP运输协议相关)

运输层 1. 运输层概述2. 端口号3. 运输层复用和分用4. 应用层常见协议使用的运输层熟知端口号5. TCP协议对比UDP协议6. TCP的流量控制7. TCP的拥塞控制7.1 慢开始算法、拥塞避免算法7.2 快重传算法7.3 快恢复算法 8. TCP超时重传时间的选择8.1 超时重传时间计算 9. TCP可靠传输…

计算机视觉 - 理论 - 从卷积到识别

计算机视觉 - 理论入门 前言一&#xff0c;导论&#xff1a;二&#xff0c;卷积&#xff1a;图像去噪&#xff1a;常值卷积&#xff1a;高斯卷积&#xff1a;椒盐去噪&#xff1a;锐化程度&#xff1a; 三&#xff0c;边缘检测&#xff1a;图像信号导数&#xff1a;求导算子:图…

Java分布式项目常用技术栈简介

Spring-Cloud-Gateway : 微服务之前架设的网关服务&#xff0c;实现服务注册中的API请求路由&#xff0c;以及控制流速控制和熔断处理都是常用的架构手段&#xff0c;而这些功能Gateway天然支持 运用Spring Boot快速开发框架&#xff0c;构建项目工程&#xff1b;并结合Spring…

【模式识别目标检测】——基于机器视觉的无人机避障RP-YOLOv3实例

目录 引入 一、YOLOv3模型 1、实时目标检测YOLOv3简介 2、改进的实时目标检测模型 二、数据集建立&结果分析 1、数据集建立 2、模型结果分析 三、无人机避障实现 参考文献&#xff1a; 引入 目前对于障碍物的检测整体分为&#xff1a;激光、红外线、超声波、雷达、…

【算法基础】搜索与图论

DFS 全排列问题 842. 排列数字 - AcWing题库 #include<bits/stdc.h> using namespace std; const int N10; int n; int path[N]; bool st[N]; void dfs(int x) {if(x>n){for(int i1;i<n;i) cout<<path[i]<<" ";cout<<endl;return ;…

[微信小程序] movable-view 可移动视图容器 - 范围问题

movable-view 可移动视图容器 可移动视图容器&#xff0c;在页面中可以拖拽滑动。movable-view必须在 movable-area 组件中&#xff0c;并且必须是直接子节点 <view><movable-area style"width: 750rpx;height: 200rpx;background-color: gainsboro;">&l…

Linux 批量杀掉进程(包含某个关键字)

一、场景说明 现场环境有十多个包含 ”celery” 关键字的进程在运行&#xff0c;每次重启服务&#xff0c;需要将这些进行kill掉&#xff0c;然后重新启动。 可以用如下命令批量kill掉这些进程&#xff1a; kill -9 PID1 PID2 PID3 PID4.....其中&#xff0c;PID是查询到的进…

(论文精读)PRUNING FILTER IN FILTER《滤波器中的剪枝滤波器》

论文地址&#xff1a;原文 代码实现 中文翻译 一、精读论文 论文题目 PRUNING FILTER IN FILTER 论文作者 Fanxu Meng 孟繁续 刊物名称 NeurIPS 2020 出版日期 2020 摘要 剪枝已成为现代神经网络压缩和加速的一种非常有效的技术。现有的剪枝方法可分为两大类:滤波器…

科技赋能企业,实现数字化转型

科技是第一生产力&#xff0c;数字技术即科技&#xff0c;可以改变传统的商业模式&#xff0c;为各行各业注入新的活力。 推动企业数字化转型&#xff0c;可是实现行业的效率提升&#xff0c;实现跨界重组&#xff0c;重构产业模式&#xff0c;为产业格局重新赋能&#xff0c;最…

go-zero微服务实战——服务构建

目录介绍 接上一节go-zero微服务实战——基本环境搭建。搭建好了微服务的基本环境&#xff0c;开始构建整个微服务体系了&#xff0c;将其他服务也搭建起来。 order的目录结构&#xff0c;如下 根目录 api服务rpc服务自定义逻辑层logic自定义参数层models自定义工具层util …

使用Yfinance和Plotly分析金融数据

大家好&#xff0c;今天我们用Python分析金融数据&#xff0c;使用Yfinance和Plotly绘制图表&#xff0c;带你了解在Python中使用Plotly制作图表&#xff0c;利用Plotly强大的图表功能来分析和可视化金融数据。 导语 在本文中&#xff0c;我们将深入研究Plotly&#xff0c;从…