PyTorch官网demo解读——第一个神经网络(1)

神经网络如此神奇,feel the magic

今天分享一下学习PyTorch官网demo的心得,原来实现一个神经网络可以如此简单/简洁/高效,同时也感慨PyTorch如此强大。

这个demo的目的是训练一个识别手写数字的模型!

先上源码:
from pathlib import Path
import requests   # http请求库
import pickle
import gzip

from matplotlib import pyplot   # 显示图像库

import math
import numpy as np
import torch

###########下载训练/验证数据######################################################
# 这里加载的是mnist数据集
DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
PATH.mkdir(parents=True, exist_ok=True)

URL = "https://github.com/pytorch/tutorials/raw/main/_static/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)


###########解压并加载训练数据######################################################
with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")


# 通过pyplot显示数据集中的第一张图片
# 显示过程会中断运行,看到效果之后可以屏蔽掉,让调试更顺畅
#print("x_train[0]: ", x_train[0])
#pyplot.imshow(x_train[0].reshape((28, 28)), cmap="gray")
#pyplot.show()


# 将加载的数据转成tensor
x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape   # n是函数,c是列数
print("x_train.shape: ", x_train.shape)
print("y_train.min: {0}, y_train.max: {1}".format(y_train.min(), y_train.max()))


# 初始化权重和偏差值,权重是随机出来的784*10的矩阵,偏差初始化为0
weights = torch.randn(784, 10) / math.sqrt(784)
weights.requires_grad_()
bias = torch.zeros(10, requires_grad=True)

# 激活函数
def log_softmax(x):
    return x - x.exp().sum(-1).log().unsqueeze(-1)

# 定义模型:y = wx + b
# 实际上就是单层的Linear模型
def model(xb):
    return log_softmax(xb @ weights + bias)


# 丢失函数 loss function
def nll(input, target):
    return -input[range(target.shape[0]), target].mean()
loss_func = nll

# 计算精度函数
def accuracy(out, yb):
    preds = torch.argmax(out, dim=1)
    return (preds == yb).float().mean()

###########开始训练##################################################################
bs = 64  # 每一批数据的大小
lr = 0.5  # 学习率
epochs = 2  # how many epochs to train for

