深度学习(17)--DataLoader自定义数据集制作

目录

DataLoader自定义数据集制作

1.从标注文件(txt文件)中读取数据和标签

2.分别把数据和标签存在两个list中

3.设置完整的图像数据路径

4.根据任务整合出一个数据处理类

5.数据预处理

6.使用定义好的类来实例化DataLoader

7.检查数据和标签是否对应

8.使用创建好的DataLoader


DataLoader工作原理:

DataLoader自定义数据集制作

1.从标注文件(txt文件)中读取数据和标签

def load_annotation(ann_file): # 参数为文本文件的路径
    # 创建一个字典结构用于保存数据,key作为图像的名字,value作为图像的标签
    data_infos = {}
    with open(ann_file) as f:
        # strip()去除一些换行符等
        # split(' ')是以空格为分隔符
        # samples是一个list,格式为图像名字,图像标签
        # eg:[['image11.jpg,'0'],['image22.jpg,'1'],['image33.jpg,'3']]
        samples = [x.strip().split(' ') for x in f.readlines()]
        for filename, gt_label in samples:
            # filename是图像名字--'image11.jpg',gt_label--'0'是标签,加载到字典data_infos中去
            # value值设置为array(gt_label,dtype=int64)类型
            data_infos[filename] = np.array(gt_label, dtype=np.int64)
        # 得到的字典格式:{'image11.jpg':array(0,dtype=int64),'image22.jpg':array(1,dtype=int64)}
    return data_infos

2.分别把数据和标签存在两个list中

img_label = load_annotation('./flower_data/train.txt')
image_name = list(img_label.keys())  # 取keys值
label = list(img_label.values())  # 取labels值

3.设置完整的图像数据路径

data_dir = './flower_data'  # 数据存放的磁盘
train_dir = data_dir + '/train_filelist'  # 训练集数据存放的磁盘
valid_dir = data_dir + '/valid_filelist'  # 验证集数据存放的磁盘
# 给设置好的图像名list设置图像路径
image_path = [os.path.join(train_dir, img) for img in image_name]

'''
os.path.join()函数:连接两个或更多的路径名组件
1.如果各组件名首字母不包含’/’,则函数会自动加上
2.第一个以”/”开头的参数开始拼接,之前的参数全部丢弃,当有多个时,从最后一个开始
3.如果最后一个组件为空,则生成的路径以一个’/’分隔符结尾 
'''

os.path.join()函数:连接两个或更多的路径名组件

  1. 如果各组件名首字母不包含’/’,则函数会自动加上
  2. 第一个以”/”开头的参数开始拼接,之前的参数全部丢弃,当有多个时,从最后一个开始
  3. 如果最后一个组件为空,则生成的路径以一个’/’分隔符结尾  

4.根据任务整合出一个数据处理类

数据处理类中构造函数__init__和数据获取函数__getitem__是必须存在的

class FlowerDataset(Dataset):  # 继承Dataset类
    # 构造函数必须存在
    def __init__(self, root_dir, ann_file, transform=None):
        self.ann_file = ann_file
        self.root_dir = root_dir
        self.img_label = self.load_annotations()  # img_label是一个字典
        self.img = [os.path.join(self.root_dir, img) for img in list(self.img_label.kesy())]
        self.label = [label for label in list(self.img_label.values())]
        self.transform = transform  # 数据需要做的预处理操作

    def __len__(self):
        return len(self.img)

    # 获取图像和标签交给模型,该函数必须存在
    # 不要修改参数,每次调用时会传入随机的idx
    # 一个batch的数据就是由__getitem__函数处理数据传入得到的
    def __getitem__(self, idx):
        image = Image.open(self.img[idx])  # img保存了图像的路径
        label = self.label[idx]
        if self.transform:
            image = self.transform(image)  # 对数据进行预处理操作
        label = torch.from_numpy(np.array(label))  # 转换label的数据类型,由list->numpy->tensor
        return image, label

    def load_annotations(self):
        data_infos = {}
        with open(self.ann_file) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filename, gt_label in samples:
                data_infos[filename] = np.array(gt_label, dtype=np.int64)
        return data_infos

