将图像增广应用于Mnist数据集

将图像增广应用于Mnist数据集

不用到cifar-10的原因是要下载好久。。我就直接用在Mnist上了,先学会用

首先我们得了解一下图像增广的基本内容,这是我的一张猫图片,以下为先导入需要的包和展示图片

import time
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from PIL import Image
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
d2l.set_figsize()
img = Image.open('cat.png')
d2l.plt.imshow(img)

在这里插入图片描述
之后呢,我们先定义几个函数,以后方便调用,第一个函数show_images,他是用来展示多张图片的

def show_images(imgs, num_rows, num_cols, scale=2):
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize = figsize)
    for i in range(num_rows):
        for j in range(num_cols):
            axes[i][j].imshow(imgs[i * num_cols + j])
            axes[i][j].axes.get_xaxis().set_visible(False)
            axes[i][j].axes.get_yaxis().set_visible(False)
    return axes

然后将图像展示函数和图像增广函数结合起来展示,也用一个函数来集成

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    show_images(Y, num_rows, num_cols, scale)

接下来,就可以开始我们的图像增广之路啦

左右翻转

torchvision.transforms.RandomHorizontalFlip()这个函数有百分之五十的概率实现左右翻转

apply(img, torchvision.transforms.RandomHorizontalFlip()) # torchvision.transforms.RandomHorizontalFlip() 百分之50的概率左右翻转

在这里插入图片描述

上下翻转

torchvision.transforms.RandomVerticalFlip() 百分之50的概率上下翻转
在这里插入图片描述

随机裁剪

随机裁剪出一块面积为原面积10%100%的区域,且该区域的宽和高之比随机取自0.52,然后将该区域的宽高缩放到200像素

shape_aug = torchvision.transforms.RandomResizedCrop(200, scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)

在这里插入图片描述
自然,我们也可以变换颜色,有亮度(brightness),对比度(contrast),饱和度(saturation),色调(hue)
我就直接一起写了,也可以只变单个
0.5的意思是比如对于亮度来说,他会在50%的范围内随机选择,即亮度为原来的0.5~1.5

color_aug = torchvision.transforms.ColorJitter(brightness=0.5, hue=0.5, saturation=0.5, contrast=0.5) 
apply(img, color_aug)

在这里插入图片描述
那么当然,我们也可以把上述的那些进行叠加
用到torchvision.transforms.Compose

augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug
])
apply(img, augs)

在这里插入图片描述
之后呢,就可以用增广后的图像进行训练啦,这里给大家一个例子用Resnet18进行训练Mnist数据集,Resnet18就不带着大家写了,直接调用别人写好的函数,写网络并不是本节的重点,如果以后有时间或者大家有需要我可以再来写~
(为什么是Mnist数据集,其实他在Mnist数据集上的效果并没有很明显,比较比较简单,最好是在cifar上,但是cifar要下太久了,懒,大家可以在cifar上测一下)
先写两个augs,训练集我就将他随机翻转,测试集就不动了

flip_aug = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()           # 记得转换成tensor 以便训练
])
no_aug = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

然后是load_mnist,就加一个transform就好啦,集成到一个函数里

def load_mnist(is_train, augs, batch_size, root="~/Datasets/MNIST"):
    dataset = torchvision.datasets.MNIST(train=is_train, root=root, transform=augs, download=True)
    return DataLoader(dataset, batch_size = batch_size, shuffle=is_train)

再之后就是模型的训练了,这个大家应该都写腻了,我也不多说什么了,反正就是模型前向传播+反向传播,然后再记录点值

def train(train_iter, test_iter, net, loss, optimizer, device, num_epochs):
    net = net.to(device)
    print("training on ", device)
    batch_count = 0
    for epoch in range(num_epochs):
        train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l_sum += l.cpu().item()
            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = d2l.evaluate_accuracy(test_iter, net)
        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'
              % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

再最后,就定义一个函数,把前面的都用上啦!

def train_with_data_aug(train_augs, test_augs, lr=0.001):
    batch_size, net = 256, d2l.resnet18(output=10, in_channels=1)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = torch.nn.CrossEntropyLoss()
    train_iter = load_mnist(True, train_augs, batch_size)
    test_iter = load_mnist(False, test_augs, batch_size)
    train(train_iter, test_iter, net, loss, optimizer, device, num_epochs=10)

