最快速度与最简代码搭建卷积神经网络,并快速训练模型,每日坚持手撕默写代码

大家好,我是微学AI,今天给大家介绍一下最快速度与最简代码搭建卷积神经网络,并快速训练模型,每日坚持手撕默写代码。随着人工智能的快速发展,去年有强大的大模型ChatGPT横空出世,国内的大模型也紧追其后的发布,主要包括:文心一言、ChatGLM、通义千问、百川大模型等,他们可以帮助我们编写代码,但是在实际中,高度依赖于大模型则会缺乏思考的能力,缺乏编写代码的感觉,在别人问的时候,缺乏熟练度。坚持多写代码反复进行,可以提高熟练程度,提高开发效率,锻炼记忆力。本文尝试利用最短的代码实现数据集、卷积神经网络的搭建、模型的训练,模型的评估的整个流程代码,快速熟练手打出来。

在这里插入图片描述

一、坚持手撕默写代码的意义:

关于坚持手撕默写代码的意义,我总结一下几点:

1.提高熟练程度:

通过手撕默写代码,我能够更加深入地理解代码的逻辑和工作原理,加深对代码的理解,并提高对编程语言和算法的熟练程度。

2.培养思维逻辑与开发效率:
手撕默写代码需要你对算法和语法有较为全面的理解,同时需要你将思路转化为具体的代码实现。这种过程能够培养我的思维逻辑能力,提高问题解决能力,提高模型库包的快速调用与开发效率。

3.探索学习新知识:
通过手撕默写代码,你会遇到各种问题和挑战,需要不断查阅资料、学习和探索,从中获得新的知识和技能。

4.锻炼记忆力:
反复手写代码可以加强对语法和细节的记忆,提高记忆力和代码的熟悉程度。

二、卷积神经网络的快速搭建

关于pytorch框架,我们经常用到的第三方库有torch,torch.nn,torchvision,这些我们要烂熟于心。

torch:torch是PyTorch的核心库,提供了张量操作、数学函数、自动求导等功能。它是一个多维数组的库,类似于NumPy,但具有GPU加速和用于深度学习的其他扩展功能。

torch.nn:torch.nn模块是PyTorch中用于构建神经网络模型的模块。它提供了各种层(如全连接层、卷积层、循环层等)和损失函数(如交叉熵损失、均方误差损失等),以及优化算法(如随机梯度下降等)的实现。

torchvision.transforms:torchvision.transforms模块提供了一系列用于图像预处理和数据增强的函数。通过该模块,可以对输入图像进行常见的操作,如裁剪、缩放、旋转、归一化等,以便更好地适应模型的输入要求。

torch.utils.data.DataLoader:torch.utils.data.DataLoader是PyTorch中用于加载和迭代数据集的工具。它可以将数据集封装成可迭代的数据加载器,支持批量加载、多线程加载和数据打乱等功能。

torchvision.datasets.FakeData:torchvision.datasets.FakeData是用于生成虚拟数据集的类。它可以根据指定的数据样式和大小生成虚拟的图像数据集,用于模型调试和测试。本文利用FakeData进行快速训练

第三方库的导入与卷积神经网络搭建:

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData

class CNNnet(nn.Module):
    def __init__(self):
        super(CNNnet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3,32,3,1,1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,1,1),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.linear = nn.Linear(int((32/4)*(32/4)*64),2)

    def forward(self, x):
        x = self.conv1(x)
        x =x.view(x.size(0),-1)
        x = self.linear(x)
        return x

在上述的CNNnet网络模型中,nn.Linear(int((32/4)*(32/4)64),2)中的int((32/4)(32/4)*64)是指线性层的输入特征数。在该模型中,线性层的输入来自于卷积层输出的特征图,经过reshape处理后得到的一维向量。具体地,假设输入图像的大小为 W x H,卷积核大小为 k x k,卷积层的输出通道数为 n,则经过两次最大池化后,卷积层的输出特征图的大小为 (W/4) x (H/4) x n。因此,线性层的输入特征数 num = (W/4) x (H/4) x n。
我们这里设置输入图像的大小为 32x32,卷积核大小为 3x3,卷积层的输出通道数为 64,则经过两次最大池化后,卷积层的输出特征图的大小为 (32/4)x(32/4)x64=8x8x64=4096。因此,线性层的输入特征数 num=4096。

三、模型训练代码快速编写

model = CNNnet()  # 实例化模型

criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 建立Adam优化器

dataset = FakeData(size=1000,image_size=(3,32,32),num_classes=2,transform=transforms.ToTensor())
train_loader=DataLoader(dataset,batch_size=32,shuffle=True)

for epoch in range(25):
    running_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        if (i + 1) % 10 == 0:
            print('[Epoch %d, Batch %5d] Loss: %.3f | Accuracy: %.3f%%' %
                  (epoch + 1, i + 1, running_loss / 5, 100 * correct / total))
            running_loss = 0.0
            correct = 0
            total = 0

四、模型评估代码快速编写

# 模型评估
model.eval()
total = 0
correct = 0

with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy on the training dataset: %.3f%%' % (100 * correct / total))

