YOLOv5 分类模型 预处理 OpenCV实现

YOLOv5 分类模型 预处理 OpenCV实现

flyfish

YOLOv5 分类模型 预处理 PIL 实现
YOLOv5 分类模型 OpenCV和PIL两者实现预处理的差异

YOLOv5 分类模型 数据集加载 1 样本处理
YOLOv5 分类模型 数据集加载 2 切片处理
YOLOv5 分类模型 数据集加载 3 自定义类别

YOLOv5 分类模型的预处理(1) Resize 和 CenterCrop
YOLOv5 分类模型的预处理(2)ToTensor 和 Normalize

YOLOv5 分类模型 Top 1和Top 5 指标说明
YOLOv5 分类模型 Top 1和Top 5 指标实现

判断图像是否是np.ndarray类型和维度

OpenCV读取一张图像时,类型类型就是<class 'numpy.ndarray'>,这里判断图像是否是np.ndarray类型
dim是dimension维度的缩写,shape属性的长度也是它的ndim
灰度图的shape为HW,二个维度
RGB图的shape为HWC,三个维度
在这里插入图片描述

def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})

实现ToTensor和Normalize

def totensor_normalize(img):
    print("preprocess:",img.shape)
    images = (img/255-mean)/std
    images = images.transpose((2, 0, 1))# HWC to CHW
    images = np.ascontiguousarray(images)
    return images

实现Resize

插值可以是以下参数

# 'nearest': cv2.INTER_NEAREST,
# 'bilinear': cv2.INTER_LINEAR,
# 'area': cv2.INTER_AREA,
# 'bicubic': cv2.INTER_CUBIC,
# 'lanczos': cv2.INTER_LANCZOS4
def resize(img, size, interpolation=cv2.INTER_LINEAR):
    r"""Resize the input numpy ndarray to the given size.
    Args:
        img (numpy ndarray): Image to be resized.
        size: like pytroch about size interpretation flyfish.
        interpolation (int, optional): Desired interpolation. Default is``cv2.INTER_LINEAR``  
    Returns:
        numpy Image: Resized image.like opencv
    """
    if not _is_numpy_image(img):
        raise TypeError('img should be numpy image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.abc.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))
    h, w = img.shape[0], img.shape[1]

    if isinstance(size, int):
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
        else:
            oh = size
            ow = int(size * w / h)
    else:
        ow, oh = size[1], size[0]
    output = cv2.resize(img, dsize=(ow, oh), interpolation=interpolation)
    if img.shape[2] == 1:
        return output[:, :, np.newaxis]
    else:
        return output

实现CenterCrop

def crop(img, i, j, h, w):
    """Crop the given Image flyfish.
    Args:
        img (numpy ndarray): Image to be cropped.
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
        h: Height of the cropped image.
        w: Width of the cropped image.
    Returns:
        numpy ndarray: Cropped image.
    """
    if not _is_numpy_image(img):
        raise TypeError('img should be numpy image. Got {}'.format(type(img)))

    return img[i:i + h, j:j + w, :]


def center_crop(img, output_size):
    if isinstance(output_size, numbers.Number):
        output_size = (int(output_size), int(output_size))
    h, w = img.shape[0:2]
    th, tw = output_size
    i = int(round((h - th) / 2.))
    j = int(round((w - tw) / 2.))
    return crop(img, i, j, th, tw)

完整

import time
from models.common import DetectMultiBackend
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
import collections
import torch
import numbers


classes_name=['n02086240', 'n02087394', 'n02088364', 'n02089973', 'n02093754', 'n02096294', 'n02099601', 'n02105641', 'n02111889', 'n02115641']

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})


def totensor_normalize(img):
    print("preprocess:",img.shape)
    images = (img/255-mean)/std
    images = images.transpose((2, 0, 1))# HWC to CHW
    images = np.ascontiguousarray(images)
    return images

def resize(img, size, interpolation=cv2.INTER_LINEAR):
    r"""Resize the input numpy ndarray to the given size.
    Args:
        img (numpy ndarray): Image to be resized.
        size: like pytroch about size interpretation flyfish.
        interpolation (int, optional): Desired interpolation. Default is``cv2.INTER_LINEAR``  
    Returns:
        numpy Image: Resized image.like opencv
    """
    if not _is_numpy_image(img):
        raise TypeError('img should be numpy image. Got {}'.format(type(img)))
    if not (isinstance(size, int) or (isinstance(size, collections.abc.Iterable) and len(size) == 2)):
        raise TypeError('Got inappropriate size arg: {}'.format(size))
    h, w = img.shape[0], img.shape[1]

    if isinstance(size, int):
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
        else:
            oh = size
            ow = int(size * w / h)
    else:
        ow, oh = size[1], size[0]
    output = cv2.resize(img, dsize=(ow, oh), interpolation=interpolation)
    if img.shape[2] == 1:
        return output[:, :, np.newaxis]
    else:
        return output

