自定义数据集 - Dataset

文章目录

    • 1. PASCAL VOC格式 划分训练集和验证集
    • 2. 自定义dataset

1. PASCAL VOC格式 划分训练集和验证集

import os
import random

def main():
    random.seed(0)  # 设置随机种子,保证随机结果可复现

    files_path = "./VOCdevkit/VOC2012/Annotations"  # 指定annotations目录
    assert os.path.exists(files_path), "path: '{}' does not exist.".format(files_path)

    val_rate = 0.5  # 定义划分比例

    files_name = sorted([file.split(".")[0] for file in os.listdir(files_path)])
    files_num = len(files_name)
    val_index = random.sample(range(0, files_num), k=int(files_num*val_rate))
    train_files = []
    val_files = []
    for index, file_name in enumerate(files_name):
        if index in val_index:
            val_files.append(file_name)
        else:
            train_files.append(file_name)

    try:
        train_f = open("train.txt", "x")
        eval_f = open("val.txt", "x")
        train_f.write("\n".join(train_files))
        eval_f.write("\n".join(val_files))
    except FileExistsError as e:
        print(e)
        exit(1)

if __name__ == '__main__':
    main()

2. 自定义dataset

自定义自己的数据集,需要继承torch.utils.data.Dataset,并且实现__len____getitem__方法。

  • __len__:获取数据集的大小
  • __getitem__:返回数据信息

如果使用多GPU训练,还需要实现get_height_and_width方法,获取图像的高度和宽度,如果不实现这个方法,它就会载入所有的图片,去计算图片的高和宽,这样的话就比较耗时和占内存。

在这里插入图片描述

from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etree

# 继承torch.utils.data.Dataset
class VOCDataSet(Dataset):
    """读取解析PASCAL VOC2007/2012数据集"""
    """
    voc_root:数据集所在根目录
    year:数据集年份
    transforms:数据预处理方法
    text_name:读取txt文件名称 train.txt/val.txt
    """
    def __init__(self, voc_root, year="2012", transforms=None, txt_name: str = "train.txt"):
        assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
        # 增加容错能力
        if "VOCdevkit" in voc_root:
            self.root = os.path.join(voc_root, f"VOC{year}")
        else:
            self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{year}")
        self.img_root = os.path.join(self.root, "JPEGImages")
        self.annotations_root = os.path.join(self.root, "Annotations")

        # 读取 train.txt 或 val.txt文件
        txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
        assert os.path.exists(txt_path), "not found {} file.".format(txt_name)

        with open(txt_path) as read:
            # strip() 方法用于移除字符串头尾指定的字符
            xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                        for line in read.readlines() if len(line.strip()) > 0]

        self.xml_list = []
        # 确认文件是否存在
        for xml_path in xml_list:
            if os.path.exists(xml_path) is False:
                print(f"Warning: not found '{xml_path}', skip this annotation file.")
                continue

            # 排除图像中没有目标的数据
            with open(xml_path) as fid:
                xml_str = fid.read()
            xml = etree.fromstring(xml_str)
            data = self.parse_xml_to_dict(xml)["annotation"]
            if "object" not in data:
                print(f"INFO: no objects in {xml_path}, skip this annotation file.")
                continue

            self.xml_list.append(xml_path)

        assert len(self.xml_list) > 0, "in '{}' file does not find any information.".format(txt_path)

        # 读取数据集类别信息
        json_file = './pascal_voc_classes.json'
        assert os.path.exists(json_file), "{} file not exist.".format(json_file)
        with open(json_file, 'r') as f:
            self.class_dict = json.load(f)

        self.transforms = transforms

    # 获取数据集大小
    def __len__(self):
        return len(self.xml_list)

    # 根据传入的索引值获取数据信息
    def __getitem__(self, idx):
        # 读取xml文件
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        img_path = os.path.join(self.img_root, data["filename"])
        image = Image.open(img_path)
        if image.format != "JPEG":
            raise ValueError("Image '{}' format not JPEG".format(img_path))

        boxes = []
        labels = []
        iscrowd = []
        assert "object" in data, "{} lack of object information.".format(xml_path)
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])

            # 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
            if xmax <= xmin or ymax <= ymin:
                print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
                continue
            
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            if "difficult" in obj:
                iscrowd.append(int(obj["difficult"]))
            else:
                iscrowd.append(0)

        # 将数据转为tensor类型
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target

    # 获取图像高度和宽度
    def get_height_and_width(self, idx):
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        data_height = int(data["size"]["height"])
        data_width = int(data["size"]["width"])
        return data_height, data_width

    """
    将xml文件解析成字典形式
    """
    def parse_xml_to_dict(self, xml):
        if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

    def coco_index(self, idx):
        """
        该方法是专门为pycocotools统计标签信息准备,不对图像和标签作任何处理
        由于不用去读取图片,可大幅缩减统计时间

        Args:
            idx: 输入需要获取图像的索引
        """
        # read xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        data_height = int(data["size"]["height"])
        data_width = int(data["size"]["width"])
        # img_path = os.path.join(self.img_root, data["filename"])
        # image = Image.open(img_path)
        # if image.format != "JPEG":
        #     raise ValueError("Image format not JPEG")
        boxes = []
        labels = []
        iscrowd = []
        for obj in data["object"]:
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            iscrowd.append(int(obj["difficult"]))

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        return (data_height, data_width), target

    @staticmethod
    def collate_fn(batch):
        return tuple(zip(*batch))