5.数据预处理

# 创建一个字典结构的数据类型来进行图像预处理操作:key - value
data_transforms = {
    # 对训练集的预处理
    'train': transforms.Compose([
        transforms.Resize([96, 96]),  # 卷积神经网络处理的数据大小必须相同,通过Resize来设置

        # 数据增强
        transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机选
        transforms.CenterCrop(64),  # 从中心开始裁剪,将原本96x96大小的图片数据裁剪为64x64大小的图片数据,可以获取更多的参数
        transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转 选择一个概率概率,50%的概率进行水平翻转
        transforms.RandomVerticalFlip(p=0.5),  # 随 机垂直翻转,50%的概率进行竖直翻转

        transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),  # 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
        transforms.RandomGrayscale(p=0.025),  # 概率转换成灰度率,3通道就是R=G=B(三颜色通道转为单一颜色通道,很少进行此处理)

        # 将数据转为Tensor类型
        transforms.ToTensor(),

        # 标准化
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 设置均值,标准差,分别对应R、G、B三个颜色通道的三个均值和标准差值,(x-μ)/σ
    ]),

    # 对验证集的预处理(不需要进行数据增强)
    'valid': transforms.Compose([transforms.Resize(256),
                                 transforms.CenterCrop(224),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                 # 均值和标准差数值的设置和训练集的相同(验证集的数据对我们来说是未知的,不能利用其中的数据再计算出相关的均值和标准差)
                                 ]),
}

6.使用定义好的类来实例化DataLoader

# 训练集
train_dataset = FlowerDataset(root_dir=train_dir, ann_file='./flower_data/train.txt', transform=data_transforms['train'])
# 测试集
valid_dataset = FlowerDataset(root_dir=valid_dir, ann_file='./flower_data/valid.txt', transform=data_transforms['valid'])
# 实例化DataLoader(使用封装好的DataLoader包)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True)

7.检查数据和标签是否对应

# 检查训练集
image1, label1 = next(iter(train_loader))  # iter表示train_loader进行迭代,next取一个batch的数据
sample = image1[0].squeeze()  # 通过squeeze()压缩一个维度,有时候维度为1x3x64x64,去除这个1
# 此时的sample是3x64x64的结构,而需要图像展示则需要转换结构为64X64X3,同时需要转换为numpy数据结构
sample = sample.permute((1, 2, 0)).numpy()
# 标准化还原 x = (x-μ) / σ -> x = x*σ + μ (预处理中进行了标准化,需要还原)
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label1[0].numpy()))


# 检查训练集
image2, label2 = next(iter(valid_loader))  # iter表示train_loader进行迭代,next取一个batch的数据
sample = image2[0].squeeze()  # 通过squeeze()压缩一个维度,有时候维度为1x3x64x64,去除这个1
# 此时的sample是3x64x64的结构,而需要图像展示则需要转换结构为64X64X3,同时需要转换为numpy数据结构
sample = sample.permute((1, 2, 0)).numpy()
# 标准化还原 x = (x-μ) / σ -> x = x*σ + μ (预处理中进行了标准化,需要还原)
sample *= [0.229, 0.224, 0.225]
sample += [0.485, 0.456, 0.406]
plt.imshow(sample)
plt.show()
print('Label is: {}'.format(label2[0].numpy()))


'''
plt.imshow():

1.plt.imshow()用于显示图像数据或二维数组(也可以是三维数组,表示RGB图像)。
2.当你有一个二维数组或图像数据时,你可以使用plt.imshow()将其可视化为图像。
3.它将数组中的每个元素的值映射为一个颜色,并将这些颜色排列成图像的形式。
4.plt.imshow()可以接受许多参数,用于控制图像的外观,例如颜色映射(colormap)、插值方法等。

plt.show():

1.plt.show()用于显示所有已创建的图形。
2.在使用Matplotlib绘制图形时,图形被存储在内存中,但不会自动显示在屏幕上。为了在屏幕上显示图形,你需要调用plt.show()函数。
3.通常,在你创建完所有的图形之后,调用plt.show()一次,它会同时显示所有的图形窗口。

'''

