卷积神经网络实战

构建卷积神经网络

  • 卷积网络中的输入和层与传统神经网络有些区别,需重新设计,训练模块基本一致

1.首先读取数据

 - 分别构建训练集和测试集(验证集)
- DataLoader来迭代取数据

# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片

# 训练集
train_dataset = datasets.MNIST(root='./data',  
                            train=True,   
                            transform=transforms.ToTensor(),  
                            download=True) 

# 测试集
test_dataset = datasets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())

# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

2.卷积网络模块构建

- 一般卷积层,relu层,池化层可以写成一个套餐
- 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)
            nn.Conv2d(
                in_channels=1,              # 灰度图
                out_channels=16,            # 要得到几多少个特征图
                kernel_size=5,              # 卷积核大小
                stride=1,                   # 步长
                padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1
            ),                              # 输出的特征图为 (16, 28, 28)
            nn.ReLU(),                      # relu层
            nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14)
        )
        self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)
            nn.ReLU(),                      # relu层
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),                # 输出 (32, 7, 7)
        )
        
        self.conv3 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)
            nn.Conv2d(32, 64, 5, 1, 2),     # 输出 (32, 14, 14)
            nn.ReLU(),             # 输出 (32, 7, 7)
        )
        
        self.out = nn.Linear(64 * 7 * 7, 10)   # 全连接层得到的结果

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output

 

3.准确率作为评估标准

def accuracy(predictions, labels):
    pred = torch.max(predictions.data, 1)[1] 
    rights = pred.eq(labels.data.view_as(pred)).sum() 
    return rights, len(labels) 

 

4训练网络模型

# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法

#开始训练循环
for epoch in range(num_epochs):
    #当前epoch的结果保存下来
    train_rights = [] 
    
    for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环
        net.train()                             
        output = net(data) 
        loss = criterion(output, target) 
        optimizer.zero_grad() 
        loss.backward() 
        optimizer.step() 
        right = accuracy(output, target) 
        train_rights.append(right) 

    
        if batch_idx % 100 == 0: 
            
            net.eval() 
            val_rights = [] 
            
            for (data, target) in test_loader:
                output = net(data) 
                right = accuracy(output, target) 
                val_rights.append(right)
                
            #准确率计算
            train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
            val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))

            print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(
                epoch, batch_idx * batch_size, len(train_loader.dataset),
                100. * batch_idx / len(train_loader), 
                loss.data, 
                100. * train_r[0].numpy() / train_r[1], 
                100. * val_r[0].numpy() / val_r[1]))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/522315.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

优雅强大的前端管理模板——Soybean Admin

公众号:【可乐前端】,每天3分钟学习一个优秀的开源项目,分享web面试与实战知识,也有全栈交流学习摸鱼群,期待您的关注! 每天3分钟开源 hi,这里是每天3分钟开源,很高兴又跟大家见面了&#xff0…

python 03序列(列表和元组)

列表 1.创建 x[1,2,3,4,5,6,7,8,9,10] print(x) 或者是 y[a,b,c,d,e,f,g,h] print(y) 2.访问 (1)取出一个元素 x[0] #取出第0号,即List里第一个元素 (2)取出多个连续元素 通过两个索引值实现,第一…

专题【双指针】【学习题】刷题日记

题目列表 11. 盛最多水的容器 42. 接雨水 15. 三数之和 16. 最接近的三数之和 18. 四数之和 26. 删除有序数组中的重复项 27. 移除元素 75. 颜色分类 167. 两数之和 II - 输入有序数组 2024.04.06 11. 盛最多水的容器 题目 给定一个长度为 n 的整数数组 height 。有 n 条垂…

MATLAB - mpcobj = mpc(model,ts,P,M,W,MV,OV,DV) 函数

系列文章目录 前言 模型预测控制器使用线性工厂、干扰和噪声模型来估计控制器状态并预测未来的工厂输出。控制器利用预测的设备输出,解决二次规划优化问题,以确定控制动作。 有关模型预测控制器结构的更多信息,请参阅 MPC 预测模型。 一、语法…

SpringMVC--概述 / 入门

目录 1. SpringMVC简介 2. 配置&入门 2.1. 开发环境 2.2. 创建maven工程 2.3. 手动创建 web.xml 2.4. 配置web.xml 2.4.1. 默认配置方式 2.4.2. 扩展配置方式 2.5. 创建请求控制器 2.6. 创建springMVC的配置文件 2.7. 测试 HelloWorld 2.7.1. 实现对首页的访问…

OJ在线比赛系统(人员管理、赛题发布、在线提交、题目审核、成绩录入)

