深度学习入门:自建数据集完成花鸟二分类任务

自建数据集完成二分类任务(参考文章)

1 图片预处理

1 .1 统一图片格式

找到的图片需要首先做相同尺寸的裁剪,归一化,否则会因为图片大小不同报错

RuntimeError: stack expects each tensor to be equal size,
but got [3, 667, 406] at entry 0 and [3, 600, 400] at entry 1

pytorch的torchvision.transforms模块提供了许多用于图片变换/增强的函数。

1.1.1 把图片不等比例压缩为固定大小
transforms.Resize((600,600)),
1.1.2 裁剪保留核心区

因为主体要识别的图像一般在中心位置,所以使用CenterCrop,这里设置为(400, 400)

transforms.CenterCrop((400,400)),
1.1.3 处理成统一数据类型

这里统一成torch.float64方便神经网络计算,也可以统一成其他比如uint32等类型

transforms.ConvertImageDtype(torch.float64),
1.1.4 归一化进一步缩小图片范围

对于图片来说0~255的范围有点大,并不利于模型梯度计算,我们应该进行归一化。pytorch当中也提供了归一化的函数torchvision.transforms.Normalize(mean,std)

  • 我们可以使用[0.5,0.5,0.5]mean,std来把数据归一化至[-1,1]
  • 也可以手动计算出所有的图片mean,std来归一化至均值为0,标准差为1的正态分布,
  • 一些深度学习代码常常使用mean=[0.485, 0.456, 0.406] ,std=[0.229, 0.224, 0.225]的归一化数据,这是在ImageNet的几百万张图片数据计算得出的结果
  • BN等方法也具有很出色的归一化表现,我们也会使用到

Juliuszh:详解深度学习中的Normalization,BN/LN/WN
Algernon:【基础算法】六问透彻理解BN(Batch Normalization)

我们这里使用简单的[0.5,0.5,0.5]归一化方法,更新cls_dataset,加入transform操作 ,作为图片裁剪的预处理。

transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])

关于transforms的操作大体分为裁剪/翻转和旋转/图像变换/transform自身操作,具体见余霆嵩:PyTorch 学习笔记(三):transforms的二十二个方法,这里不进行详细展开。

1.2 数据增强

当数据集较小时,可以通过对已有图片做数据增强,利用之前提到的transforms中的函数 ,也可以混合使用来根据已有数据创造新数据

        self.data_enhancement = transforms.Compose([
            transforms.RandomHorizontalFlip(p=1),
            transforms.RandomRotation(30)
        ])

2 创建自制数据集

2.1 以Dataset类接口为模版

class cls_dataset(Dataset):
    def __init__(self) -> None:
       # initialization
        
    def __getitem__(self, index):
        # return data,label in set 
    
    def __len__(self):
        # return the length of the dataset

2.2 创建set

2.2.1定义两个空列表data_list和target_list
2.2.2遍历文件夹
2.2.3读取图片对象,对每一个图片对象预处理后,分别将图片对象和对应的标签加入data_list和target_list中
2.2.4将data_list和target_list加入h5df_ile中
import os
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import h5py
from torchvision.io import read_image

train_pic_path = 'test-set'
test_pic_path = 'training-set'

def create_h5_file(file_name):
    all_type = ['flower', 'bird']
    h5df_file = h5py.File(file_name, "w") #file_name指向比如"train.hdf5"这种文件路径,但这句话之前file_name指向路径为空

    #图片统一化处理
    transform = transforms.Compose([
        transforms.Resize((600, 600)),
        transforms.CenterCrop((400, 400)),
        transforms.ConvertImageDtype(torch.float64),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ]
    )
    #数据增强

    data_list = []   #建立一个保存图片张量的空列表
    target_list = [] #建立一个保存图片标签的空列表

    #遍历文件夹建立数据集
    '''
    文件夹组成
    | —— train
    |   | —— flower
    |   |   | —— 图片1
    |   | —— bird
    |   | —— | —— 图片2
    | —— test
    |   | —— flower
    |   | —— bird
    '''

    dataset_kind = file_name.split('.')[0]
    #先判断缺失的文件是训练集还是测试集
    if dataset_kind == 'train':
        pic_file_name = train_pic_path
    else:
        pic_file_name = test_pic_path

    #再循环遍历文件夹
    for file_name_dir, _, files in tqdm(os.walk(pic_file_name)):
        target = file_name_dir.split('/')[-1]
        if target in all_type:
            for file in files:
                pic = read_image(os.path.join(file_name_dir, file))  #以张量形式读取图片对象
                pic = transform(pic)    #预处理图片
                pic = np.array(pic).astype(np.float64)
                data_list.append(pic)   #将pic对象添加到列表里
                target_list.append(target.encode()) #将target编码后添加到列表里

    h5df_file.create_dataset("image", data=data_list)
    h5df_file.create_dataset("target", data=target_list)
    h5df_file.close()