值得注意的是,这边调用别人的d2l.resnet18,要注意in_channels=1记得写,他默认是3通道的,改成1通道对于我们的mnist,如果你要是cifar-10就不用变了,把in_channel=1给删掉就好~,至此,调用我们的函数就行
在这里插入图片描述

训练还是很快滴

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/226397.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++ STL容器与常用库函数

STL是提高C编写效率的一个利器 STL容器&#xff1a; 一、#include <vector> 英文翻译&#xff1a;vector &#xff1a;向量 vector是变长数组(动态变化)&#xff0c;支持随机访问&#xff0c;不支持在任意位置O(1)插入。为了保证效率&#xff0c;元素的增删一般应该在末尾…

测开速成考点

用户登录界面需要做哪些测试&#xff1f; ⑴功能&#xff1a;  正确的用户名&#xff0c;正确的密码&#xff0c;点击提交按钮&#xff0c;验证能否正确登录 是否跳转正确页面。  错误的用户名 或 密码&#xff0c;点击提交&#xff0c;提示相应的错误信息。  记住用户名…

JavaScript实现手写签名,可触屏手写,支持移动端与PC端双端保存

目录 1.HTML模板 2.获取DOM元素和定义变量 3.创建两个canvas元素&#xff0c;并设置它们的宽度和高度 4.绑定触摸事件&#xff1a;touchstart, touchmove, touchend和click 5.实现触摸事件回调函数&#xff1a;startDrawing, draw和stopDrawing 6.实现绘制线段的函数&…

C语言——输出菱形

法一&#xff1a; #include<stdio.h> #define N 7 //假设输出7层菱形 int main(){int i;//i控制第几行 int j;//j控制每一行空格的循环个数 int k;//k控制每一行*的循环次数 for(i1;i<4;i){//将图形分为两部分,前四行(第一部分) for(j1;j<4-i;j){//输出第i行的…

记录一下本地源码安装部署ThingsBoard可能踩到的坑

使用git下载源码后, 必须运行 mvn clean install -DskipTests这一步很重要, 有很多文件需要初始化, 如果直接放入idea可能存在各种问题, 最好是用命令行执行 初始化时, 可能报错停止, 这个一般是网络问题, 可以尝试修改maven镜像, 这是我成功构建的镜像 <!--阿里云仓库--…

膜结构建筑:未来体育可持续发展的绿色引擎

随着城市化的飞速发展&#xff0c;现代建筑迫切需要创新性的解决方案&#xff0c;而膜结构建筑以其独特的设计理念和可持续性特点&#xff0c;正在成为未来城市发展的重要引擎。本文将深入探讨膜结构建筑在可持续城市发展中的关键作用&#xff0c;包括其在节能减排、资源有效利…

“分割“安卓用户,对标iOS,鸿蒙崛起~

近期关于**“华为于明年推出不兼容安卓的鸿蒙版本”**的消息传出&#xff0c;引起了业界的热议关注。自从2019年8月&#xff0c;美国制裁下&#xff0c;华为不再能够获得谷歌安卓操作系统相关付费服务&#xff0c;如此情况下&#xff0c;华为“备胎”鸿蒙操作系统一夜转正。 华…

Swift 中 User Defaults 的读取和写入

文章目录 前言介绍 User Defaults共享 User DefaultsUser Defaults 存储数据类型响应更改监控 User Defaults 更改覆盖User Defaults 设置考虑的替代方案Keychain 用于安全性用于跨平台的 CloudKit 结论 前言 User Defaults 是 Swift 应用程序存储在应用启动之间保持的首选项的…

音视频学习(二十)——rtsp收流(udp方式)

前言 本文主要介绍通过udp方式实现rtsp拉流。 流程图 流程说明&#xff1a; 相较于tcp方式“信令数据”复用同一连接拉流&#xff0c;udp方式拉流“信令数据”采用不同的连接&#xff0c;信令传输采用tcp&#xff0c;流数据传输采用udp&#xff1b;客户端向服务端&#xff0…

Android平板还能编程?Ubuntu本地安装code-server远程编程写代码

