chap5 CNN

卷积神经网络(CNN)

问题描述:

利用卷积神经网络,实现对MNIST数据集的分类问题

数据集:

MNIST数据集包括60000张训练图片和10000张测试图片。图片样本的数量已经足够训练一个很复杂的模型(例如 CNN的深层神经网络)。它经常被用来作为一个新 的模式识别模型的测试用例。而且它也是一个方便学生和研究者们执行用例的数据集。除此之外,MNIST数据集是一个相对较小的数据集,可以在你的笔记本CPUs上面直接执行

题目要求

Pytorch版本的卷积神经网络需要补齐self.conv1中的nn.Conv2d()self.conv2()的参数,还需要填写x=x.view()中的内容。
训练精度应该在96%以上。

import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import torch.nn.functional as F
import numpy as np
learning_rate = 1e-4
keep_prob_rate = 0.7 #
max_epoch = 3
BATCH_SIZE = 50

DOWNLOAD_MNIST = False
if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'):
    # not mnist dir or mnist is empyt dir
    DOWNLOAD_MNIST = True


train_data = torchvision.datasets.MNIST(root='./mnist/',train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset = train_data ,batch_size= BATCH_SIZE ,shuffle= True)

test_data = torchvision.datasets.MNIST(root = './mnist/',train = False)
test_x = Variable(torch.unsqueeze(test_data.test_data,dim  = 1),volatile = True).type(torch.FloatTensor)[:500]/255.
test_y = test_data.test_labels[:500].numpy()

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d( # ???
                # patch 7 * 7 ; 1  in channels ; 32 out channels ; ; stride is 1
                # padding style is same(that means the convolution opration's input and output have the same size)
                in_channels=1,
                out_channels=32,
                kernel_size=7,
                stride=1,
                padding=3,
            ),
            nn.ReLU(),        # activation function
            nn.MaxPool2d(2),  # pooling operation
        )
        self.conv2 = nn.Sequential( # ???
            # line 1 : convolution function, patch 5*5 , 32 in channels ;64 out channels; padding style is same; stride is 1
            # line 2 : choosing your activation funciont
            # line 3 : pooling operation function.
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2, stride=1),
            nn.ReLU(),
            nn.AvgPool2d(2),
        )
        self.out1 = nn.Linear( 7*7*64 , 1024 , bias= True)   # full connection layer one

        self.dropout = nn.Dropout(keep_prob_rate)
        self.out2 = nn.Linear(1024,10,bias=True)



    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(-1, 7*7*64)  # flatten the output of coonv2 to (batch_size ,32 * 7 * 7)    # ???
        out1 = self.out1(x)
        out1 = F.relu(out1)
        out1 = self.dropout(out1)
        out2 = self.out2(out1)
        output = F.softmax(out2)
        return output


def test(cnn):
    global prediction
    y_pre = cnn(test_x)
    _,pre_index= torch.max(y_pre,1)
    pre_index= pre_index.view(-1)
    prediction = pre_index.data.numpy()
    correct  = np.sum(prediction == test_y)
    return correct / 500.0