def crop(img, i, j, h, w):
    """Crop the given Image flyfish.
    Args:
        img (numpy ndarray): Image to be cropped.
        i: Upper pixel coordinate.
        j: Left pixel coordinate.
        h: Height of the cropped image.
        w: Width of the cropped image.
    Returns:
        numpy ndarray: Cropped image.
    """
    if not _is_numpy_image(img):
        raise TypeError('img should be numpy image. Got {}'.format(type(img)))

    return img[i:i + h, j:j + w, :]


def center_crop(img, output_size):
    if isinstance(output_size, numbers.Number):
        output_size = (int(output_size), int(output_size))
    h, w = img.shape[0:2]
    th, tw = output_size
    i = int(round((h - th) / 2.))
    j = int(round((w - tw) / 2.))
    return crop(img, i, j, th, tw)

class DatasetFolder:

    def __init__(
        self,
        root: str,

    ) -> None:
        self.root = root

        if classes_name is None or not classes_name:
            classes, class_to_idx = self.find_classes(self.root)
            print("not classes_name")

        else:
            classes = classes_name
            class_to_idx ={cls_name: i for i, cls_name in enumerate(classes)}
            print("is classes_name")

        print("classes:",classes)
        
        print("class_to_idx:",class_to_idx)
        samples = self.make_dataset(self.root, class_to_idx)

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Optional[Dict[str, int]] = None,

    ) -> List[Tuple[str, int]]:

        directory = os.path.expanduser(directory)

        if class_to_idx is None:
            _, class_to_idx = self.find_classes(directory)
        elif not class_to_idx:
            raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

        instances = []
        available_classes = set()
        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    if 1:  # 验证:
                        item = path, class_index
                        instances.append(item)

                        if target_class not in available_classes:
                            available_classes.add(target_class)

        empty_classes = set(class_to_idx.keys()) - available_classes
        if empty_classes:
            msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "

        return instances

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:

        classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
        if not classes:
            raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def __getitem__(self, index: int) -> Tuple[Any, Any]:

        path, target = self.samples[index]
        sample = self.loader(path)

        return sample, target

    def __len__(self) -> int:
        return len(self.samples)

    def loader(self, path):
        print("path:", path)
        img = cv2.imread(path)  # BGR HWC
        img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)#RGB
        print("type:",type(img))
        return img


def time_sync():
    return time.time()



dataset = DatasetFolder(root="/media/flyfish/datasets/imagewoof/val")
weights = "/home/classes.pt"
device = "cpu"
model = DetectMultiBackend(weights, device=device, dnn=False, fp16=False)
model.eval()



def classify_transforms(img):
    img=resize(img,224)
    img=center_crop(img,224)
    img=totensor_normalize(img)
    return img;

pred, targets, loss, dt = [], [], 0, [0.0, 0.0, 0.0]
# current batch size =1
for i, (images, labels) in enumerate(dataset):
    print("i:", i)
    print(images.shape, labels)
    im = classify_transforms(images)


    images=torch.from_numpy(im).to(torch.float32) # numpy to tensor
    images = images.unsqueeze(0).to("cpu")
 
    print(images.shape)


        
    t1 = time_sync()
    images = images.to(device, non_blocking=True)
    t2 = time_sync()
    # dt[0] += t2 - t1

    y = model(images)
    y=y.numpy()
   
    print("y:", y)
    t3 = time_sync()
    # dt[1] += t3 - t2

    tmp1=y.argsort()[:,::-1][:, :5]
   
    print("tmp1:", tmp1)
    pred.append(tmp1)

    print("labels:", labels)

    
    targets.append(labels)

    print("for pred:", pred)  # list
    print("for targets:", targets)  # list

    # dt[2] += time_sync() - t3


pred, targets = np.concatenate(pred), np.array(targets)
print("pred:", pred)
print("pred:", pred.shape)
print("targets:", targets)
print("targets:", targets.shape)
correct = ((targets[:, None] == pred)).astype(np.float32)
print("correct:", correct.shape)
print("correct:", correct)
acc = np.stack((correct[:, 0], correct.max(1)), axis=1)  # (top1, top5) accuracy
print("acc:", acc.shape)
print("acc:", acc)
top = acc.mean(0)
print("top1:", top[0])
print("top5:", top[1])