上面的代码将模型设置为评估模式(model.eval()),然后使用torch.no_grad()上下文管理器来禁用梯度计算,以提高运行效率。在遍历训练集数据进行预测时,统计正确预测的样本数,并计算准确率。
该评估代码是在训练集上进行评估,如果需要在测试集上评估模型,需要使用测试集的数据进行评估。这里没有做扩展。

运行结果:

...
[Epoch 18, Batch    10] Loss: 0.292 | Accuracy: 99.062%
[Epoch 18, Batch    20] Loss: 0.264 | Accuracy: 100.000%
[Epoch 18, Batch    30] Loss: 0.245 | Accuracy: 100.000%
[Epoch 19, Batch    10] Loss: 0.208 | Accuracy: 100.000%
[Epoch 19, Batch    20] Loss: 0.218 | Accuracy: 100.000%
[Epoch 19, Batch    30] Loss: 0.215 | Accuracy: 99.688%
[Epoch 20, Batch    10] Loss: 0.201 | Accuracy: 100.000%
[Epoch 20, Batch    20] Loss: 0.183 | Accuracy: 100.000%
[Epoch 20, Batch    30] Loss: 0.165 | Accuracy: 100.000%
[Epoch 21, Batch    10] Loss: 0.136 | Accuracy: 100.000%
[Epoch 21, Batch    20] Loss: 0.137 | Accuracy: 100.000%
[Epoch 21, Batch    30] Loss: 0.119 | Accuracy: 100.000%
[Epoch 22, Batch    10] Loss: 0.108 | Accuracy: 100.000%
[Epoch 22, Batch    20] Loss: 0.102 | Accuracy: 100.000%
[Epoch 22, Batch    30] Loss: 0.098 | Accuracy: 100.000%
[Epoch 23, Batch    10] Loss: 0.087 | Accuracy: 100.000%
[Epoch 23, Batch    20] Loss: 0.083 | Accuracy: 100.000%
[Epoch 23, Batch    30] Loss: 0.086 | Accuracy: 100.000%
[Epoch 24, Batch    10] Loss: 0.072 | Accuracy: 100.000%
[Epoch 24, Batch    20] Loss: 0.075 | Accuracy: 100.000%
[Epoch 24, Batch    30] Loss: 0.075 | Accuracy: 100.000%
[Epoch 25, Batch    10] Loss: 0.068 | Accuracy: 100.000%
[Epoch 25, Batch    20] Loss: 0.060 | Accuracy: 100.000%
[Epoch 25, Batch    30] Loss: 0.065 | Accuracy: 100.000%
Accuracy on the training dataset: 100.000%

本文只是将模型训练的过程跑通,手打快速训练卷积神经网络网络的过程。实际应用场景中还需要将数据集分为训练集、验证集、测试集,详细的过程可以看我的往期文章。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/277297.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

用于IT管理的COBIT

随着世界的不断发展和变化,企业必须像冲浪者一样乘风破浪,适应社会不断更新的浪潮,拥抱新技术。信息技术(IT)已成为大多数企业运营的支柱,对战略决策、客户互动和整体效率都起了一定的影响作用。然而&#…

三巨头对决:深入了解pnpm、yarn与npm

欢迎来到我的博客,代码的世界里,每一行都是一个故事 三巨头对决:深入了解pnpm、yarn与npm 前言包管理器简介npm(Node Package Manager):Yarn:pnpm(Performant Npm)&#…

基于ssm学生奖惩管理系统+v论文

摘 要 在如今社会上,关于信息上面的处理,没有任何一个企业或者个人会忽视,如何让信息急速传递,并且归档储存查询,采用之前的纸张记录模式已经不符合当前使用要求了。所以,对学生奖惩信息管理的提升&#x…

PowerShell Instal 一键部署TeamCity

前言 TeamCity 是一个通用的 CI/CD 软件平台,可实现灵活的工作流程、协作和开发实践。允许在您的 DevOps 流程中成功实现持续集成、持续交付和持续部署。 系统支持 Centos7,8,9/Redhat7,8,9及复刻系列系统支持 Windows 10,11,2012,2016,2019,2022高版本建议使用9系列系统…

C语言 linux文件操作(二)

文章目录 一、获取文件长度二、追加写入三、覆盖写入四、文件创建函数creat 一、获取文件长度 通过lseek函数,除了操作定位文件指针,还可以获取到文件大小,注意这里是文件大小,单位是字节。例如在file1文件中事先写入"你好世…

智慧工地云平台源码 支持二次开发、支持源码交付

智慧工地利用移动互联、物联网、云计算、大数据等新一代信息技术,彻底改变传统施工现场各参建方的交互方式、工作方式和管理模式,为建设集团、施工企业、监理单位、设计单位、政府监管部门等提供一揽子工地现场管理信息化解决方案。 通过人员管理、车辆管…

c++ 静态联编+动态联编 (多态)

静态多态 动态多态 1)静态多态和动态多态的区别就是函数地址是早绑定(静态联编)还是晚绑定(动态联编)。 如果函数的调用,在编译阶段就可以确定函数的调用地址,并产生代码,就是静态多态(编译时多态),就是说地址是早绑定…

