【代码整理】基于COCO格式的pytorch Dataset类实现

import模块

import numpy as np
import torch
from functools import partial
from PIL import Image
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import random
import albumentations as A
from pycocotools.coco import COCO
import os
import cv2
import matplotlib.pyplot as plt

基于albumentations库自定义数据预处理/数据增强

class Transform():
    '''数据预处理/数据增强(基于albumentations库)
    '''
    def __init__(self, imgSize):
        maxSize = max(imgSize[0], imgSize[1])
        # 训练时增强
        self.trainTF = A.Compose([
                A.BBoxSafeRandomCrop(p=0.5),
                # 最长边限制为imgSize
                A.LongestMaxSize(max_size=maxSize),
                A.HorizontalFlip(p=0.5),
                # 参数:随机色调、饱和度、值变化
                A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=False, p=0.5),
                # 随机明亮对比度
                A.RandomBrightnessContrast(p=0.2),   
                # 高斯噪声
                A.GaussNoise(var_limit=(0.05, 0.09), p=0.4),     
                A.OneOf([
                    # 使用随机大小的内核将运动模糊应用于输入图像
                    A.MotionBlur(p=0.2),   
                    # 中值滤波
                    A.MedianBlur(blur_limit=3, p=0.1),    
                    # 使用随机大小的内核模糊输入图像
                    A.Blur(blur_limit=3, p=0.1),  
                ], p=0.2),
                # 较短的边做padding
                A.PadIfNeeded(imgSize[0], imgSize[1], border_mode=cv2.BORDER_CONSTANT, value=[0,0,0]),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ],
            bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),
            )
        # 验证时增强
        self.validTF = A.Compose([
                # 最长边限制为imgSize
                A.LongestMaxSize(max_size=maxSize),
                # 较短的边做padding
                A.PadIfNeeded(imgSize[0], imgSize[1], border_mode=0, mask_value=[0,0,0]),
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ],
            bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),
            )

自定义数据集读取类COCODataset实现


class COCODataset(Dataset):

    def __init__(self, annPath, imgDir, inputShape=[800, 600], trainMode=True):
        '''__init__() 为默认构造函数,传入数据集类别(训练或测试),以及数据集路径

        Args:
            :param annPath:     COCO annotation 文件路径
            :param imgDir:      图像的根目录
            :param inputShape: 网络要求输入的图像尺寸
            :param trainMode:   训练集/测试集

        Returns:
            FRCNNDataset
        '''      
        self.mode = trainMode
        self.tf = Transform(imgSize=inputShape)
        self.imgDir = imgDir
        self.annPath = annPath
        self.DataNums = len(os.listdir(imgDir))
        # 为实例注释初始化COCO的API
        self.coco=COCO(annPath)
        # 获取数据集中所有图像对应的imgId
        self.imgIds = list(self.coco.imgs.keys())

    def __len__(self):
        '''重载data.Dataset父类方法, 返回数据集大小
        '''
        return len(self.imgIds)

    def __getitem__(self, index):
        '''重载data.Dataset父类方法, 获取数据集中数据内容
           这里通过pycocotools来读取图像和标签
        '''   
        # 通过imgId获取图像信息imgInfo: 例:{'id': 12465, 'license': 1, 'height': 375, 'width': 500, 'file_name': '2011_003115.jpg'}
        imgId = self.imgIds[index]
        imgInfo = self.coco.loadImgs(imgId)[0]
        # 载入图像 (通过imgInfo获取图像名,得到图像路径)               
        image = Image.open(os.path.join(self.imgDir, imgInfo['file_name']))
        image = np.array(image.convert('RGB'))
        # 得到图像里包含的BBox的所有id
        imgAnnIds = self.coco.getAnnIds(imgIds=imgId)   
        # 通过BBox的id找到对应的BBox信息
        anns = self.coco.loadAnns(imgAnnIds) 
        # 获取BBox的坐标和类别
        labels, boxes = [], []
        for ann in anns:
            labelName = ann['category_id']
            labels.append(labelName)
            boxes.append(ann['bbox'])
        labels = np.array(labels)
        boxes = np.array(boxes)
        
        # 训练/验证时的数据增强各不相同
        if(self.mode):
            # albumentation的图像维度得是[W,H,C]
            transformed = self.tf.trainTF(image=image, bboxes=boxes, category_ids=labels)
        else:
            transformed = self.tf.validTF(image=image, bboxes=boxes, category_ids=labels)
        # 这里的box是coco格式(xywh)
        image, box, label = transformed['image'], transformed['bboxes'], transformed['category_ids']
        return image.transpose(2,0,1), np.array(box), np.array(label)