class h5py_dataset(Dataset):
    def __init__(self, file_name) -> None:
        super().__init__()
        self.file_name = file_name    #指向文件的路径名
        #如果file_name指向的h5文件不存在,就新建一个
        if not os.path.exists(file_name):
            create_h5_file(file_name)

        
    def __getitem__(self, index):
        with h5py.File(self.file_name, 'r') as f:
            if f['target'][index].decode() == 'bird':   #如果在f文件的target列表中查找到index下标对应的标签是bird
                target = torch.tensor(0)
            else:
                target = torch.tensor(1)
        return f['image'][index], target

    def __len__(self):
        with h5py.File(self.file_name, 'r') as f:
            return len(f['target'])

def h5py_loader():
    train_file = 'train.hdf5'
    test_file = 'test.hdf5'

    train_dataset = h5py_dataset(train_file)
    test_dataset = h5py_dataset(test_file)

    train_data_loader = DataLoader(train_dataset, batch_size=4)
    test_data_loader = DataLoader(test_dataset, batch_size=4)

    return train_data_loader, test_data_loader


2.3 创建loader

实例化set对象后利用torch.utils.data.DataLoader

3 搭建网络

3.1 网络结构

在这里插入图片描述

3.2 参数计算

卷积后,池化后尺寸计算公式:
(图像尺寸-卷积核尺寸 + 2*填充值)/步长+1
(图像尺寸-池化窗尺寸 + 2*填充值)/步长+1

参考文章

3.3 不成文规定

池化参数一般就是(2, 2)

中间的channel数量都是自己设定的,二的次方就行

kernelsize一般3或者5之类的

4 训练

加深对前面数据集组成理解

    for _, data in enumerate(train_loader):
        if isinstance(data, list):
            image = data[0].type(torch.FloatTensor).to(device)
            target = data[1].to(device)
        elif isinstance(data, dict):
            image = data['image'].type(torch.FloatTensor).to(device)
            target = data['target'].to(device)
        else:
            print(type(data))
            raise TypeError

for 循环中data的组成来源于构建set时,

    h5df_file.create_dataset("image", data=data_list)
    h5df_file.create_dataset("target", data=target_list)

写入了h5df文件中两个dataset,但在文件中是以嵌套列表形式保存,其中data[0]等价于引用image这个dataset,data[1]等价于引用target这个集合

在这里插入图片描述

5 测试

6 保存模型

改进

投影概率放到网络里面

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/175902.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何使用环境变量运行bat脚本(开启数据库db)

文章目录 1.拥有一个bat脚本2. 右击本电脑,点击属性,找到高级设置3.新建一个环境变量命名为你想要的名字4. 找到Path,进入新增%m8%即可5.确认所有刚刚的操作,并关闭开始配置后的所有页面6.notice 1.拥有一个bat脚本 内容是执行mys…

【开源】基于Vue.js的康复中心管理系统

项目编号: S 056 ,文末获取源码。 \color{red}{项目编号:S056,文末获取源码。} 项目编号:S056,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 普通用户模块2.2 护工模块2.3 管理员…

精彩预告 | OpenHarmony即将亮相MTSC 2023

MTSC 2023 第12届中国互联网测试开发大会(深圳站)即将于2023年11月25日,在深圳登喜路国际大酒店举办,大会将以“1个主会场4个平行分会场”的形式呈现,聚集一众顶尖技术专家和行业领袖,围绕如今备受关注的行…

2023年中国油墨树脂主要环节、产量及市场规模分析[图]

油墨树脂是指用于油墨制造中的一种高分子材料,主要用于改善油墨的粘性、流动性、光泽度和耐磨性等性能。其主要成分为合成树脂,如聚酯、聚酰胺、聚丙烯酸酯等。油墨树脂在油墨制造中的应用非常广泛,可以用于各种类型的油墨,包括印…

(内部资料)收下这几个人工智能学习秘籍!

秘籍一:练好基本功 学习基础知识:人工智能涉及多个学科领域,包括数学、计算机科学、统计学等。因此,学习基础知识是非常重要的。您可以通过学习线性代数、概率论和微积分等数学基础知识,以及掌握Python编程语言和常用…

Linux学习笔记-Ubuntu下使用Crontab设置定时任务

文章目录 一、概述二、基于crontab的设置2.1 基本命令说明2.2 使用-e指令编辑命令2.2.1 进入编辑模式2.2.2 指令信息格式2.2.4 开启日志1) 修改rsyslog配置文件2) 重启rsyslog3) 查看日志 2.2.3 设置后之后重启服务 三、示例3.1 每隔一分钟往文件中日期3.2 使用-l查看任务列表3…

腾讯云服务器99元一年?假的,阿里云是99元