注意:在进行数据预处理的时候,与进行图像分类数据预处理操作有些不同,如进行图像的翻转时,bbox的坐标信息也应该随之进行改变。

在这里插入图片描述

所以,我们自己定义了一个transform.py

import random
from torchvision.transforms import functional as F

class Compose(object):
    """组合多个transform函数"""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor(object):
    """将PIL图像转为Tensor"""
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

class RandomHorizontalFlip(object):
    """随机水平翻转图像以及bboxes"""
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)  # 水平翻转图片
            bbox = target["boxes"]
            # bbox: xmin, ymin, xmax, ymax
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]  # 翻转对应bbox坐标信息
            target["boxes"] = bbox
        return image, target

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/339601.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Sentinel限流规则支持流控效果

流控效果是指请求达到流控阈值时应该采取的措施&#xff0c;包括三种&#xff1a; 1.快速失败&#xff1a;达到阈值后&#xff0c;新的请求会被立即拒绝并抛出FlowException异常。是默认的处理方式。 2.warm up&#xff1a;预热模式&#xff0c;对超出阈值的请求同样是拒绝并抛…

网络安全的信息收集方法有哪些?

网络安全攻击中的信息收集是攻击者为了了解目标系统的弱点、配置、环境和潜在的防御措施而进行的活动。以下是一些常见的信息收集手段&#xff1a; 开放网络资源查询&#xff1a; 使用搜索引擎查找关于目标组织的信息&#xff0c;包括新闻稿、社交媒体帖子、官方网站等。通过W…

答案之书程序

答案之书程序 需求&#xff1a;用户输入手机号码后4位或者生日&#xff0c;自动生成答案之书对应答案 效果图 C#代码实现过程 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq;…

新买电脑配置不低却卡顿?

目录 前言&#xff1a; 电脑卡顿的原因 Windows 10必做的系统优化 禁用 IP Helper 关闭系统通知 机械硬盘开启优化驱动器功能 开启存储感知 前言&#xff1a; 新买的电脑配置不低&#xff0c;但却卡顿甚至程序不反应&#xff0c;这是怎么回事儿&#xff1f; 其实并不…

获取主流电商平台商品价格,库存信息,数据分析,SKU详情

要接入API接口以采集电商平台上的商品数据&#xff0c;可以按照以下步骤进行&#xff1a; 1、找到可用的API接口&#xff1a;首先&#xff0c;需要找到支持查询商品信息的API接口。这些信息通常可以在电商平台的官方文档或开发者门户网站上找到。 2、注册并获取API密钥&#x…

【代码随想录算法训练营第二十四天|回溯算法的理论基础、77. 组合】

代码随想录算法训练营第二十四天|回溯算法的理论基础、77. 组合 回溯算法的理论基础77. 组合 回溯算法的理论基础 这里我觉得《代码随想录》和y总的课都比较好了 《代码随想录》 &#xff1a; https://programmercarl.com/0077.%E7%BB%84%E5%90%88%E4%BC%98%E5%8C%96.html#%E5…

成人高考和自考到底应该选哪个呢?

在成人学历提升的各项方式之中 成人高考与自学考试经常会被人拿来对比 但它们之间的差别在哪里 又分别去适合什么类型的考生 成考自考报名一般8月底开始&#xff0c;要准备考试的考生需要提前做好准备了哦 成考自考报名都需要上传证件照&#xff0c;而且都很严格 大家可使用小程…

react 页签(自行封装)

思路&#xff1a;封装一个页签组件&#xff0c;包裹页面组件&#xff0c;页面渲染之后把数据缓存到全局状态实现页面缓存。 浏览本博客之前先看一下我的博客实现的功能是否满足需求&#xff0c;实现功能&#xff1a; - 页面缓存 - 关闭当前页 - 鼠标右键>关闭当前 - 鼠标右…

启发式教学是什么