HTML+CSS+JavaScript制作电子时钟

一 效果展示 二 步骤 在网上下载0-9的jpg图片,将其复制粘贴到项目images文件中,注意,图片的命名一定是数字形式,例如:1.jpg,风景图也是自行下载然后粘贴到相应的文件。 三 代码实现…

【Docker】添加指定用户到指定用户组

运行Docker ps命令,报错:/v1.24/containers/json": dial unix /var/run/docker.sock: connect: permission denied 创建docker用户组 安装docker时默认已经创建好 sudo groupadd docker添加用户加入docker用户组 此处以用户user为例 sudo usermo…

生意不好做?不妨去“私域”找找机会

站在2023年的尾巴上向前看,零售从业者们心里都有同样的疑问:2024年消费还能好么?增长的机会又在哪里? “我会说:要有信心,消费行业永远年轻。”经济学家香帅在企业微信举办的“2023实干企业家峰会消费专场…

六、从0开始卷出一个新项目瑞萨RZN2L之loader app分离工程优化

六、loader app分离工程 6.1 概述 6.2 官方资料与不足 6.3 loader app分离工程的优化 6.3.1 自动调节合并appsection 6.3.2 loader中使用外设 6.3.3 app使用sram mirror 6.3.4 sram atcm同时使用 六、从0开始卷出一个新项目之瑞萨RZN2L loader…

决心解开软光栅的心结

最近几天离职在家,是的,还没回老家.白天周中的时候写这个软光栅化渲染器.包括在上班的最后项目大家都不干活的时候我已经开始写了.到今天上午总算是有的看了.细节还差很多,下午把透视校正插值加上,下午加不完就元旦假期之后再说(元旦我要写pbrt的读书笔记).还有摄像机裁剪,背面…

【Vue2 + ElementUI】el-table中校验表单

一. 案例 校验金额 阐述&#xff1a;校验输入的金额是否正确。如下所示&#xff0c;点击【编辑图标】会变为input输入框当&#xff0c;输入金额。当输入框失去焦点时&#xff0c;若正确则调用接口更新金额且变为不可输入状态&#xff0c;否则返回不合法金额提示 <templat…

proE各版本安装指南

下载链接 https://pan.baidu.com/s/1BSaJxvPPGeIa4YKm7xk57g?pwd0531 1.鼠标右击【Proe5.0M280(64bit)】压缩包&#xff08;win11及以上系统需先点击“显示更多选项”&#xff09;选择【解压到 Proe5.0M280(64bit)】&#xff08;解压的路径中不能有中文&#xff09;。 2.打开…

BIT-666 的 2023 年度总结

<<< 年度总结 >>> <<< 年度数据 >>> ◆ 发博情况 ◆ 学习成就 ◆ 代码提交 ◆ 博文表现 <<< 年度创作 >>> ◆ LLM - LLaMA2 <<< 年度风景 >>> ◆ 春 - 中关村软件园 - 百望山 ◆ 夏 - 乌兰…

Gamma LUT PG285笔记

1 gamma校正应用背景 探测器响应为线性亮度或RGB值&#xff0c;而显示器并非线性&#xff0c;需要算法做校正。 2 reg 可以配置3张LUT表&#xff0c;每张表最大1024个16bit参数。表中0x0800仅是第一张表的起始地址&#xff0c;地址每次加4。 3 数据输入的格式 按照RBG的顺序…

springboot参数校验常用注解及分组校验

一、使用方式添加Validated 二、常见注解 Null 被注解的元素必须为null NotNull 被注解的元素必须不为null NotBlank 只能作用在接收的 String 类型上&#xff0c;注意是只能&#xff0c;不能为 null&#xff0c;而且调用 trim() 后&#xff0c;长度必须大于 0即&#xff…

Solana 与 DePIN 的双向奔赴,会带来 DePIN 之夏吗?

作者&#xff1a;LBank Labs 研究员 F.F 编译&#xff1a;TinTinLand 原文&#xff1a;https://medium.com/lbanklabs/new-anchor-of-solana-depin-b674d04d6980 太长不看版 在过去的一年里&#xff0c;我们观察到 Solana 和 DePIN 两者都呈现出了显著的增长。这不仅是极客科…

Ubuntu22.04-安装后Terminal无法调出

参考&#xff1a; Ubuntu20.04 终端打开不了的问题排查_ubuntu终端打不开-CSDN博客 https://blog.csdn.net/u010092716/article/details/130968032 Ubuntu修改locale从而修改语言环境_ubuntu locale-CSDN博客 https://blog.csdn.net/aa1209551258/article/details/81745394 问…

为什么ChatGPT采用SSE协议而不是Websocket?

在探索ChatGPT的使用过程中&#xff0c;我们发现GPT采用了流式数据返回的方式。理论上&#xff0c;这种情况可以通过全双工通信协议实现持久化连接&#xff0c;或者依赖于基于EventStream的事件流。然而&#xff0c;ChatGPT选择了后者&#xff0c;也就是本文即将深入探讨的SSE&…