plt.imshow()和plt.show()的区别: 

plt.imshow():

  1. plt.imshow()用于显示图像数据或二维数组(也可以是三维数组,表示RGB图像)。
  2. 当你有一个二维数组或图像数据时,你可以使用plt.imshow()将其可视化为图像。
  3. 它将数组中的每个元素的值映射为一个颜色,并将这些颜色排列成图像的形式。
  4. plt.imshow()可以接受许多参数,用于控制图像的外观,例如颜色映射(colormap)、插值方法等。

plt.show():

  1. plt.show()用于显示所有已创建的图形。
  2. 在使用Matplotlib绘制图形时,图形被存储在内存中,但不会自动显示在屏幕上。为了在屏幕上显示图形,你需要调用plt.show()函数。
  3. 通常,在你创建完所有的图形之后,调用plt.show()一次,它会同时显示所有的图形窗口。

8.使用创建好的DataLoader

dataloaders = {'train': train_loader, "valid": valid_loader}
for inputs, labels in dataloaders['train']:
    print("处理训练集")
for inputs, labels in dataloaders['valid']:
    print("处理验证集")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/406727.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【行业会议】优积科技应邀参加住建部模块建筑企业2023年工作座谈会

2023年3月2日,优积建筑科技发展(上海)有限公司(以下简称“优积科技”)应邀参加由住房和城乡建设部科技与产业化发展中心(以下简称“住建部科技与产业化中心”)组织召开的模块建筑企业2023年工作…

OpenCV 4基础篇| OpenCV图像基本操作

目录 1. 图像读取1.1 cv2.imread() 不能读取中文路径和中文名称1.2 cv2.imdecode() 可以读取中文路径和中文名称 2. 图像的显示2.1 openCV显示图像 cv2.imshow()2.2 matplotlib显示图像 plt.imshow() 3. 图像的保存 cv2.imwrite()4. 图像的复制4.1 img.copy()4.2 np.copy()4.3 …

基于java springboot的图书管理系统设计和实现

基于java springboot的图书管理系统设计和实现 博主介绍:5年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言 文末获取源码联…

Ansible 简介及部署 基础模块学习 ansible部署rsync 及时监控远程同步

Ansible介绍: Ansible 是一个配置管理系统,当下最流行的批量自动化运维工具之一,它是一款开源的自动化工具,基于Python开发的配置管理和应用部署的工具。 Ansible 是基于模块工作的,它只是提供了一种运行框架&#xff…

【深度学习】Pytorch 系列教程(七):PyTorch数据结构:2、张量的数学运算(5):二维卷积及其数学原理

文章目录 一、前言二、实验环境三、PyTorch数据结构1、Tensor(张量)1. 维度(Dimensions)2. 数据类型(Data Types)3. GPU加速(GPU Acceleration) 2、张量的数学运算1. 向量运算2. 矩阵…

书生·浦语大模型实战营第四节课作业

基础作业 fintune过程 这里要注意下。 合并完参数的模型再进行网页部署时,需要用到InternLM源码,教程里面忽略了需要commit版本。通过以下命令转到所需版本,然后就可以看到web_demo.py。 cd InternLM git checkout 3028f07cb79e5b1d7342f4…

Servlet实现图片的上传和显示

本篇文章是在上一篇文章上改进而来 一、图片上传需要引用的jar包 链接:https://pan.baidu.com/s/17FLjlWlNEG5YnS_dl3C8WA 提取码:wbis 二、最后的结果 三、更改数据库增加图片路径字段path 四、前端页面增加图片上传按钮,和上传的复选框 代码 上传…

ChatGPT 4.0 升级指南

1.ChatGPT 是什么? ChatGPT 是由 OpenAI 开发的一种基于人工智能的聊天机器人,它基于强大的语言处理模型 GPT(Generative Pre-trained Transformer)构建。它能够理解人类语言,可以为我们解决实际的问题。 1.模型规模…

vue+node.js美食分享推荐管理系统 io551

