【深度学习与神经网络】MNIST手写数字识别1

简单的全连接层

导入相应库

import torch
import numpy as np
from torch import nn,optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

读入数据并转为tensor向量

# 训练集
# 转为tensor数据
train_dataset = datasets.MNIST(root='./',train=True, transform = transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./',train=False, transform = transforms.ToTensor(), download=True)

装载数据集

# 批次大小
batch_size = 64

# 装载训练集
train_loader = DataLoader(dataset = train_dataset, batch_size=batch_size, shuffle = True)
test_loader = DataLoader(dataset = test_dataset, batch_size=batch_size, shuffle = True)

定义网络结构
一层全连接网络,最后使用softmax转概率值输出

# 定义网络结构
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 10)
        self.softmax = nn.Softmax(dim =1)
        
    def forward(self, x):
        # [64,1,28,28] ——> [64, 784]
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.softmax(x)
        return x   

定义模型
使用均方误差损失函数,梯度下降优化

# 定义模型
model = Net()
mes_loss = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),0.5)

训练并测试网络:
训练时注意最后输出(64,10)
标签是(64) ,需要将其转为one-hot编码(64,10)

def train():
    for i,data in enumerate(train_loader):
        # 获得一个批次的数据和标签
        inputs, labels = data
        # 获得模型结果 (64,10)
        out = model(inputs)
        # to one-hot 把数据标签变为独热编码
        labels = labels.reshape(-1,1)
        one_hot = torch.zeros(inputs.shape[0],10).scatter(1, labels, 1)
        # 计算loss
        loss = mes_loss(out, one_hot)
        # 梯度清0
        optimizer.zero_grad()
        # 计算梯度
        loss.backward()
        # 修改权值
        optimizer.step()
        
def test():
    correct = 0
    for i,data in enumerate(test_loader):
        # 获得一个批次的数据和标签
        inputs, labels = data
        # 获得模型结果 (64,10)
        out = model(inputs)
        # 获取最大值和最大值所在位置
        _,predicted = torch.max(out,1)
        # 预测正确数量
        correct += (predicted == labels).sum()
        
        
    print("test ac:{0}".format(correct.item()/len(test_dataset)))
        
  

调用模型 训练10次

# 使用mse损失函数 
for epoch in range(10):
    print("epoch:",epoch)
    train()
    test()

训练结果:
在这里插入图片描述
准确率不够

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/471359.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

深度学习之本地部署大模型ChatGLM3-6B【大模型】【报错】

文章目录 0.前言1.模型下载2.配置环境2.1 下载项目文件2.2 配置环境 3.开始推理4.总结 0.前言 本博客将介绍ChatGLM3-6B大模型在Ubuntu上的本地部署教程 1.模型下载 由于毛毛张的服务器服务无法科学上网,所以模型的相关文件必须现在本地下载好,再上传到…

【ai技术】(1):发现一个大模型可视化项目,使用nodejs编写的,llm-viz,可以本地运行展示大模型结构。

1,关于项目 https://www.bilibili.com/video/BV1eF4m1c7NC/ 【ai技术】(1):发现一个大模型可视化项目,使用nodejs编写的,llm-viz,可以本地运行展示大模型结构。 https://github.com/bbycroft/l…

k8s集群架构维护k8s集群以及搭建k8s集群以及k8s集群的常见问题

一、k8s架构 Kubernetes(K8s)是一个由Google主导开发的开源容器编排平台,用于自动化部署、扩展和管理容器化应用程序。它的设计目标是简化容器化应用程序在生产环境中的部署和运营。Kubernetes的架构设计复杂且高效,主要包括以下几…

再谈EMC Unity存储系统内存DIMM问题

以前写过一篇关于EMC Unity 存储系统的DIMM的介绍文章,但是最近还是遇到很多关于内存的问题,还有一些退货,所以有必要再写一篇关于EMC Unity 内存方面的问题,供朋友们参考。如果还有疑问,可以加vx:StorageE…

v-bind 绑定 class 与 style 基础用法

使用 v-bind 指令绑定 class 和 style 时语法相对复杂一些,这两者是可以互相替代的,均用于响应更新HTML元素的属性, v-bind 绑定 class 属性可以改写成绑定 style 属性,只是 css 属性位置变了而已。 1. 绑定 class 属性 1.1 数组…

MySQL 数据库设计范式

第一范式(1NF) 每一列都是不可分割的原子数据项第二范式(2NF) 在1NF的基础上,非码属性必须完全依赖于候选码(在1NF基础上消除非主属性对主码的部分函数依赖) 1.函数依赖A->B,如果通过A属性(属性组)的值…

蓝桥杯 2023 省B 飞机降落

首先&#xff0c;这题要求的数据量比较少&#xff0c;我们可以考虑考虑暴力解法。 这题可能难在很多情况的考虑&#xff0c;比如说&#xff1a; 现在时间是10&#xff0c;有个飞机20才到&#xff0c;我们是可以干等10分钟。 #include <iostream> #include <…

编织效率之梦:Visual Studio与Windows快捷键指南