其他

# DataLoader中collate_fn参数使用
# 由于检测数据集每张图像上的目标数量不一
# 因此需要自定义的如何组织一个batch里输出的内容
def frcnn_dataset_collate(batch):
    images = []
    bboxes = []
    labels = []
    for img, box, label in batch:
        images.append(img)
        bboxes.append(box)
        labels.append(label)
    images = torch.from_numpy(np.array(images))
    return images, bboxes, labels



# 设置Dataloader的种子
# DataLoader中worker_init_fn参数使
# 为每个 worker 设置了一个基于初始种子和 worker ID 的独特的随机种子, 这样每个 worker 将产生不同的随机数序列,从而有助于数据加载过程的随机性和多样性
def worker_init_fn(worker_id, seed):
    worker_seed = worker_id + seed
    random.seed(worker_seed)
    np.random.seed(worker_seed)
    torch.manual_seed(worker_seed)


# 固定全局随机数种子
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

batch数据集可视化

def visBatch(dataLoader:DataLoader):
    '''可视化训练集一个batch
    Args:
        dataLoader: torch的data.DataLoader
    Retuens:
        None     
    '''
    catName = {1:'person', 2:'bicycle', 3:'car', 4:'motorcycle', 5:'airplane', 6:'bus',
               7:'train', 8:'truck', 9:'boat', 10:'traffic light', 11:'fire hydrant',
               13:'stop sign', 14:'parking meter', 15:'bench', 16:'bird', 17:'cat', 18:'dog',
               19:'horse', 20:'sheep', 21:'cow', 22:'elephant', 23:'bear', 24:'zebra', 25:'giraffe',
               27:'backpack', 28:'umbrella', 31:'handbag', 32:'tie', 33:'suitcase', 34:'frisbee',
               35:'skis', 36:'snowboard', 37:'sports ball', 38:'kite', 39:'baseball bat',
               40:'baseball glove', 41:'skateboard', 42:'surfboard', 43:'tennis racket',
               44:'bottle', 46:'wine glass', 47:'cup', 48:'fork', 49:'knife', 50:'spoon', 51:'bowl',
               52:'banana', 53:'apple', 54:'sandwich', 55:'orange', 56:'broccoli', 57:'carrot',
               58:'hot dog', 59:'pizza', 60:'donut', 61:'cake', 62:'chair', 63:'couch',
               64:'potted plant', 65:'bed', 67:'dining table', 70:'toilet', 72:'tv', 73:'laptop',
               74:'mouse', 75:'remote', 76:'keyboard', 77:'cell phone', 78:'microwave',
               79:'oven', 80:'toaster', 81:'sink', 82:'refrigerator', 84:'book', 85:'clock',
               86:'vase', 87:'scissors', 88:'teddy bear', 89:'hair drier', 90:'toothbrush'}
    
    for step, batch in enumerate(dataLoader):
        images, boxes, labels = batch[0], batch[1], batch[2]
        # 只可视化一个batch的图像:
        if step > 0: break
        # 图像均值
        mean = np.array([0.485, 0.456, 0.406]) 
        # 标准差
        std = np.array([[0.229, 0.224, 0.225]]) 
        plt.figure(figsize = (8,8))
        for idx, imgBoxLabel in enumerate(zip(images, boxes, labels)):
            img, box, label = imgBoxLabel
            ax = plt.subplot(4,4,idx+1)
            img = img.numpy().transpose((1,2,0))
            # 由于在数据预处理时我们对数据进行了标准归一化,可视化的时候需要将其还原
            img = img * std + mean
            for instBox, instLabel in zip(box, label):
                x, y, w, h = round(instBox[0]),round(instBox[1]), round(instBox[2]), round(instBox[3])
                # 显示框
                ax.add_patch(plt.Rectangle((x, y), w, h, color='blue', fill=False, linewidth=2))
                # 显示类别
                ax.text(x, y, catName[instLabel], bbox={'facecolor':'white', 'alpha':0.5})
            plt.imshow(img)
            # 在图像上方展示对应的标签
            # 取消坐标轴
            plt.axis("off")
             # 微调行间距
            plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.05, hspace=0.05)
        plt.show()

example