,本系统采用了 MySQL数据库的架构,在开始这项工作前,首先要设计好要用到的数据库表。该系统的使用者有二类:管理员和用户,主要功能包括个人信息修改,用户、美食类型、美食信息、订单信息、美食分享、课程大…

Camunda7.18流程引擎启动出现Table ‘camunda_platform_docker.ACT_GE_PROPERTY‘的解决方案

文章目录 1、问题描述2、原因分析3、解决方案3.1、方案一:降低mysql版本3.2、方案二:增加nullCatalogMeansCurrent参数(推荐) 4、总结 1、问题描述 需要在docker中,部署Camunda流程引擎。通过启动脚本camunda-platfor…

【LeetCode-337】打家劫舍III(动态规划)

目录 题目描述 解法1:动态规划 代码实现 题目链接 题目描述 在上次打劫完一条街道之后和一圈房屋后,小偷又发现了一个新的可行窃的地区。这个地区只有一个入口,我们称之为“根”。 除了“根”之外,每栋房子有且只有一个“父“…

JVM内存随着服务器内存的升高而升高问题排查

一、故障描述 公司测试环境和线上环境,都会有:JVM内存随着服务器内存的升高而升高 这种问题 二、排查 1、linux服务器上使用htop查看java项目内存占比,给最大最小推内存300m,但是实际上超出一倍 2、排查方案 a、通过后面的学习…

Games 103 作业四

Games 103 作业四 第四次作业就是流体模拟了,作业中给了若干的实现步骤,以及一些模板代码。 首先第一步,在update函数的开头,加载水面mesh的高度,然后在update的结束时,把计算后的高度更新到mesh中。这个很…

CSDN原力值怎么提升?

文章目录 前言一、原力值怎么看二、提升原力值的方法1.原力值↑2.原力值↓提示!!!禁止在csdn网站内进行违规行为!!! 结束语 前言 在前面一篇文章中,我讲了付费收看的条件,有需要的先把网址收藏起来! https://blog.csdn.net/m0_69481332/arti…

【坑】Spring Boot整合MyBatis,一级缓存失效

一、Spring Boot整合MyBatis,一级缓存失效 1.1、概述 MyBatis一级缓存的作用域是同一个SqlSession,在同一个SqlSession中执行两次相同的查询,第一次执行完毕后,Mybatis会将查询到的数据缓存起来(缓存到内存中&#xf…

【Java面试系列】Nginx

目录 为什么要用Nginx?为什么Nginx性能这么高?Nginx 是如何实现高并发的? Nginx怎么处理请求的?Nginx的工作流程 给 favicon.ico 和 robots.txt 设置过期时间; 这里为 favicon.ico 为 99 天,robots.txt 为 7 天并不记录 404 错误日…

前沿科技速递——YOLOv9

随着YOLO系列的不断迭代更新,前几天,YOLO系列也迎来了第九个大型号的更新!YOLOv9正式推出了!附上原论文链接。 arxiv.org/pdf/2402.13616.pdf 同样是使用MS COCO数据集进行对比比较,通过折线图可看出AP曲线在全方面都…

2024比较赚钱的项目是什么?亲身经历,月入过万!

我是电商珠珠 年后找项目这件事,成为了部分人所焦虑的一点,有的想要兼职,有的在考虑全职。至于做什么还没有一丝头绪。大家都知道短视频很火,于是有直播能力的人就吃上了流量红利,开始做达人带货,拍视频接…

Linux下“一切皆文件”

“Linux下一切皆文件” Linux 下一切皆文件这个说法是指 Linux 系统中的一种设计理念,即将所有设备、资源和进程等抽象为文件或文件夹的形式。这种设计理念的好处在于统一了对待不同类型资源的方式,提供了统一的接口和工具来进行管理和操作。 Linux 下…

Flutter Slider自定义滑块样式 Slider的label标签框常显示

1、自定义Slider滑块样式 Flutter Slider控件的滑块系统样式是一个圆点,thumbShape默认样式是RoundSliderThumbShape,如果想要使用其它的样式就需要自定义一下thumbShape; 例如需要一个上图样式的(圆点半透明圆形边框&#xff09…