pytorch-MNIST测试实战

这里写目录标题

  • 1. 为什么test
  • 2. 如何做test
  • 3. 什么时候做test
  • 4. 完整代码

1. 为什么test

如下图:上下两幅图中蓝色分别表示train的accuracy和loss,黄色表示test的accuracy和loss,如果单纯看train的accuracy和loss曲线就会认为模型已经train的很好了,accuracy一直在上升接近于1了,loss一直在下降已经接近于0了,殊不知此时可能已经出现了over fitting(本数据集准确率很高,其他数据准确率很低),此时就需要test了,从图中可以看出test在红色划线右侧的accuracy已经不变甚至下降了,loss曲线波动也比较大,甚至已经上升了。
在这里插入图片描述

2. 如何做test

如下图所示:
argmax找出概率最大的数字的index
softmax在这里使用与不使用结果是一样的,因为softmax不改变单调性(大的依然大,小的依然小)
使用torch.eq计算预测值与目标值是否相当,相等返回1不等返回0
correct.sum().float().item() /4是用来计算accuracy的,其他sum()是计算正确的个数,item是tensor转bumpy; /4是除以总样本数
在这里插入图片描述

3. 什么时候做test

  • 每几个batch做一次
  • 一个epoch做一次
    注意:为什么不一个batch做一次test呢?因为test的数据可能也比较大,每个batch都test会影响train的速度

4. 完整代码

从一下代码可知,test是一个epoch做一次,首先像train一样load test数据,并搬到GPU中,然后数据输入到网络中,计算loss,最后计算准确了并打印输出

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms


batch_size=200
learning_rate=0.01
epochs=10

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)



class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 10),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        x = self.model(x)

        return x

device = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)

for epoch in range(epochs):

    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)
        data, target = data.to(device), target.cuda()

        logits = net(data)
        loss = criteon(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.norm())
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        data, target = data.to(device), target.cuda()
        logits = net(data)
        test_loss += criteon(logits, target).item()

        pred = logits.argmax(dim=1)
        correct += pred.eq(target).float().sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/560754.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【优质书籍推荐】Vue.js+Node.js全栈开发实战

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。…

von Mises-Fisher Distribution (代码解析)

torch.distribution 中包含了很多概率分布的实现,本文首先通过均匀分布来说明 Distribution 的具体用法, 然后再解释 von Mises-Fisher 分布的实现, 其公式推导见 von Mises-Fisher Distribution. 1. torch.distribution.Distribution 以下是 Uniform 的源码: cl…

怎么使用JMeter进行性能测试?

一、简介 JMeter是Apache软件基金会下的一款开源的性能测试工具,完全由Java开发。它专注于对我们应用程序进行负载测试和性能测量,最初设计用于web应用程序,现在已经扩展到其他测试功能,比如:FTP、Database和LDAP等。…

Vue 3 项目构建与效率提升:vite-plugin-vue-setup-extend 插件应用指南

一、Vue3项目创建 前提是已安装Node.js(点击跳转Node官网) npm create vuelatest这一指令将会安装并执行 create-vue,它是 Vue 官方的项目脚手架工具。你将会看到一些诸如 TypeScript 和测试支持之类的可选功能提示: ✔ Projec…

跟着Carl大佬学leetcode之977 有序数组的平方

来点强调,刷题是按照代码随想录的顺序进行的,链接如下https://www.programmercarl.com/本系列是记录一些刷题心得和学习过程,就看到题目自己先上手试试,然后看程序员Carl大佬的解释,自己再敲一遍修修补补,练…

electron打包dist为可执行程序后记【electron-quick-start】

文章目录 目录 文章目录 前言 一、直接看效果 二、实现步骤 1.准备dist文件夹 2.NVM管理node版本 3.准备electron容器并npm run start 4.封装成可执行程序 1.手动下载electron对应版本的zip文件,解决打包缓慢问题 2.安装packager 3.配置打包命令执行内容…

【点云语义分割】弱监督点云语义分割自适应标签分布

Adaptive Annotation Distribution for Weakly Supervised Point Cloud Semantic Segmentation 摘要: 弱监督点云语义分割因其能够减轻对点云细粒度注释的严重依赖而备受关注。然而,在实际应用中,稀疏注释通常在点云中呈现出明显的非均匀分布…

Table表格(关于个人介绍与图片)

展开行&#xff1a; <el-table :data"gainData" :border"gainParentBorder" style"width: 100%"><el-table-column type"expand"><template #default"props"><div m"4"><h3>工作经…

封装个js分页插件