结果

pred: [[0 3 6 2 1]
 [0 7 2 9 3]
 [0 5 6 2 9]
 ...
 [9 8 7 6 1]
 [9 3 6 7 0]
 [9 5 0 2 7]]
pred: (3929, 5)
targets: [0 0 0 ... 9 9 9]
targets: (3929,)
correct: (3929, 5)
correct: [[          1           0           0           0           0]
 [          1           0           0           0           0]
 [          1           0           0           0           0]
 ...
 [          1           0           0           0           0]
 [          1           0           0           0           0]
 [          1           0           0           0           0]]
acc: (3929, 2)
acc: [[          1           1]
 [          1           1]
 [          1           1]
 ...
 [          1           1]
 [          1           1]
 [          1           1]]
top1: 0.86230594
top5: 0.98167473

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/183722.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

逻辑回归

目录 第1关&#xff1a;逻辑回归核心思想 相关知识 什么是逻辑回归 编程要求 代码文件 第2关&#xff1a;逻辑回归的损失函数 相关知识 为什么需要损失函数 逻辑回归的损失函数 题目答案 第3关&#xff1a;梯度下降 相关知识 什么是梯度 梯度下降算法原理 编程要…

电线电缆、漆包线工厂开源MES/生产管理系统/云MES

万界星空科技专业的漆包线MES系统功能介绍&#xff1a; 从原材料出入库-拉丝机等设备管理-漆包线称重打印系统自动入库&#xff08;支持多台秤同时称重&#xff09;-建立销售报价、销售订单-生产订单-支持扫码出库及自动拣货出库-应收应付账款-对接各种其他系统及财务系统。 …

(免费领源码)java#springboot#mysql流浪动物救助系统78174-计算机毕业设计项目选题推荐

摘 要 21世纪的今天&#xff0c;随着社会的不断发展与进步&#xff0c;人们对于信息科学化的认识&#xff0c;已由低层次向高层次发展&#xff0c;由原来的感性认识向理性认识提高&#xff0c;管理工作的重要性已逐渐被人们所认识&#xff0c;科学化的管理&#xff0c;使信息存…

NX二次开发UF_CURVE_add_string_to_ocf_data 函数介绍

文章作者&#xff1a;里海 来源网站&#xff1a;https://blog.csdn.net/WangPaiFeiXingYuan UF_CURVE_add_string_to_ocf_data Defined in: uf_curve.h int UF_CURVE_add_string_to_ocf_data(tag_t string_tag, int offset_direction, int num_offsets, UF_CURVE_ocf_values_…

芯片的测试方法

半导体的生产流程包括晶圆制造和封装测试&#xff0c;在这两个环节中分别需要完成晶圆检测(CP, Circuit Probing)和成品测试(FT, Final Test)。无论哪个环节&#xff0c;要测试芯片的各项功能指标均须完成两个步骤&#xff1a;一是将芯片的引脚与测试机的功能模块连接起来&…

一些好用的前端小插件(转自知乎)

一些好用的前端小插件&#xff08;2&#xff09; 1. cropper.js Cropper.js 2.0 是一系列用于图像裁剪的 Web 组件。 官网地址&#xff1a;https://fengyuanchen.github.io/cropperjs/v2/zh/ 2. Vditor Vditor是一款浏览器端的 Markdown 编辑器&#xff0c;支持所见即所得、…

点击按钮,按钮的文字变为倒计时,的小技巧(适用于获取验证码)

看效果图&#xff1a; 代码 <a-buttonclick"getSms":disabled"myState.smsSendFlag"v-text"(!myState.smsSendFlag && 获取验证码) || ${myState.time} s" ></a-button>data(){return {myState: {smsSendFlag: false,tim…

【【Linux 常用命令学习 之 一 】】

Linux 常用命令学习 之 一 打开终端之后的 我们会了解 所使用的 字符串含义 其中前面的 zhuxushuai 是 当前的用户名字 接下来的 zhuxushuai-virtual-machine 是 机器名字 最后的符号 $表示 当前是普通用户 输入指令 ls 是打印出当前所在目录中所有文件和文件夹 shell 操…

BUUCTF [WUSTCTF2020]find_me 1

BUUCTF:https://buuoj.cn/challenges 题目描述&#xff1a; 得到的 flag 请包上 flag{} 提交。 感谢 Iven Huang 师傅供题。 比赛平台&#xff1a;https://ctfgame.w-ais.cn/ 密文&#xff1a; 下载附件&#xff0c;得到一个.jpg图片。 解题思路&#xff1a; 1、得到一张图…

