pytorch-数据增强

目录

  • 1. Flip翻转
  • 2. Rotate旋转
  • 3. scale缩放
  • 4. crop裁剪
  • 5. 总结
  • 6. 完整代码

1. Flip翻转

在这里插入图片描述
在这里插入图片描述
上图中做了随机水平翻转和随机垂直翻转,翻转完成后转化成tensor

2. Rotate旋转

在这里插入图片描述
在这里插入图片描述

上图中作了2次旋转第一次旋转角度在-15<0<15范围内,随机出一个角度,第二次旋转角是从90,180,270中random出一个。

3. scale缩放

在这里插入图片描述

在这里插入图片描述
缩放通过Resize函数实现,注意传入参数宽高为list类型

4. crop裁剪

在这里插入图片描述
在这里插入图片描述
上图中的RandomCrop就是随机裁剪方法,一般与RandomRotation一起使用。
transforms.Compose类似nn.Sequential,是将各种操作打包成一个操作

5. 总结

数据增加理论上可以扩充出无数张图片数据,但是如果原数据集比较小的话,也不会得到很好的效果,只能改善一些,意思就是说数据增加对机器学习改善比较有限。

6. 完整代码

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms

from visdom import Visdom

batch_size=200
learning_rate=0.01
epochs=10

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.RandomHorizontalFlip(),
                       transforms.RandomVerticalFlip(),
                       transforms.RandomRotation(15),
                       transforms.RandomRotation([90, 180, 270]),
                       transforms.Resize([32, 32]),
                       transforms.RandomCrop([28, 28]),
                       transforms.ToTensor(),
                       # transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)



class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 10),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        x = self.model(x)

        return x

device = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)

viz = Visdom()

viz.line([0.], [0.], win='train_loss', opts=dict(title='train loss'))
viz.line([[0.0, 0.0]], [0.], win='test', opts=dict(title='test loss&acc.',
                                                   legend=['loss', 'acc.']))
global_step = 0

for epoch in range(epochs):

    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)
        data, target = data.to(device), target.cuda()

        logits = net(data)
        loss = criteon(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.norm())
        optimizer.step()

        global_step += 1
        viz.line([loss.item()], [global_step], win='train_loss', update='append')

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        data, target = data.to(device), target.cuda()
        logits = net(data)
        test_loss += criteon(logits, target).item()

        pred = logits.argmax(dim=1)
        correct += pred.eq(target).float().sum().item()

    viz.line([[test_loss, correct / len(test_loader.dataset)]],
             [global_step], win='test', update='append')
    viz.images(data.view(-1, 1, 28, 28), win='x')
    viz.text(str(pred.detach().cpu().numpy()), win='pred',
             opts=dict(title='pred'))

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/694964.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java SE(Java Platform, Standard Edition)

Java SE&#xff08;Java Platform, Standard Edition&#xff09; 是Java平台的一个版本&#xff0c;面向桌面应用程序、服务器和嵌入式环境。Java SE提供了开发和运行Java应用程序的基础API&#xff08;Application Programming Interface&#xff0c;应用程序编程接口&…

Java 编译报错:找不到符号? 手把手教你排查解决!

Java 编译报错&#xff1a;找不到符号&#xff1f; 手把手教你排查解决&#xff01; 在 Java 开发过程中&#xff0c;我们经常会遇到编译器抛出 "找不到符号" 错误。这个错误提示意味着编译器无法在它所理解的范围内找到你所引用的类、变量或方法。这篇文章将带你一步…

Spring Security 应用详解

一、 集成SpringBoot 1.1 Spring Boot 介绍 Spring Boot 是一套 Spring 的快速开发框架&#xff0c;基于 Spring 4.0 设计&#xff0c;使用 Spring Boot 开发可以避免一些繁琐的工程 搭建和配置&#xff0c;同时它集成了大量的常用框架&#xff0c;快速导入依赖包&#xff0…

html--宇航员404

<!doctype html> <html> <head> <meta charset"utf-8"> <title>太空404</title><style> html {margin: 0;padding: 0;background-color: white; }body, html {width: 100%;height: 100%;overflow: hidden; }#svgContainer…

数据结构(C):二叉树前中后序和层序详解及代码实现及深度刨析

目录 &#x1f31e;0.前言 &#x1f688;1.二叉树链式结构的代码是实现 &#x1f688;2.二叉树的遍历及代码实现和深度刨析代码 &#x1f69d;2.1前序遍历 ✈️2.1.1前序遍历的理解 ✈️2.1.2前序代码的实现 ✈️2.1.3前序代码的深度解剖 &#x1f69d;2.2中序遍历 ✈…

【QT5】<总览五> QT多线程、TCP/UDP

文章目录 前言 一、QThread多线程 二、QT中的TCP编程 1. TCP简介 2. 服务端程序编写 3. 客户端程序编写 4. 服务端与客户端测试 三、QT中的UDP编程 1. UDP简介 2. UDP单播与广播程序 前言 承接【QT5】&#xff1c;总览四&#xff1e; QT常见绘图、图表及动画。若存在…

开启数字化校园解决方案,实现教育智能化

现代社会的教育面临诸多挑战&#xff0c;如何提高教育质量&#xff0c;实现教育智能化成为了当务之急。数字化校园解决方案应运而生&#xff0c;为学校提供了全新的教学模式和管理方式。本文将介绍数字化校园解决方案的重要性&#xff0c;以及如何开启数字化校园&#xff0c;实…

【端午安康,给大家讲个“网络”故事,深刻一下!】