腾讯云服务器99元一年是真的吗?假的,不用99元,只要88元即可购买一台2核2G3M带宽的轻量应用服务器,99元太多了,88元就够了,腾讯云百科活动 txybk.com/go/txy 活动打开如下图: 腾讯云服务器价格 腾…

【Redis篇】简述Java中操作Redis的方法

文章目录 🎄简述Jedis🎄Jedis优点🍔使用Jedis连接Redis⭐进行测试🎈进行测试 Redis(Remote Dictionary Server)是一种流行的高性能内存数据库,广泛应用于各种应用程序和系统中。作为Java开发人员…

政府采购变数大,联想还值不值得代理渠道商们“跟”?

文|新熔财经 作者|余一 “事业单位更换纯国产电脑”、“联想被排除在大订单之外”等消息下,联想硬件终端产品面临的问题日益严峻。 早在今年年初,联想集团(下称联想)掌门人杨元庆坦承,“智能…

阿里云学生认证可领300元无门槛代金券(高效计划)

阿里云高校计划学生和教师均可参与,完成学生认证和教师验证后学生可以免费领取300元无门槛代金券和3折优惠折扣,适用于云服务器等全量公共云产品,订单原价金额封顶5000元/年,阿里云百科aliyunbaike.com分享阿里云高校计划入口及学…

MySQL数据库入门到大牛_基础_11_数据处理之增删改

本章将会介绍DML中的增删改查操作,增删改泛泛来讲是针对表中数据的修改。 文章目录 1. 插入数据1.1 实际问题1.2 方式1:VALUES的方式添加1.3 方式2:将查询结果插入到表中 2. 更新数据3. 删除数据4. 小结5. MySQL8新特性:计算列6. …

微博头条文章开放接口报错 auth by Null spi

接口文档地址 https://open.weibo.com/wiki/Toutiao/api 接口说明 https://api.weibo.com/proxy/article/publish.json 请求方式 POST 请求参数 参数名称类型是否必需描述titlestring是文章标题,限定32个中英文字符以内contentstring是正文内容,限制9…

项目管理PMP6.0-五大过程组、十大知识领域、四十九个过程(记忆码:7664363734)

项目管理PMP6.0-五大过程组、十大知识领域、四十九个过程(记忆码:7664363734) 项目经理的影响力范围三者关系图(五大过程组、十大知识领域、四十九个过程)五大过程组十大知识领域十大知识领域之间联系 四十九个过程&am…

振南技术干货集:制冷设备大型IoT监测项目研发纪实(1)

注解目录 1.制冷设备的监测迫在眉睫 1.1 冷食的利润贡献 1.2 冷设监测系统的困难 (制冷设备对于便利店为何如何重要?了解一下你所不知道的便利店和新零售行业。关于电力线载波通信的论战。) 2、电路设计 2.1 防护电路 2.1.1 强电防护 …

redis的性能管理和雪崩

redis的性能管理 redis的数据是缓存在内存当中的 系统巡检: 硬件巡检、数据库、nginx、redis、docker、k8s 运维人员必须要关注的redis指标 在日常巡检中需要经常查看这些指标使用情况 info memory #查看redis使用内存的指标 used_memory:11285512 #数据占用的…

Maven中常用命令以及idea中使用maven指南

文章目录 Maven 常用命令compiletestcleanpackageinstallMaven 指令的生命周期maven 的概念模型 idea 开发maven 项目idea 的maven 配置idea 中创建一个maven 的web 工程在pom.xml 文件添加坐标坐标的来源方式依赖范围编写servlet maven 工程运行调试 Maven 常用命令 compile …

AI辅助带货直播场景源码系统 附带网站的搭建教程

互联网技术的发展和普及,直播带货行业迅速崛起。然而,直播带货在带来商机的同时,也面临着诸多挑战。如直播内容缺乏新意、转化率低等问题。针对这些问题,AI辅助带货直播场景源码系统应运而生,旨在利用人工智能技术&…

CSS实现三角形

CSS实现三角形 前言第一种:bordertransparent第二种borderrgb使用unicode字符 前言 本文讲解三种实现三角形的方式,并且配有图文以及代码解说。那么好,本文正式开始。 第一种:bordertransparent border是边框,而transparent是透明的颜色&a…

CentOS使用docker安装OpenGauss数据库

1.搜索OpenGauss docker search opengauss 2.选择其中一个源拉取 docker pull docker.io/enmotech/opengauss 3.运行OpenGauss docker run --name opengauss --privilegedtrue --restartalways -d -e GS_USERNAMEpostgres -e GS_PASSWORDmyGauss2023 -p 5432:5432 docker.…

抖音seo矩阵系统源代码部署及产品功能设计分析

一、引言 随着抖音等短视频平台的崛起,越来越多的企业和个人开始关注如何在这些平台上提升曝光量和用户流量。抖音SEO(搜索引擎优化)是一种有效的方法,通过优化短视频内容和关键词,让更多的人找到并点击你的视频。本文…