vue3-组件传参及计算属性

​&#x1f308;个人主页&#xff1a;前端青山 &#x1f525;系列专栏&#xff1a;Vue篇 &#x1f516;人终将被年少不可得之物困其一生 依旧青山,本期给大家带来vue篇专栏内容:vue3-组件传参及计算属性 目录 vue3中的组件传参 1、父传子 2、子传父 toRef 与 toRefs vue3中…

【LeetCode刷题笔记】DFSBFS(三)

图的基础知识 邻接矩阵是一个二维表,其中横纵坐标交叉的格子值为 1 的表示这两个顶点是连通的,否则是不连通的。

反爬虫机制与反爬虫技术(二)

反爬虫机制与反爬虫技术二 1、动态页面处理与验证码识别概述2、反爬虫案例:页面登录与滑块验证码处理2.1、用例简介2.2、库(模块)简介2.3、网页分析2.4、Selenium准备操作2.5、页面登录2.6、模糊移动滑块测试3、滑块验证码处理:精确移动滑块3.1、精确移动滑块的原理3.2、滑…

NX二次开发UF_CSYS_ask_wcs 函数介绍

文章作者&#xff1a;里海 来源网站&#xff1a;https://blog.csdn.net/WangPaiFeiXingYuan UF_CSYS_ask_wcs Defined in: uf_csys.h int UF_CSYS_ask_wcs(tag_t * wcs_id ) overview 概述 Gets the object identifier of the coordinate system to which the work coordin…

Matlab三角剖分插值问题分析

目录 前言 一、问题引入 二、一个例子 1.生成散点图 2.对数据进行剖分 3.点法式分析 三、最后结果 前言 上一篇文章感觉对三角剖分问题没有说清楚&#xff0c;这次专门对三角剖分问题再仔细说说。 一、问题引入 实际上这个问题是用来解决二维曲面插值问题的。 二维插值问题&…

python tkinter使用(五)

python tkinter使用(五) 本篇文章讲述tkinter 中treeview的使用 Treeview是一个多列列表框&#xff0c;可以显示层次数据。 #!/usr/bin/python3 # -*- coding: UTF-8 -*- """Author: zhTime 2023/11/23 下午8:28 .Email:Describe: treeview 使用 "&quo…

YB4051系列设备是高度集成的 Li-lon 和 Li-Pol 线性充电器,针对便携式应用的小容量电池。

YB4051H 300mA 单电池锂离子电池充电器0.1 mA 终端&#xff0c;45nA 电池漏电流 概述&#xff1a; YB4051系列设备是高度集成的 Li-lon 和 Li-Pol 线性充电器&#xff0c;针对便携式应用的小容量电池。它是一个完整的恒流/恒压线性充电器。不需要外部感应电阻&#xff0c;由于…

〖大前端 - 基础入门三大核心之JS篇㊷〗- DOM事件对象及它的属性

说明&#xff1a;该文属于 大前端全栈架构白宝书专栏&#xff0c;目前阶段免费&#xff0c;如需要项目实战或者是体系化资源&#xff0c;文末名片加V&#xff01;作者&#xff1a;不渴望力量的哈士奇(哈哥)&#xff0c;十余年工作经验, 从事过全栈研发、产品经理等工作&#xf…

vulnhub6

靶机地址&#xff1a;https://download.vulnhub.com/evilbox/EvilBox---One.ova 准备工作 可以先安装 kali 的字典: sudo apt install seclists ​ 或者直接输入 seclists​&#xff0c;系统会问你是否安装&#xff0c;输入 y 即可自动安装 733 x 3751414 x 723 ​ 默认路…

opencv 常用操作指南

1.通道交换 读取图像&#xff0c;然后将RGB通道替换成BGR通道&#xff0c;需要注意的是&#xff0c;opencv读取的图像默认是BGR。cv2.cvtColor函数可以参考Color Space Conversions img cv2.imread(imori.jpg) img cv2.cvtColor(img, cv2.COLOR_BGR2RGB) cv2.imwrite(answe…

HarmonyOS(五)—— 认识页面和自定义组件生命周期

前言 在前面我们通过如何创建自定义组件一文知道了如何如何自定义组件以及自定义组件的相关注意事项&#xff0c;接下来我们认识一下页面和自定义组件生命周期。 自定义组件和页面的关系 在开始之前&#xff0c;我们先明确自定义组件和页面的关系 自定义组件&#xff1a;Co…