# for test only:
if __name__ == "__main__":
    # 固定随机种子
    seed = 23
    seed_everything(seed)
    # BatcchSize
    BS = 16
    # 图像尺寸
    imgSize = [800, 800]

    trainAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_train2017.json"
    testAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_val2017.json"
    imgDir =  "E:/datasets/Universal/COCO2017/train2017"
    # 自定义数据集读取类
    trainDataset = COCODataset(trainAnnPath, imgDir, imgSize, trainMode=True)
    trainDataLoader = DataLoader(trainDataset, shuffle=True, batch_size = BS, num_workers=2, pin_memory=True,
                                    collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))
    # validDataset = COCODataset(testAnnPath, imgDir, imgSize, trainMode=False)
    # validDataLoader = DataLoader(validDataset, shuffle=True, batch_size = BS, num_workers = 1, pin_memory=True, 
                                  # collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))



    print(f'训练集大小 : {trainDataset.__len__()}')
    visBatch(trainDataLoader)
    for step, batch in enumerate(trainDataLoader):
        images, boxes, labels = batch[0], batch[1], batch[2]
        # torch.Size([bs, 3, 800, 800])
        print(f'images.shape : {images.shape}')   
        # 列表形式,因为每个框里的实例数量不一,所以每个列表里的box数量不一
        print(f'len(boxes) : {len(boxes)}')     
        # 列表形式,因为每个框里的实例数量不一,所以每个列表里的label数量不一  
        print(f'len(labels) : {len(labels)}')     
        break

输出

在这里插入图片描述

images.shape : torch.Size([16, 3, 800, 800])
len(boxes) : 16
len(labels) : 16

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/336997.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【机器学习300问】14、什么是特征工程?