// 分页插件类 class PaginationPlugin {constructor(fetchDataURL, options {}) {this.fetchDataURL fetchDataURL;this.options {containerId: options.containerId || paginationContainer,dataSizeAttr: options.dataSizeAttr || toatalsize, // 修改为实际API返回的数据…

Flutter 的 showDialog 和 showCupertinoDialog 有什么区别?

我将我的 App 里用的 Flutter 升级到了 3.19&#xff0c;没想到&#xff0c;以前我用 showDialog 和 AlertDialog 组合创建的二次确认框&#xff0c;变得无敌难看了&#xff0c;大幅度增加了整个框的圆角和里面默认按钮的圆角。不得已&#xff0c;我必须修改一下&#xff0c;以…

51.基于SpringBoot + Vue实现的前后端分离-校园志愿者管理系统(项目 + 论文)

项目介绍 本站是一个B/S模式系统&#xff0c;采用SpringBoot Vue框架&#xff0c;MYSQL数据库设计开发&#xff0c;充分保证系统的稳定性。系统具有界面清晰、操作简单&#xff0c;功能齐全的特点&#xff0c;使得基于SpringBoot Vue技术的校园志愿者管理系统设计与实现管理工…

偏微分方程算法之二维初边值问题(紧交替方向隐格式)

目录 一、研究对象 二、理论推导 2.1 二维紧差分格式 2.2 紧交替方向格式 2.2.1 紧Peaceman-Rachford格式 2.2.2 紧D’Yakonov格式 2.2.3 紧Douglas格式 三、算例实现 四、结论 一、研究对象 继续以二维抛物型方程初边值问题为研究对象: 为了确保连续性,公式…

Vue.extend()和我的两米大砍刀

Vue.extends是什么&#xff1f; 一个全局API,用于注册并挂载组件。 传统的引用组件的方式是使用import直接引入&#xff0c;但是使用Vue.extends()也可以实现。 使用规则 <div id"mount-point"></div>// 创建构造器 var Profile Vue.extend({templat…

PyCharm,终端conda环境无法切换的问题(二个解决方案)

问题 PyCharm终端&#xff0c;环境切换无效&#xff0c;默认始终为base 解决一 Settings->Tools->Terminal->ShellPath&#xff0c;将powershell修改为cmd.exe 解决二 conda config --show在输出中找到 auto_activate_base 的行&#xff0c;发现被设置为 true&#x…

R语言中的execl数据转plink

文章目录 带出外部连接的方式添加列的方式从列表中选出对应的数据信息查看变量信息没有成功 带出外部连接的方式 点击这个黄色的按钮就可以弹出外部链接的方式 添加列的方式 创建一个数据框的方式 我们创建一个三行三列的数据方式 df <- data.frame(name c("Alice&…

嵌入式linux中利用QT控制蜂鸣器方法

大家好,今天给大家分享一下,如何控制开发板上的蜂鸣器。 第一:开发板原理图 从原理图中可以得出,当引脚输出低电平的时候,对应的蜂鸣器发出响声。 第二:QT代码详细实现 设置一个按钮,点击即可控制BEEP状态发生反转。 #ifndef MAINWINDOW_H #define MAINWINDOW_H#in…

【网络编程】TCP流套接字编程(TCP实现回显服务器)

一.TCP流套字节相关API. Socket(既能给客户端使用,也能给服务器使用) 构造方法 基本方法: ServerSocket(只能给服务器使用) 构造方法: 基本方法: 二.TCP实现回显服务器. 客户端代码示例: package Demo2;import java.io.IOException; import java.io.InputStream; import j…

【飞桨AI实战】人体姿态估计:零基础入门,从模型训练到应用开发

前言 本次分享将带领大家从 0 到 1 完成一个人体姿态估计任务&#xff0c;覆盖数据准备、模型训练、推理部署和应用开发的全流程&#xff0c;项目将采用以PaddlePaddle为核心的飞桨深度学习框架进行开发&#xff0c;并总结开发过程中踩过的一些坑&#xff0c;希望能为有类似项…

(C语言)sscanf 与 sprintf详解

目录 1.sprintf函数详解 2. sscanf函数详解 1.sprintf函数详解 头文件&#xff1a;stdio.h 作用&#xff1a;将格式化的数据写入字符串里&#xff0c;也就是将格式化的数据转变为字符串。 演示&#xff1a; #include <stdio.h> struct S {char name[10];int height;…

【Gradle如何安装配置及使用的教程】

&#x1f3a5;博主&#xff1a;程序员不想YY啊 &#x1f4ab;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f917;点赞&#x1f388;收藏⭐再看&#x1f4ab;养成习惯 ✨希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出…