文章目录 1.ubuntu本地安装code-server2. 安装cpolar内网穿透3. 创建隧道映射本地端口4. 安卓平板测试访问5.固定域名公网地址6.结语 1.ubuntu本地安装code-server 准备一台虚拟机,Ubuntu或者centos都可以&#xff0c;这里以VMwhere ubuntu系统为例 下载code server服务,浏览器…

毕设:《基于hive的音乐数据分析系统的设计与实现》

文章目录 环境启动一、爬取数据1.1、歌单信息1.2、每首歌前20条评论1.3、排行榜 二、搭建环境1.1、搭建JAVA1.2、配置hadoop1.3、配置Hadoop环境&#xff1a;YARN1.4、MYSQL1.5、HIVE(数据仓库)1.6、Sqoop&#xff08;关系数据库数据迁移&#xff09; 三、hadoop配置内存四、导…

原生cesium、mars3d、supermap-cesium在vue3+vite中引入

1. 原生cesium 需要下载 yarn add cesiumyarn add vite-plugin-cesium2. mars3d 需要下载 yarn add mars3d mars3d-cesiumyarn add vite-plugin-mars3d3. supermap-cesium 只需要引入官网下载的包&#xff0c;build文件夹下的cesium&#xff0c;以及项目中引入的其他cesiu…

ChatGPT 应用开发(一)ChatGPT OpenAI API 免代理调用方式(通过 Cloudflare 的 AI Gateway)

前言 开发 ChatGPT 应用&#xff0c;我觉得最前置的点就是能使用 ChatGPT API 接口。首先我自己要能成功访问&#xff0c;这没问题&#xff0c;会魔法就可以本地调用。 那用户如何调用到我的应用 API 呢&#xff0c;我的理解是通过用户能访问到的中转服务器向 OpenAI 发起访问…

Linux篇:进程间通信

一、进程间通信原理&#xff1a; 1、通信是有成本的&#xff1a;两个或者多个进程&#xff0c;实现数据层面的交互&#xff0c;因为进程独立性的存在&#xff0c;导致进程通信的成本比较高。 2、进程间通信的方式&#xff1a; ①基本数据 ②发送命令 ③某种协同 ④通知 .....…

C++新经典模板与泛型编程:用成员函数重载实现std::is_class

用成员函数重载实现is_class std::is_class功能&#xff0c;是一个C11标准中用于判断某个类型是否为一个类类型&#xff08;但不是联合类型&#xff09;的类模板。当时在讲解的时候并没有涉及std::is_class的实现代码&#xff0c;在这里实现一下。简单地书写一个IsClass类模板…

Linux学习笔记7-IIC的应用和AP3216C

接下来进入其他两种串行通信方式&#xff1a;SPI和I2C的学习&#xff0c;因为以后的项目中会用到这些通信方式&#xff0c;而且正点原子的开发板里面也有用I2C和SPI通信的传感器来做实例&#xff0c;分别是一个距离传感器和六轴陀螺仪&#xff0c;这样就可以很好的通过实例来学…

TCP通讯

第二十一章 网络通信 本章节主要讲解的是TCP和UDP两种通信方式它们都有着自己的优点和缺点 这两种通讯方式不通的地方就是TCP是一对一通信 UDP是一对多的通信方式 接下来会一一讲解 TCP通信 TCP通信方式呢 主要的通讯方式是一对一的通讯方式&#xff0c;也有着优点和缺点…

maven-assembly-plugin 自定义打包

我想把input文件夹给打包进去 pom文件 <build><plugins><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-assembly-plugin</artifactId><executions><execution><!-- 配置执行器 --><i…

黑苹果之主板篇

一、什么是主板 主板&#xff0c;又叫主机板&#xff08;mainboard&#xff09;、系统板&#xff08;systemboard&#xff09;、或母板&#xff08;motherboard&#xff09;&#xff0c;是计算机最基本的同时也是最重要的部件之一。主板一般为矩形电路板&#xff0c;上面安装了…

如何使用HadSky搭配内网穿透工具搭建个人论坛并发布至公网随时随地可访问

文章目录 前言1. 网站搭建1.1 网页下载和安装1.2 网页测试1.3 cpolar的安装和注册 2. 本地网页发布2.1 Cpolar临时数据隧道2.2 Cpolar稳定隧道&#xff08;云端设置&#xff09;2.3 Cpolar稳定隧道&#xff08;本地设置&#xff09;2.4 公网访问测试 总结 前言 经过多年的基础…