个人主页&#xff1a;日刷百题 系列专栏&#xff1a;〖C/C小游戏〗〖Linux〗〖数据结构〗 〖C语言〗 &#x1f30e;欢迎各位→点赞&#x1f44d;收藏⭐️留言&#x1f4dd; ​ ​ 前言&#xff1a; 常用快捷键整理 (用加粗标注的是我个人使用时常用的&#xff0c;其实这个…

2024-03-19 作业

作业要求&#xff1a; 1> 将白天课堂代码重新实现一遍&#xff1a; select实现的TCP并发服务器 poll实现的TCP客户端 2> 君子作业 select实现的TCP客户端 poll实现的TCP并发服务器 作业1&#xff1a; 运行代码&#xff1a; select实现的TCP并发服务器 #include<myh…

[项目设计]基于websocket实现网络对战五子棋

项目介绍 该项目旨在实现一个网页端的在线五子棋&#xff0c;将实现登陆、好友、房间、对战、观战、聊天等功能 完成该项目需要了解C、数据库MySQL、基础前端HTML/CSS/JS/Ajax、网络协议WebSocket 项目源码&#xff1a;azhe1/online_gobang - 码云 - 开源中国 (gitee.com) …

国产-高精度、可编程数字温度传感芯片-MY18E20

由工采代理的MY18E20是一款国产高精度可编程的数字模拟混合信号温度传感芯片&#xff1b;感温原理基于CMOS半导体PN节温度与带隙电压的特性关系&#xff0c;经过小信号放大、模数转换、数字校准补偿后&#xff0c;数字总线输出&#xff0c;具有精度高、一致性好、测温快、功耗低…

vue前端解析jwt

vue前端解析jwt 我们可以用在线解析看解析的结果&#xff1a;https://www.lddgo.net/encrypt/jwt-decrypt 但是如果在前端需要解析token&#xff0c;拿到其中的权限信息&#xff0c;可以这样解决。 在线的&#xff1a; 完美解决&#xff1a; 代码&#xff1a; function par…

MySQL下载和安装部署

4.1 简介 MySQL是一个关系型数据库管理系统&#xff0c;由瑞典MySQL AB 公 司开发&#xff0c;现在已经属于 Oracle 旗下产品。MySQL 是最流行的关系型数据 库管理系统之一&#xff0c;在 WEB 应用方面&#xff0c;MySQL是最好的 RDBMS(Relational Database Management System&…

AI系统性学习—LangChain入门

文章目录 1、LangChain入门1.1 简介1.2 架构1.3 核心概念1.2 快速入门1.3 安装 2、LangChain Prompt Template2.1 什么是提示词模版2.1 创建一个提示词模版2.2 聊天消息提示词模版2.3 模版追加示例 3、语言模型3.1 LLM基础模型3.2 LangChain聊天模型3.3 自定义模型3.4 输出解析…

linux网络服务学习(2):vsftp

1.什么是vsftp vsftp是linux服务器上的一款使用ftp协议的软件&#xff0c;是linux上使用最广泛的ftp服务端软件 ftp协议是使用明文传输的&#xff0c;很不安全&#xff0c;一般用于局域网内的文件上传、下载 2.vsftp连接类型 ftp连接要用到2个端口&#xff1a;21、20端口。…

高效备考2024年AMC10:吃透2000-2023年1250道真题(限时免费送)

我们今天继续来随机看5道AMC10真题&#xff0c;以及详细解析&#xff0c;这些题目来自1250道完整的官方历年AMC10真题库。通过系统研究和吃透AMC10的历年真题&#xff0c;参加AMC10的竞赛就能拿到好名次。即使不参加AMC10竞赛&#xff0c;掌握了这些知识和解题思路后初中和高中…

阿里巴巴求职者必看:@SpringMVC?面试准备全攻略!

如有疑问或者更多的技术分享,欢迎关注我的微信公众号“知其然亦知其所以然”! 大家好,我是小米!今天我们来聊聊阿里巴巴面试中常见的一个热门话题:@SpringMVC!如果你对这个话题感兴趣,那就跟着我一起来揭开这个技术的神秘面纱吧! @Controller 在SpringMVC中,我们经…

Java-seata 头参数透传问题步骤详解-arthas

seata分布式事物下游不能回滚的问题; 初步分析headers中TX_XID 没有传给下游系统 通过拦截器打印上游服务日志和下游服务日志打印&#xff0c;影响上游服务不能传header 中自定意义参数的地方是启用线程的熔断策略。 feign:hystrix:enabled: false #不启用client: config:def…

C数据类型(C语言)---变量的类型决定了什么?

目录 数据类型&#xff08;Data Type&#xff09; 变量的类型决定了什么&#xff1f; &#xff08;1&#xff09;不同类型数据占用的内存大小不同 如何计算变量或类型占内存的大小 &#xff08;2&#xff09;不同数据类型的表数范围不同 &#xff08;3&#xff09;不同类型…

Jmeter-基础元件使用(二)

一、Jmeter属性 当我们想要在不同线程组中使用某变量&#xff0c;就需要使用属&#xff0c;此时Jmeter属性的设置需要函数来进行set和get操作 1.创建set函数 2.然后采用Beanshell取样器进行函数执行 3.调用全局变量pro_id 4.将上面生成的函数字符串粘贴到另一个线程组即可…