学生们在上课时看似认真听讲&#xff0c;但是在下课后却一片茫然&#xff0c;不知道你讲了什么内容&#xff1f;这是因为你可能使用了传统的教学方法&#xff0c;而不是启发式教学。 启发式教学是指老师在教育教学中&#xff0c;采用引导、启示、激发等手段&#xff0c;调动学…

主板电路学习; 华硕ASUS K43SD笔记本安装win7X64(ventoy)

记录 老爷机 白色 华硕 K43SD 笔记本 安装 win7X64 1. MBR样式常规安装win7X64Sp1 (华硕 K43SD 安装 win7X64 ) 老爷机 白色 华硕 K43SD 笔记本 安装 win7X64 &#xff08;常规安装&#xff09; 设置&#xff1a; 禁用UEFI 启用AHCI ventoy制作MBR&#xff08;非UEFI&#…

安全防御-基础认知

目录 安全风险能见度不足&#xff1a; 常见的网络安全术语 &#xff1a; 常见安全风险 网络的基本攻击模式&#xff1a; 病毒分类&#xff1a; 病毒的特征&#xff1a; 常见病毒&#xff1a; 信息安全的五要素&#xff1a; 信息安全的五要素案例 网络空间&#xff1a…

Anything本地知识库问答系统:基于检索增强生成式应用(RAG)两阶段检索、支持海量数据、跨语种问答

QAnything本地知识库问答系统&#xff1a;基于检索增强生成式应用&#xff08;RAG&#xff09;两阶段检索、支持海量数据、跨语种问答 QAnything (Question and Answer based on Anything) 是致力于支持任意格式文件或数据库的本地知识库问答系统&#xff0c;可断网安装使用。…

Vue diff原理

✨ 专栏介绍 在当今Web开发领域中&#xff0c;构建交互性强、可复用且易于维护的用户界面是至关重要的。而Vue.js作为一款现代化且流行的JavaScript框架&#xff0c;正是为了满足这些需求而诞生。它采用了MVVM架构模式&#xff0c;并通过数据驱动和组件化的方式&#xff0c;使…

DAY15--learning English

一、积累 1.loyalty Bro had loyalty on that. 老兄对那个东西情有独钟。 2. consent its illegal to film anyone without consent in many country 。 在一些国家里面&#xff0c;没有经过别人的同意就去拍摄别人是违法的。 3. butt 4.disciplinary Welcome to our discip…

Django(九)

1. 用户登录-Cookie和Session 什么是cookie和session&#xff1f; 发送HTTP请求或者HTTPS请求(无状态&短连接) http://127.0.0.1:8000/admin/list/ https://127.0.0.1:8000/admin/list/http无状态短连接&#xff1a;一次请求响应之后断开连接&#xff0c;再发请求重新连…

CentOS 系统创建网卡bond0

很多时候在机房运维的过程中&#xff0c;我们会遇到客户要求的建立网卡光口的bond0设置&#xff0c;通俗点说就是将两个光口合并为一个口进行链接设置。创建这个设置是有两种设置&#xff0c;一是在安装系统的过程中对bond0进行创建设置&#xff0c;另一种就是通过系统里面对网…

C波段数据链和DAA技术实现BVLOS超视距飞行-乔克托族BEYOND计划获得FAA扩大批准

俄克拉荷马州的乔克托族(CNO)超越计划宣布&#xff0c;已从联邦航空管理局(FAA)获得了扩大的超视距(BVLOS)操作豁免的批准。这扩展了原有的豁免范围&#xff0c;覆盖约43英里长的区域&#xff0c;包括CNO医疗诊所、CNO新兴航空技术中心测试场以及其他设施。这是目前美国境内同类…

文件上传笔记整理

文件上传 web渗透的核心&#xff0c;内网渗透的基础 通过上传webshell文件到对方的服务器来获得对方服务器的控制权 成功条件 文件成功上传到对方的服务器&#xff08;躲过杀软&#xff09; 知道文件上传的具体路径 上传的文件可以执行成功 文件上传的流程 前端JS对上传文件进行…

美易官方《投资抄底5年期美债?》

摩根士丹利和摩根大通近期发布报告&#xff0c;建议投资者抄底5年期美债。这两家知名投行认为&#xff0c;当前5年期美债收益率已经处于历史低位&#xff0c;而经济复苏和通胀预期将使得债券价格上涨。 摩根士丹利预计&#xff0c;美国国债有反弹的空间&#xff0c;预计未来几…

【GitHub项目推荐--不错的 C 开源项目】【转载】

大学时接触的第一门语言就是 C语言&#xff0c;虽然距 C语言创立已过了40多年&#xff0c;但其经典性和可移植性任然是当今众多高级语言中不可忽视的&#xff0c;想要学好其他的高级语言&#xff0c;最好是先从掌握 C语言入手。 今天老逛盘点 GitHub 上不错的 C语言 开源项目&…