系统功能设计 技术栈:springboot,jdk8,vue3,element-plus,mybatis-plus 1.java后端系统 首先需要学生通过前端注册页面和java后端系统将个人信息写入数据库,包含学号、姓名、班级以及需要爬取网站的相关信息(例如AtCoder账号信…

智谱清言 HTTP调用 + postman使用

官方教程 接口鉴权 非SDK用户鉴权 官方网站 第一步 获取您的 API Key 第二步 使用 JWT 组装 用户端需引入对应 JWT 相关工具类,并按以下方式组装 JWT 中 header、payload 部分 1、header 具体示例 {“alg”:“HS256”,“sign_type”:“SIGN”} alg : 属性表…

批量导入svg文件作为图标使用(vue3)vite-plugin-svg-icons插件的具体应用

目录 需求svg使用简述插件使用简述实现安装插件1、配置vite.config.ts2、src/main.ts引入注册脚本3、写个icon组件4、使用组件 需求 在vue3项目中,需要批量导入某个文件夹内数量不确定的svg文件用来作为图标,开发完成后能够通过增减文件夹内的svg文件&a…

ICLR 2024 | 联邦学习后门攻击的模型关键层

ChatGPT狂飙160天,世界已经不是之前的样子。 新建了免费的人工智能中文站https://ai.weoknow.com 新建了收费的人工智能中文站https://ai.hzytsoft.cn/ 更多资源欢迎关注 联邦学习使多个参与方可以在数据隐私得到保护的情况下训练机器学习模型。但是由于服务器无法…

论文阅读——MVDiffusion

MVDiffusion: Enabling Holistic Multi-view Image Generation with Correspondence-Aware Diffusion 文生图模型 用于根据给定像素到像素对应关系的文本提示生成一致的多视图图像。 MVDiffusion 会在给定任意每个视图文本的情况下合成高分辨率真实感全景图像,或将…

亚信安慧AntDB:开启数据洞察的新视野

AntDB一直秉承着“技术生态”的理念,不断进行技术创新和功能增强,以保持与先进数据库系统的竞争力。作为一款致力于提升数据库处理性能和稳定性的系统,AntDB在技术上始终保持敏锐的洞察力,不断汲取国内外先进技术的精华&#xff0…

Scala大数据开发

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl Scala简述 在此,简要介绍 Scala 的基本信息和情况。 Scala释义 Scala 源自于英语单词scalable,表示可伸缩的、可扩展的含义。 Scala作者 Scala编…

tomcat 结构目录

bin 启动,关闭和其他脚本。这些 .sh文件(对于Unix系统)是这些.bat文件的功能副本(对于Windows系统)。由于Win32命令行缺少某些功能,因此此处包含一些其他文件。比如说:windows下启动tomcat用的是…

基于开源IM即时通讯框架MobileIMSDK:RainbowChat-iOS端v9.0版已发布

关于MobileIMSDK MobileIMSDK 是一套专门为移动端开发的开源IM即时通讯框架,超轻量级、高度提炼,一套API优雅支持 UDP 、TCP 、WebSocket 三种协议,支持 iOS、Android、H5、标准Java、小程序、Uniapp,服务端基于Netty编写。 工程…

如何不编程用 ChatGPT 爬取网站数据?

敢于大胆设想,才能在 AI 时代提出好问题。 需求 很多小伙伴,都需要为研究获取数据。从网上爬取数据,是其中关键一环。以往,这都需要编程来实现。 可最近,一位星友在知识星球提问: 这里涉及到一些个人隐私&a…

【VMware Workstation】启动虚拟机报错“此主机支持 AMD-V,但 AMD-V 处于禁用状态”

问题出现步骤: 打开虚拟机: 然后报错: “此主机支持 AMD-V,但 AMD-V 处于禁用状态。 如果已在 BIOS/固件设置中禁用 AMD-V,或主机自更改此设置后从未重新启动,则 AMD-V 可能被禁用。 (1) 确认 BIOS/固件设…

吴恩达2022机器学习专项课程(一) 第二周课程实验:多元线性回归(Lab_02)

1.训练集 使用Numpy数组存储数据集。 2.打印数组 打印两个数组的形状和数据。 3.初始化w,b 为了演示,w,b预设出接近最优解的值。w是一个一维数组,w个数对应特征个数。 4.非向量化计算多元线性回归函数 使用for循环&…

泰迪·南通师范大数据智能工作室挂牌签约仪式圆满结束

为促进毕业生高质量就业,拓宽就业渠道,加强校企合作,4月2日,泰迪智能科技股份有限公司上海分公司总经理彭艳昆一行来校出席南通师范高等专科学校“泰迪科技南通师范大数据智能工作室”签约揭牌仪式。学校党委副书记陈玉君、科技处…

LabVIEW数控磨床振动分析及监控系统

LabVIEW数控磨床振动分析及监控系统 在现代精密加工中,数控磨床作为关键设备之一,其加工质量直接影响到产品的精度与性能。然而,磨削过程中的振动是影响加工质量的主要因素之一,不仅会导致工件表面质量下降,还可能缩短…

41.基于SpringBoot + Vue实现的前后端分离-校园网上店铺管理系统(项目 + 论文PPT)

项目介绍 二十一世纪互联网的出现,改变了几千年以来人们的生活,不仅仅是生活物资的丰富,还有精神层次的丰富。本课题研究和开发校园网上店铺,让安装在计算机上的该系统变成管理人员的小帮手,提高校园店铺商品销售信息处…