def train(cnn):
    optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate )
    loss_func = nn.CrossEntropyLoss()
    for epoch in range(max_epoch):
        for step, (x_, y_) in enumerate(train_loader):
            x ,y= Variable(x_),Variable(y_)
            output = cnn(x)  
            loss = loss_func(output,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if step != 0 and step % 20 ==0:
                print("=" * 10,step,"="*5,"="*5, "test accuracy is ",test(cnn) ,"=" * 10 )

if __name__ == '__main__':
    cnn = CNN()
    train(cnn)

训练结果为:

========== 20 ===== ===== test accuracy is  0.224 ==========
========== 40 ===== ===== test accuracy is  0.362 ==========
========== 60 ===== ===== test accuracy is  0.402 ==========
========== 80 ===== ===== test accuracy is  0.51 ==========
========== 100 ===== ===== test accuracy is  0.608 ==========
========== 120 ===== ===== test accuracy is  0.624 ==========
========== 140 ===== ===== test accuracy is  0.708 ==========
========== 160 ===== ===== test accuracy is  0.684 ==========
========== 180 ===== ===== test accuracy is  0.738 ==========
========== 200 ===== ===== test accuracy is  0.766 ==========
========== 220 ===== ===== test accuracy is  0.778 ==========
========== 240 ===== ===== test accuracy is  0.796 ==========
========== 260 ===== ===== test accuracy is  0.802 ==========
========== 280 ===== ===== test accuracy is  0.81 ==========
========== 300 ===== ===== test accuracy is  0.812 ==========
========== 320 ===== ===== test accuracy is  0.82 ==========
========== 340 ===== ===== test accuracy is  0.848 ==========
========== 360 ===== ===== test accuracy is  0.83 ==========
========== 380 ===== ===== test accuracy is  0.852 ==========
========== 400 ===== ===== test accuracy is  0.852 ==========
========== 420 ===== ===== test accuracy is  0.856 ==========
========== 440 ===== ===== test accuracy is  0.874 ==========
========== 460 ===== ===== test accuracy is  0.85 ==========
========== 480 ===== ===== test accuracy is  0.874 ==========
========== 500 ===== ===== test accuracy is  0.864 ==========
========== 520 ===== ===== test accuracy is  0.858 ==========
========== 540 ===== ===== test accuracy is  0.884 ==========
========== 560 ===== ===== test accuracy is  0.872 ==========
========== 580 ===== ===== test accuracy is  0.9 ==========
========== 600 ===== ===== test accuracy is  0.88 ==========
========== 620 ===== ===== test accuracy is  0.886 ==========
========== 640 ===== ===== test accuracy is  0.882 ==========
========== 660 ===== ===== test accuracy is  0.886 ==========
========== 680 ===== ===== test accuracy is  0.876 ==========
========== 700 ===== ===== test accuracy is  0.882 ==========
========== 720 ===== ===== test accuracy is  0.886 ==========
========== 740 ===== ===== test accuracy is  0.894 ==========
========== 760 ===== ===== test accuracy is  0.894 ==========
========== 780 ===== ===== test accuracy is  0.9 ==========
========== 800 ===== ===== test accuracy is  0.898 ==========
========== 820 ===== ===== test accuracy is  0.912 ==========
========== 840 ===== ===== test accuracy is  0.894 ==========
========== 860 ===== ===== test accuracy is  0.898 ==========
========== 880 ===== ===== test accuracy is  0.888 ==========
========== 900 ===== ===== test accuracy is  0.896 ==========
========== 920 ===== ===== test accuracy is  0.888 ==========
========== 940 ===== ===== test accuracy is  0.91 ==========
========== 960 ===== ===== test accuracy is  0.908 ==========
========== 980 ===== ===== test accuracy is  0.918 ==========
========== 1000 ===== ===== test accuracy is  0.906 ==========
========== 1020 ===== ===== test accuracy is  0.908 ==========
========== 1040 ===== ===== test accuracy is  0.906 ==========
========== 1060 ===== ===== test accuracy is  0.914 ==========
========== 1080 ===== ===== test accuracy is  0.908 ==========
========== 1100 ===== ===== test accuracy is  0.906 ==========
========== 1120 ===== ===== test accuracy is  0.906 ==========
========== 1140 ===== ===== test accuracy is  0.924 ==========
========== 1160 ===== ===== test accuracy is  0.918 ==========
========== 1180 ===== ===== test accuracy is  0.904 ==========
========== 20 ===== ===== test accuracy is  0.924 ==========
========== 40 ===== ===== test accuracy is  0.908 ==========
========== 60 ===== ===== test accuracy is  0.92 ==========
========== 80 ===== ===== test accuracy is  0.91 ==========
========== 100 ===== ===== test accuracy is  0.926 ==========
========== 120 ===== ===== test accuracy is  0.91 ==========
========== 140 ===== ===== test accuracy is  0.922 ==========
========== 160 ===== ===== test accuracy is  0.932 ==========
========== 180 ===== ===== test accuracy is  0.932 ==========
========== 200 ===== ===== test accuracy is  0.93 ==========
========== 220 ===== ===== test accuracy is  0.94 ==========
========== 240 ===== ===== test accuracy is  0.918 ==========
========== 260 ===== ===== test accuracy is  0.934 ==========
========== 280 ===== ===== test accuracy is  0.93 ==========
========== 300 ===== ===== test accuracy is  0.934 ==========
========== 320 ===== ===== test accuracy is  0.934 ==========
========== 340 ===== ===== test accuracy is  0.93 ==========
========== 360 ===== ===== test accuracy is  0.944 ==========
========== 380 ===== ===== test accuracy is  0.938 ==========
========== 400 ===== ===== test accuracy is  0.92 ==========
========== 420 ===== ===== test accuracy is  0.936 ==========
========== 440 ===== ===== test accuracy is  0.948 ==========
========== 460 ===== ===== test accuracy is  0.934 ==========
========== 480 ===== ===== test accuracy is  0.938 ==========
========== 500 ===== ===== test accuracy is  0.916 ==========
========== 520 ===== ===== test accuracy is  0.916 ==========
========== 540 ===== ===== test accuracy is  0.928 ==========
========== 560 ===== ===== test accuracy is  0.936 ==========
========== 580 ===== ===== test accuracy is  0.942 ==========
========== 600 ===== ===== test accuracy is  0.922 ==========
========== 620 ===== ===== test accuracy is  0.94 ==========
========== 640 ===== ===== test accuracy is  0.94 ==========
========== 660 ===== ===== test accuracy is  0.96 ==========
========== 680 ===== ===== test accuracy is  0.938 ==========
========== 700 ===== ===== test accuracy is  0.936 ==========
========== 720 ===== ===== test accuracy is  0.94 ==========
========== 740 ===== ===== test accuracy is  0.946 ==========
========== 760 ===== ===== test accuracy is  0.946 ==========
========== 780 ===== ===== test accuracy is  0.948 ==========
========== 800 ===== ===== test accuracy is  0.95 ==========
========== 820 ===== ===== test accuracy is  0.948 ==========
========== 840 ===== ===== test accuracy is  0.95 ==========
========== 860 ===== ===== test accuracy is  0.94 ==========
========== 880 ===== ===== test accuracy is  0.956 ==========
========== 900 ===== ===== test accuracy is  0.944 ==========
========== 920 ===== ===== test accuracy is  0.948 ==========
========== 940 ===== ===== test accuracy is  0.95 ==========
========== 960 ===== ===== test accuracy is  0.944 ==========
========== 980 ===== ===== test accuracy is  0.94 ==========
========== 1000 ===== ===== test accuracy is  0.946 ==========
========== 1020 ===== ===== test accuracy is  0.952 ==========
========== 1040 ===== ===== test accuracy is  0.952 ==========
========== 1060 ===== ===== test accuracy is  0.944 ==========
========== 1080 ===== ===== test accuracy is  0.956 ==========
========== 1100 ===== ===== test accuracy is  0.96 ==========
========== 1120 ===== ===== test accuracy is  0.948 ==========
========== 1140 ===== ===== test accuracy is  0.942 ==========
========== 1160 ===== ===== test accuracy is  0.948 ==========
========== 1180 ===== ===== test accuracy is  0.944 ==========
========== 20 ===== ===== test accuracy is  0.952 ==========
========== 40 ===== ===== test accuracy is  0.96 ==========
========== 60 ===== ===== test accuracy is  0.948 ==========
========== 80 ===== ===== test accuracy is  0.954 ==========
========== 100 ===== ===== test accuracy is  0.948 ==========
========== 120 ===== ===== test accuracy is  0.948 ==========
========== 140 ===== ===== test accuracy is  0.958 ==========
========== 160 ===== ===== test accuracy is  0.942 ==========
========== 180 ===== ===== test accuracy is  0.948 ==========
========== 200 ===== ===== test accuracy is  0.952 ==========
========== 220 ===== ===== test accuracy is  0.952 ==========
========== 240 ===== ===== test accuracy is  0.95 ==========
========== 260 ===== ===== test accuracy is  0.966 ==========
========== 280 ===== ===== test accuracy is  0.96 ==========
========== 300 ===== ===== test accuracy is  0.956 ==========
========== 320 ===== ===== test accuracy is  0.96 ==========
========== 340 ===== ===== test accuracy is  0.956 ==========
========== 360 ===== ===== test accuracy is  0.956 ==========
========== 380 ===== ===== test accuracy is  0.954 ==========
========== 400 ===== ===== test accuracy is  0.96 ==========
========== 420 ===== ===== test accuracy is  0.966 ==========
========== 440 ===== ===== test accuracy is  0.96 ==========
========== 460 ===== ===== test accuracy is  0.954 ==========
========== 480 ===== ===== test accuracy is  0.968 ==========
========== 500 ===== ===== test accuracy is  0.958 ==========
========== 520 ===== ===== test accuracy is  0.958 ==========
========== 540 ===== ===== test accuracy is  0.962 ==========
========== 560 ===== ===== test accuracy is  0.968 ==========
========== 580 ===== ===== test accuracy is  0.958 ==========
========== 600 ===== ===== test accuracy is  0.952 ==========
========== 620 ===== ===== test accuracy is  0.95 ==========
========== 640 ===== ===== test accuracy is  0.964 ==========
========== 660 ===== ===== test accuracy is  0.962 ==========
========== 680 ===== ===== test accuracy is  0.96 ==========
========== 700 ===== ===== test accuracy is  0.962 ==========
========== 720 ===== ===== test accuracy is  0.964 ==========
========== 740 ===== ===== test accuracy is  0.958 ==========
========== 760 ===== ===== test accuracy is  0.96 ==========
========== 780 ===== ===== test accuracy is  0.972 ==========
========== 800 ===== ===== test accuracy is  0.962 ==========
========== 820 ===== ===== test accuracy is  0.968 ==========
========== 840 ===== ===== test accuracy is  0.964 ==========
========== 860 ===== ===== test accuracy is  0.96 ==========
========== 880 ===== ===== test accuracy is  0.964 ==========
========== 900 ===== ===== test accuracy is  0.96 ==========
========== 920 ===== ===== test accuracy is  0.96 ==========
========== 940 ===== ===== test accuracy is  0.97 ==========
========== 960 ===== ===== test accuracy is  0.956 ==========
========== 980 ===== ===== test accuracy is  0.966 ==========
========== 1000 ===== ===== test accuracy is  0.964 ==========
========== 1020 ===== ===== test accuracy is  0.964 ==========
========== 1040 ===== ===== test accuracy is  0.97 ==========
========== 1060 ===== ===== test accuracy is  0.974 ==========
========== 1080 ===== ===== test accuracy is  0.962 ==========
========== 1100 ===== ===== test accuracy is  0.97 ==========
========== 1120 ===== ===== test accuracy is  0.974 ==========
========== 1140 ===== ===== test accuracy is  0.978 ==========
========== 1160 ===== ===== test accuracy is  0.976 ==========
========== 1180 ===== ===== test accuracy is  0.974 ==========

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/665315.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【课程总结】Day4:信息论和决策树算法

前言 本章内容主要是学习机器学习中的一个重要模型:决策树,围绕决策树的应用,我们展开了解到:熵的定义、熵的计算、决策树的构建过程(基于快速降熵)、基尼系数等,从而使得我们对决策树有了直观认识。 熵的介绍 因为…

U盘损坏打不开?数据恢复攻略全解析

随着信息技术的飞速发展,U盘已成为我们日常工作和生活中不可或缺的数据存储工具。然而,当U盘突然损坏,无法打开时,我们往往会陷入焦虑和无助之中。本文将为大家详细解析U盘损坏打不开的原因,并提供两种有效的数据恢复方…

【stm32】stm32f407 ch340下载

一、接线 1、ch340 Vcc短接3v3 5v---------5v GND-----GND TX ------RX RX --------TX 2、stm32F407 如上图,我们需要进入isp下载模式,接线图如下 二、下载 使用FlyMcu选择你要下载的程序文件中的.hex文件, 然后配置图如下&#xff1…

5月安全月报 | 钓鱼事件频发,OKLink带你开启“防钓”模式

本月全网安全事件所造成的损失环比上升 27.27%,钓鱼与诈骗事件占比 60% 以上。 安全意识是您保护数字资产的第一道防线,OKLink 提供 40 头部区块链浏览器与一站式查询入口以及地址监控、代币授权查询和地址健康度等工具,为您的资产安全保驾护…

使用CS抓取WIN2012明文密码

目录 实验概述: 开始实验: 实验准备: 打开CS: 生成木马控制wind2012: 抓取明文密码: 实验概述: win2012及win10版本是不允许将明文密码储存在内存中的,此时我们…

大量path计算优化方案

1.影响path基础属性数据做key缓存,缓存的path应去除坐标变换,归一化。基础属性应满足CAB, BC-A 2.高频path操作以(keykey操作)做新key缓存。 3.高频修改高级属性,以新key属性变更做新key缓存。 4.key与id做中转映射&am…

【每日刷题】Day52

【每日刷题】Day52 🥕个人主页:开敲🍉 🔥所属专栏:每日刷题🍍 🌼文章目录🌼 1. 2965. 找出缺失和重复的数字 - 力扣(LeetCode) 2. 350. 两个数组的交集 II …

C# FTP/SFTP 详解及连接 FTP/SFTP 方式示例汇总

文章目录 1、FTP/SFTP基础知识FTPSFTP 2、FTP连接示例3、SFTP连接示例4、总结 在软件开发中,文件传输是一个常见的需求。尤其是在不同的服务器之间传输文件时,FTP(文件传输协议)和SFTP(安全文件传输协议)成…

dolphinscheduler docker部署海豚mysql版本,docker重新封装正在运行服务为镜像

1.官方文档: https://dolphinscheduler.apache.org/zh-cn/docs/3.2.1/guide/installation/standalone#%E9%85%8D%E7%BD%AE%E6%95%B0%E6%8D%AE%E5%BA%93 2.github: dolphinscheduler/docs/docs/zh/guide/howto/datasource-setting.md at 3.2.1-release apache/do…

【R基础】如何开始学习R-从下载R及Rstudio开始

文章目录 概要下载R流程下载Rstudio流程下载完成-打开 概要 提示:如何开始学习R-从下载R及Rstudio开始,此处我只是想下载指定版本R4.3.3 下载R流程 链接: R官网 文件下载到本地 下载文件展示 按照向导指示安装 下载Rstudio流程 链接: Rstudio官网…

深度学习-语言模型

深度学习-语言模型 统计语言模型神经网络语言模型语言模型的应用序列模型(Sequence Model)语言模型(Language Model)序列模型和语言模型的区别 语言模型(Language Model)是自然语言处理(NLP&…

AI预测福彩3D采取888=3策略+和值012路一缩定乾坤测试5月31日预测第7弹

昨天的3D已命中!今天继续基于8883的大底,使用尽可能少的条件进行缩号。好了,直接上结果吧~ 首先,888定位如下: 百位:7,6,5,8,9,3,2,0 十位:3,4,5,2,1,7,8,9 …

20240531在飞凌的OK3588-C开发板上跑原厂的Buildroot测试USB摄像头

20240531在飞凌的OK3588-C开发板上跑原厂的Buildroot测试USB摄像头 2024/5/31 20:04 USB摄像头分辨率:1080p(1920x1080) 默认编译Buildroot的SDK即可点亮USB摄像头。v4l2-ctl --list-devices v4l2-ctl --list-formats-ext -d /dev/video74 …

【UnityShader入门精要学习笔记】第十六章 Unity中的渲染优化技术 (下)

本系列为作者学习UnityShader入门精要而作的笔记,内容将包括: 书本中句子照抄 个人批注项目源码一堆新手会犯的错误潜在的太监断更,有始无终 我的GitHub仓库 总之适用于同样开始学习Shader的同学们进行有取舍的参考。 文章目录 减少需要处…

【Linux】操作系统之冯诺依曼体系

🎉博主首页: 有趣的中国人 🎉专栏首页: Linux 🎉其它专栏: C初阶 | C进阶 | 初阶数据结构 小伙伴们大家好,本片文章将会讲解 操作系统中 冯诺依曼体系 的相关内容。 如果看到最后您觉得这篇文…

基于编译型语言鲲鹏应用开发小技巧

编译型语言应用执行过程 大部分应用可以通过重新编译即可移植到鲲鹏平台 预处理命令: gcc -E hello.c -o hello.i,预处理完成后使用命令: cat hello.i可以看到预处理后的代码 编译命令: gcc -s hello.i -o hello.s 汇编命令: gcc -c hello.c -o hello.o 链接处理…

接口测试之XML响应断言

目录 XPath 基本语法XML 响应结果解析XML 响应结果断言 XML 响应数据 如何提取 AddResult 中的值&#xff1f; <soap:Body><AddResponse xmlns"http://tempuri.org/"><AddResult>4</AddResult></AddResponse> </soap:Body> …

SpringBoot中MyBatisPlus的使用

MyBatis Plus 是 MyBatis 的增强工具&#xff0c;提供了许多强大的功能&#xff0c;简化了 MyBatis 的使用。下面是在 Spring Boot 中使用 MyBatis Plus 的步骤&#xff1a; 添加依赖&#xff1a;在 Maven 或 Gradle 的配置文件中添加 MyBatis Plus 的依赖。 配置数据源&#…

AI换脸FaceFusion一键云部署指南

大家好&#xff0c;从我开始分享到现在&#xff0c;收到很多朋友的反馈说配置很低玩不了AI。本篇是一个云端部署AI项目的指南&#xff0c;帮助大家在云端进行AI项目的部署。我会从云平台的选择、代码部署、保存镜像几个方面进行详细的介绍。没有代码基础的小白也不用担心&#…

QT系列教程(6) 几种标准对话框

几种标准对话框 本文介绍几种标准对话框&#xff0c;都是Qt封装好的&#xff0c;我们先创建一个界面&#xff0c;添加几个按钮&#xff0c;然后分别在几个按钮的回调函数里添加创建不同对话框的逻辑 颜色对话框 颜色对话框用来选择颜色&#xff0c;创建后会显示各种颜色和透明…