for epoch in range(epochs):
    for i in range((n - 1) // bs + 1):
        start_i = i * bs
        end_i = start_i + bs
        xb = x_train[start_i:end_i]
        yb = y_train[start_i:end_i]
        pred = model(xb) # 通过模型预测
        loss = loss_func(pred, yb) # 通过与实际结果比对,计算丢失值

        loss.backward() # 反向传播
        with torch.no_grad():
            weights -= weights.grad * lr  # 调整权重值
            bias -= bias.grad * lr  # 调整偏差值
            weights.grad.zero_()
            bias.grad.zero_()

##########对比一下预测结果############################################################
xb = x_train[0:bs]  # 加载一批数据,这里用的是训练的数据,在实际应用中最好使用没训练过的数据来验证
yb = y_train[0:bs]  # 训练数据对应的正确结果
preds = model(xb)  # 使用训练之后的模型进行预测
print("################## after training ###################")
print("accuracy: ", accuracy(preds, yb))   # 打印出训练之后的精度
# print(preds[0])
print("pred value: ", torch.argmax(preds, dim=1))   # 打印预测的数字
print("real value: ", yb)   # 实际正确的数据,可以直观地和上一行打印地数据进行对比
运行结果:

可以看到训练后模型地预测精度达到了0.9531,已经不错了,毕竟只使用了一个单层地Linear模型;从输出地对比数据中可以看出有三个地方预测错了(红框标记地数字)

ok,今天先到这里,下一篇再来解读代码中地细节

附:

PyTorch官方源码:https://github.com/pytorch/tutorials/blob/main/beginner_source/nn_tutorial.py

天地一逆旅,同悲万古愁!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/250862.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

crmeb v5新增一个功能的完整示例记录

首先,需求 工作中的二开需求是这样的,修改首页的装修,并新增回收报价的功能 开始动手 第一步,我们要到后面的管理界面,去装修中修改首面的展示 首页的页面配置好之后,就要在 前端的展示程序中 配置相…

105基于matlab的阶次分析算法

基于matlab的阶次分析算法,用于变转速机械故障特征提取,可运行,包含寻找脉冲时刻,等角度时刻。数据可更换自己的,程序已调通,可直接运行。 105阶次分析变转速信号处理 (xiaohongshu.com)

二十七、读写文件

二十七、读写文件 27.1 文件类QFile #include <QCoreApplication>#include<QFile> #include<QDebug>int main(int argc, char *argv[]) {QCoreApplication a(argc, argv);QFile file("D:/main.txt");if(!file.open(QIODevice::WriteOnly | QIODe…

Arrays.asList()方法:陷阱与解决之道

在Java编程中&#xff0c;Arrays类提供了一系列用于操作数组的实用方法。其中&#xff0c;​Arrays.asList()​方法是一个常用的方法&#xff0c;用于快速将数组转换为List集合。然而&#xff0c;这个方法存在一些潜在的陷阱&#xff0c;可能导致出现意外的行为。本文将介绍​A…

数据可视化---柱状图

import matplotlib.pyplot as plt import numpy as npdef plot_bar_chart(data, labels, colorsNone, title"Bar Chart", xlabel"X-Axis", ylabel"Y-Axis"):"""绘制柱状图&#xff0c;并在柱子上显示数量和比例。:param data: 包…

【C++11特性篇】C++11中新增的initializer_list——初始化的小利器(2)

前言 大家好吖&#xff0c;欢迎来到 YY 滴C11系列 &#xff0c;热烈欢迎&#xff01; 本章主要内容面向接触过C的老铁 主要内容含&#xff1a; 欢迎订阅 YY滴C专栏&#xff01;更多干货持续更新&#xff01;以下是传送门&#xff01; 目录 一.探究std::initializer_list是什么…

OceanBase 4.2.1社区版 最小资源需求安装方式

OceanBase 4.2.1社区版 最小资源需求安装方式 资源需求 资源需求分析 observer Memory 控制参数&#xff1a; memory_limit_percentage 默认80% memory_limit 直接设定observer Memory 大小 System memory 可设为1G 租户内存:sys租户内存设为1G&#xff0c;OCP需要的租户oc…

Mac brew install youtube-dl 【 youtube 下载工具:youtube-dl 安装】

文章目录 1. 简介2. 预备3. 安装4. 命令5. 测试 1. 简介 youtube-dl - 从youtube.com或其他视频平台下载视频 https://github.com/ytdl-org/youtube-dl 2. 预备 安装并配置 git安装 brew 3. 安装 MacBook-Pro ~ % brew install youtube-dl Warning: youtube-dl has been …

【Java系列】详解多线程(三)—— 线程安全(上篇)

个人主页&#xff1a;兜里有颗棉花糖 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 兜里有颗棉花糖 原创 收录于专栏【Java系列专栏】【JaveEE学习专栏】 本专栏旨在分享学习Java的一点学习心得&#xff0c;欢迎大家在评论区交流讨论&#x1f48c; 目录 一…

基于EasyExcel的数据导入导出

前言&#xff1a; 代码复制粘贴即可用&#xff0c;主要包含的功能有Excel模板下载、基于Excel数据导入、Excel数据导出。 根据实际情况修改一些细节即可&#xff0c;最后有结果展示&#xff0c;可以先看下结果&#xff0c;是否是您想要的。 台上一分钟&#xff0c;台下60秒&a…

OSG中几何体的绘制(二)

5. 几何体操作 在本章的前言中就讲到&#xff0c;场景都是由基本的绘图基元构成的&#xff0c;基本的绘图基元构成简单的几何体,简单的几何体构成复杂的几何体&#xff0c;复杂的几何体最终构造成复杂的场景。当多个几何体组合时&#xff0c;可能存在多种降低场景渲染效率的原因…

SQL进阶理论篇(七):B+树的查询及存储机制

文章目录 简介数据库中的存储结构数据库中的页结构从数据页来看B树的查询过程总结参考文献 简介 我们之前已经了解过数据库的B树索引和Hash索引&#xff0c;这些索引信息以及数据记录都是保存在文件里的&#xff0c;确切的说是存储在页结构中。 本节&#xff0c;从我们将了解…

【LeetCode刷题】-- 161.相隔为1的编辑距离

161.相隔为1的编辑距离 方法&#xff1a;一次遍历 首先&#xff0c;我们要确认字符串的长度不会相差太远。如果长度差了2个或更多字符&#xff0c;那么 s 和 t 就不可能是一次编辑之差的字符串。 接下来&#xff0c;我们假设 s 的长度总是短于或等于 t 的长度。如果不是这样&…

cesium 自定义贴图,shadertoy移植教程。

1.前言 cesium中提供了一些高级的api&#xff0c;可以自己写一些shader来制作炫酷的效果。 ShaderToy 是一个可以在线编写、测试和分享图形渲染着色器的网站。它提供了一个图形化的编辑器&#xff0c;可以让用户编写基于 WebGL 的 GLSL 着色器代码&#xff0c;并实时预览渲染结…

大数据存储技术(3)—— HBase分布式数据库

目录 一、HBase简介 &#xff08;一&#xff09;概念 &#xff08;二&#xff09;特点 &#xff08;三&#xff09;HBase架构 二、HBase原理 &#xff08;一&#xff09;读流程 &#xff08;二&#xff09;写流程 &#xff08;三&#xff09;数据 flush 过程 &#xf…

配置802.1x认证

实验目的&#xff1a; 某公司拥有两个部门&#xff0c;市场部和人事部门&#xff0c;市场部和人事部的IP地址分别为10.1.11.0/24、10.1.21.0/24两个IP网段。市场部属于vlan11&#xff0c;人事部属于vlan21。现在需要在SW2上配置802.1x认证&#xff0c;实现终端用于只有认证成功…

BKP 备份寄存器 RTC 实时时钟-stm32入门

这一章节我们要讲的主要内容是 RTC 实时时钟&#xff0c;对应手册&#xff0c;是第 16 章的位置。 实时时钟这个东西&#xff0c;本质上是一个定时器&#xff0c;但是这个定时器&#xff0c;是专门用来产生年月日时分秒&#xff0c;这种日期和时间信息的。所以学会了 STM32 的…

Flink系列之:SQL提示

Flink系列之&#xff1a;SQL提示 一、动态表选项二、语法三、例子四、查询提示五、句法六、加入提示七、播送八、随机散列九、随机合并十、嵌套循环十一、LOOKUP十二、进一步说明十三、故障排除十四、连接提示中的冲突案例十五、什么是查询块 SQL 提示可以与 SQL 语句一起使用来…

【卡塔尔世界杯数据可视化与新闻展示】

卡塔尔世界杯数据可视化与新闻展示 前言数据获取与处理可视化页面搭建功能实现新闻信息显示详情查看登录注册评论信息管理 创新点结语 前言 随着卡塔尔世界杯的临近&#xff0c;对于足球爱好者来说&#xff0c;对比赛的数据分析和新闻报道将成为关注的焦点。本文将介绍如何使用…

腾讯云优惠全站搜——云服务器配置大全精准

腾讯云推出优惠全站搜页面 https://curl.qcloud.com/PPrF9NFe 在这个页面可以一键查询所需云服务器、轻量应用服务器、数据库、存储、CDN、网络、安全、大数据等云产品优惠活动大全&#xff0c;活动打开如下图&#xff1a; 腾讯云优惠全站搜——优惠合集 如上图&#xff0c;在这…