当我学习到这个知识点的时候十分困惑,因为从名字中我完全无法理解这个什么东西。于是呢我就去问了一下维基百科,下面是他的回答: 特征工程(英语:feature engineering)又称特征提取(英语&#xf…

Jetbrains Writerside 使用教程

系列文章目录 前言 一、入门 Writerside 是基于 IntelliJ 平台的 JetBrains 集成开发环境。使用它可以编写、构建、测试和发布技术文档。 如果你想将 Writerside 作为另一个 JetBrains IDE 的插件,请参阅 Writerside 作为插件。 1.1 安装 Writerside…

cetos7搭建部署k8s 版本1.28

主机分配 内存最少是4G cpu个数最少两个 IP内存CPU主机名192.168.231.12044K1 192.168.231.12144K2192.168.231.12244K3 关闭防火墙 systemctl stop firewalled 关闭swap vim /etc/fstab 设置主机名称 hostnameset 安装docker 三个主机 初始化集群 在mas…

相关系数(皮尔逊相关系数和斯皮尔曼相关系数)

本文借鉴了数学建模清风老师的课件与思路,可以点击查看链接查看清风老师视频讲解:5.1 对数据进行描述性统计以及皮尔逊相关系数的计算方法_哔哩哔哩_bilibili 注:直接先看 ( 三、两个相关系数系数的比较 ) 部分&#x…

VC++中使用OpenCV进行颜色检测

VC中使用OpenCV进行颜色检测 在VC中使用OpenCV进行颜色检测非常简单,首选读取一张彩色图像,并调用函数cvtColor(img, imgHSV, COLOR_BGR2HSV);函数将原图img转换成HSV图像imgHSV,再设置好HSV三个分量的上限和下限值,调用inRange函…

自动化测试:5分钟了解Selenium以及如何提升自动化测试的效果

在快节奏的技术世界里,自动化测试已经成为确保 Web 应用程序质量和性能的重要手段。自动化测试不仅加快了测试过程,还提高了测试的重复性和准确性。Selenium,作为领先的自动化测试工具之一,为测试人员提供了强大的功能来模拟用户在…

C++-类和对象(3)

1. 再谈构造函数 1.1 构造函数体赋值 我们在创建一个对象时,编译器会调用该对象的构造函数对该对象的成员进行初始化。 class Date { public:Date(int year, int month, int day){_year year;_month month;_day day;} private:int _year;int _month;int _day…

Linux系统安装Samba服务器

在实际开发中,我们经常会有跨系统之间文件传递的需求,Samba 便是能够在 Windows 和 Linux 之间传递文件的服务,功能也是非常强大和好用,本篇文章将介绍如何在 Linux 系统上安装 Samba 服务,以 CentOS7 系统为例。 一、…

SpringBoot:详解Bean生命周期和作用域

🏡浩泽学编程:个人主页 🔥 推荐专栏:《深入浅出SpringBoot》《java项目分享》 《RabbitMQ》《Spring》《SpringMVC》 🛸学无止境,不骄不躁,知行合一 文章目录 前言一、生命周期二…

大数据平台的硬件规划、网络调优、架构设计、节点规划

1.大数据平台硬件选型 要对Hadoop大数据平台进行硬件选型,首先需要了解Hadoop的运行架构以及每个角色的功能。在一个典型的Hadoop架构中,通常有5个角色,分别是NameNode、Standby NameNode、ResourceManager、NodeManager、DataNode以及外围机。 其中 NameNode 负责协调集群…

手把手教你购买阿里云服务器以及Ubuntu环境下宝塔搭建网站

阿里云服务器Ubuntu通过宝塔搭建网站详细教程 前言一、阿里云服务器的购买二、进入控制面板2.1 修改密码2.2 开放端口号 三、 测试服务器是否可以连接四、 安装nginx搭建网站(选做)五、安装宝塔5.1 登录宝塔官网5.2 卸载预装的mysql和nginx5.3 安装宝塔5.4 访问宝塔控制台5.5 修…

CSS:backdrop-filter实现毛玻璃的效果

实现效果 实现代码 /* 关键属性 */ background-color: rgba(255, 255, 255, 0.4); backdrop-filter: blur(10px); -webkit-backdrop-filter: blur(10px);完整代码 <style>/* 遮罩层 */.mo-mask {position: fixed;top: 0;bottom: 0;left: 0;right: 0;width: 100%;height…

Hadoop3完全分布式搭建

一、第一台的操作搭建 修改主机名 使用hostnamectl set-hostname 修改当前主机名 关闭防火墙和SELlinux 1&#xff0c;使用 systemctl stop firewalld systemctl disable firewalld 关闭防火墙 2&#xff0c;使用 vim /etc/selinux/config 修改为 SELINUXdisabled 使用N…

考研C语言刷题基础篇之分支循环结构基础(二)

目录 第一题分数求和 第二题&#xff1a;求10 个整数中最大值 第三题&#xff1a;在屏幕上输出9*9乘法口诀表 第四题&#xff1a;写一个代码&#xff1a;打印100~200之间的素数 第五题&#xff1a;求斐波那契数的第N个数 斐波那契数的概念&#xff1a;前两个数相加等于第三…

爬虫进阶之selenium模拟浏览器

爬虫进阶之selenium模拟浏览器 简介环境配置1、建议先安装conda2、创建虚拟环境并安装对应的包3、下载对应的谷歌驱动以及与驱动对应的浏览器 代码setting.py配置scrapy脚本参考中间件middlewares.py 附录&#xff1a;selenium教程 简介 Selenium是一个用于自动化浏览器操作的…

继电器开关电路图大全

继电器是一种电控制器件&#xff0c;是当输入量&#xff08;激励量&#xff09;的变化达到规定要求时&#xff0c;在电气输出电路中使被控量发生预定的阶跃变化的一种电器。它具有控制系统&#xff08;又称输入回路&#xff09;和被控制系统&#xff08;又称输出回路&#xff0…

USB-C接口给显示器带来怎样的变化?

随着科技的不断发展&#xff0c;Type-C接口已经成为现代电子设备中常见的接口标准。它不仅可以提供高速的数据传输&#xff0c;还可以实现快速充电和视频传输等功能。因此&#xff0c;使用Type-C接口的显示器方案也受到了广泛的关注。本文将介绍Type-C接口显示器的优势、应用场…

基于C++11的数据库连接池【C++/数据库/多线程/MySQL】

一、概述 概述&#xff1a;数据库连接池可提前把多个数据库连接建立起来&#xff0c;然后把它放到一个池子里边&#xff0c;就是放到一个容器里边进行维护。这样的话就能够避免数据库连接的频繁的创建和销毁&#xff0c;从而提高程序的效率。线程池其实也是同样的思路&#xf…

二叉树基础oj题目

二叉树基础oj题目及思路总结 前文中&#xff0c;介绍了二叉树的基本概念及基础操作&#xff0c;进一步对于二叉树的递归遍历及子问题的处理思想有了一定的了解。本文将带来几道二叉树经典的oj题目。 目录 二叉树基础oj题目 对称二叉树平衡二叉树二叉树的层序遍历 二叉树基…

基于一次应用卡死问题所做的前端性能评估与优化尝试

问题背景 在上个月&#xff0c;由于客户反馈客户端卡死现象但我们远程却难以复现此现象&#xff0c;于是我们组织了一次现场上门故障排查&#xff0c;并希望基于此次观察与优化&#xff0c;为客户端开发提供一些整体的优化升级。当然&#xff0c;在尝试过程中&#xff0c;也发…