牛马我&#x1f434;上周又挨锤了&#xff0c; 网络是不稳定的&#xff0c;博学多知的你可能知道&#xff0c;可能不知道。但假如没亲身经历过&#xff0c;知不知道都不深刻&#xff0c;牛马踩了个网络的坑&#xff0c;深刻了&#xff0c;这里分享下&#xff0c; 一个真相 无…

【Python报错】已解决ImportError: cannot import name ‘triu’ from ‘scipy.linalg’

成功解决“ImportError: cannot import name ‘triu’ from ‘scipy.linalg’”错误的全面指南 在Python编程中&#xff0c;尤其是在使用scipy这个科学计算库时&#xff0c;可能会遇到ImportError错误&#xff0c;提示无法从scipy.linalg模块中导入名为triu的函数。这个错误通…

深入JVM:线上内存泄漏问题诊断与处理

文章目录 深入JVM&#xff1a;线上内存泄漏问题诊断与处理一、序言二、内存泄漏概念三、内存泄漏环境模拟四、内存泄漏诊断与解决1、步骤一&#xff1a;获取堆内存快照文件&#xff08;1&#xff09;获取正在运行程序dump文件&#xff08;2&#xff09;获取已终止程序dump文件 …

HP Laptop 14s-fr1xxx原厂oem预装Win11系统ISO镜像下载

惠普星青春版14s-fr1xxx笔记本电脑原装出厂Windows11系统安装包&#xff0c;恢复出厂开箱状态一模一样 链接&#xff1a;https://pan.baidu.com/s/11Qe5XgCmH3emIVEpvoKclg?pwdm1qe 提取码&#xff1a;m1qe 适用型号&#xff1a;14s-fr1xxx 14s-fr0001AU、14s-fr0002AU、…

VMware Fusion 如何增加linux硬盘空间并成功挂载

文章目录 0. 前言1. 增加硬盘空间2. 硬盘分区2.1 查看硬盘2.2 分区2.3 格式化2.4 挂载 3. 参考 0. 前言 如果发现虚拟机分配的硬盘不足&#xff0c;需要增加硬盘空间。本文教给大家如何增加硬盘空间并成功挂载。 查看当前硬盘使用情况&#xff1a; df -h可以看到&#xff0c…

sqli-labs 靶场 less-7 第七关详解:OUTFILE注入与配置

SQLi-Labs是一个用于学习和练习SQL注入漏洞的开源应用程序。通过它&#xff0c;我们可以学习如何识别和利用不同类型的SQL注入漏洞&#xff0c;并了解如何修复和防范这些漏洞。Less 7 SQLI DUMB SERIES-7判断注入点 进入页面中&#xff0c;并输入数据查看结果。 发现空数据提…

求宇文玥在水下的浮力和赵丽颖捞他的时间

关注微信公众号 数据分析螺丝钉 免费领取价值万元的python/java/商业分析/数据结构与算法学习资料 2024年汉东省在达康书记的带领下率先实现高考试点改革。为让更多的考生能提升对他们的理解和记忆&#xff0c;把电视剧的场景融入考试题目中。确保学生看一遍就懂&#xff0c;想…

debian12安装时分区方案

一、初次尝试 一共设置了4个分区&#xff0c;其中根目录/分区46G&#xff0c;swap分区10G&#xff08;电脑内存为6G&#xff09;&#xff0c;/boot分区200M&#xff0c;/home分区55G。系统安装之后的实际占有情况为&#xff1a; 二、调整后情况 一共设置了4个分区&#xff0c…

基于R语言BIOMOD2 及机器学习方法的物种分布模拟与案例分析

原文链接&#xff1a;基于R语言BIOMOD2 及机器学习方法的物种分布模拟与案例分析https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247606139&idx4&snf94ec30bfb5fa7ac0320403d49db3b66&chksmfa821e9ccdf5978a44a9ba96f6e04a121c0bbf63beea0940b385011c0b…

Spring运维之boo项目表现层测试匹配响应执行状态响应体JSON和响应头

匹配响应执行状态 我们创建了测试环境 而且发送了虚拟的请求 我们接下来要进行验证 验证请求和预期值是否匹配 MVC结果匹配器 匹配上了 匹配失败 package com.example.demo;import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Auto…

【网络教程】Iptables官方教程-学习笔记7-简单理解IPTABLES规则的作用流程

前面学习了IPTABLES的所有功能介绍后&#xff0c;一个Linux设备里的IPTABLES规则集是如何运行的&#xff0c;这里简单做个介绍。 在Linux设备里输入"iptables -nvl",得到该设备的所有防火墙规则&#xff0c;得到的结果中可以看到这个设备防火墙里所有的链以及链里的…

2024年CKA模拟系统制作 | step-by-step | 1、基础环境准备

目录 一、软件环境 二、虚拟网络环境准备 1、编辑虚拟网络 2、网络设置 三、新建虚拟主机 1、新建目录 2、新建虚拟主机 四、系统安装 1、装载系统镜像 2、开启虚拟机 3、选择语言 4、键盘选择 5、网络配置 6、代理设置 7、设置软件源 8、存储设置 9、名称设置 …

计算机网络 —— 网络层 (路由协议)

计算机网络 —— 网络层 &#xff08;路由协议&#xff09; 什么是路由协议内部网关协议RIP关键特性 OSPF主要特点 外部网关协议BGP关键特性 我们今天来看路由协议&#xff1a; 什么是路由协议 路由协议是网络设备&#xff08;主要是路由器&#xff09;用来决定数据包在网络中…