3-2 多层感知机的从零开始实现

import torch
from torch import nn
from d2l import torch as d2l

batch_size = 256 # 批量大小为256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
# load进来训练集和测试集

初始化模型参数

回想一下,Fashion-MNIST中的每个图像由 28 × 28 = 784 28 \times 28 = 784 28×28=784个灰度像素值组成。 所有图像共分为 10 10 10个类别。 忽略像素之间的空间结构, 我们可以将每个图像视为具有 784 784 784个输入特征 和 10 10 10个类的简单分类数据集。
首先,我们将实现一个具有单隐藏层的多层感知机, 它包含 256 256 256个隐藏单元。 注意,我们可以将这两个变量都视为超参数。 通常,我们选择 2 2 2的若干次幂作为层的宽度。 因为内存在硬件中的分配和寻址方式,这么做往往可以在计算上更高效
我们用几个张量来表示我们的参数。 注意,对于每一层我们都要记录一个权重矩阵和一个偏置向量。 跟以前一样,我们要为损失关于这些参数的梯度分配内存。

num_inputs, num_outputs, num_hiddens = 784, 10, 256

W1 = nn.Parameter(torch.randn(
    num_inputs, num_hiddens, requires_grad=True) * 0.01)
# 初始成一个随机的,行数是num_inputs输入的个数,列数是num_hiddens,requires_grad=True需要算梯度
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
# 偏差是一个隐藏层的个数,是一个向量,初始化为全0
W2 = nn.Parameter(torch.randn(
    num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))

params = [W1, b1, W2, b2]
# 第一层参数 w1和b1
# 第二层参数 w2和b2

激活函数

为了确保我们对模型的细节了如指掌, 我们将实现ReLU激活函数, 而不是直接调用内置的relu函数。

def relu(X):
    a = torch.zeros_like(X)
    # 产生一个矩阵a,a的shape与X一模一样,但是元素全为0
    return torch.max(X, a)

模型

因为我们忽略了空间结构, 所以我们使用reshape将每个二维图像转换为一个长度为num_inputs的向量。 只需几行代码就可以实现我们的模型。

def net(X):
    X = X.reshape((-1, num_inputs))
    # 把X拉成一个二维的矩阵
    # -1 让它自己算,实际是batch size
    H = relu(X@W1 + b1)  # 这里“@”代表矩阵乘法
    return (H@W2 + b2)

损失函数

由于我们已经从零实现过softmax函数, 因此在这里我们直接使用高级API中的内置函数来计算softmax和交叉熵损失。

loss = nn.CrossEntropyLoss(reduction='none')

训练

多层感知机的训练过程与softmax回归的训练过程完全相同。 可以直接调用d2l包的train_ch3函数, 将迭代周期数设置为10,并将学习率设置为0.1.

num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

请添加图片描述
相比softmax,我的损失更低了!我的精度其实没有发生太大变化。
因为我的模型更大了,所以我的数据拟合性更好,所以我的损失在下降。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/796823.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

一文了解5G新通话技术演进与业务模型

5G新通话简介 5G新通话,也被称为VoNR,是基于R16及后续协议产生的一种增强型语音通话业务。 它在IMS网络里新增数据通道(Data Channel),承载通话时的文本、图片、涂鸦、菜单等信息。它能在传统话音业务基础上提供更多服…

红日靶场----(三)1.漏洞利用

上期已经信息收集阶段已经完成,接下来是漏洞利用。 靶场思路 通过信息收集得到两个吧靶场的思路 1、http://192.168.195.33/phpmyadmin/(数据库的管理界面) root/root 2、http://192.168.195.33/yxcms/index.php?radmin/index/login&am…

练习 6.7:⼈们 在为练习 6.1 编写的程序中,再创建两个表⽰⼈的字典,然后将这三个字典都存储在⼀个名为 people 的列表中。

练习 6.7:⼈们 在为练习 6.1 编写的程序中,再创建两个表⽰⼈的字典,然后将这三个字典都存储在⼀个名为 people 的列表中。 要求 遍历这个列表,将其中每个⼈的所有信息都打印出来。 代码 human {shuicc: {first_name: shui,la…

linux下安装cutecom串口助手;centos安装cutecom串口助手;rpm安装包安装cutecom串口助手

在支持apt-get的系统下安装 在终端命令行中输入: sudo apt-get install cutecom 安装好后输入 sudo cutecom 就可以了 关于如何使用,可以看这个https://www.cnblogs.com/xingboy/p/14388610.html 如果你的电脑不支持apt-get。 那我们就通过安装包…

druid(德鲁伊)数据线程池连接MySQL数据库

文章目录 1、druid连接MySQL2、编写JDBCUtils 工具类 1、druid连接MySQL 初学JDBC时,连接数据库是先建立连接,用完直接关闭。这就需要不断的创建和销毁连接,会消耗系统的资源。 借鉴线程池的思想,数据连接池就这么被设计出来了。…

【Unity2D 2022:NPC】制作任务系统

一、接受任务 1. 编辑NPC对话脚本: (1)创建静态布尔变量用来判断ruby是否接受到任务 public class NPCDialog : MonoBehaviour {// 创建全局变量用来判断ruby是否接到任务public static bool receiveTask false; } (2&#xff…

git自动pull同步远程若干分支与本地若干分支

git自动pull同步远程若干分支与本地若干分支 假设远程代码仓库有100个分支,而本地只有10个本地分支与远程分支一一对应,现在要保持本地的这个10个分支与远程一致,最笨的方法是checkout到每个分支,然后一个一个的 git pull origin…

展望未来:在【PyCharm】中结合【机器学习】实现高效的图形化处理

欢迎来到 破晓的历程的 博客 ⛺️不负时光,不负己✈️ 文章目录 引言一、PyCharm简介与配置1.1 PyCharm基础1.2 图形化库配置 二、机器学习项目中的数据可视化2.1 数据加载与预处理2.2 数据探索性可视化2.3 模型训练与结果可视化 三、PyCharm中的图形化调试四、高级…

【香橙派 Orange pi AIpro】| 开发板深入使用体验

目录 一. 🦁 写在前面二. 🦁 愉快的安装流程2.1 安装前准备2.2 流程准备2.2.1 烧录镜像2.2.2 开机2.2.3 连网2.2.4 SSH远程连接开发板 2.3 体验 AI 应用样例 三. 🦁 写在最后 一. 🦁 写在前面 大家好,我是狮子呀&…

【kubernetes】Helm包管理器基本概念与Chart实战

概念:基础架构与常用命令 三个重要概念: 1.chart 创建Kukernetes应用程序所必需的一组信息。 2.config 包含了可以合并到打包的chart中的配置信息,用于创建一个可发布的对象。 3.release 是一个与特走配置相结合的chart的运行实例。 常用命…

二叉搜索树大冒险:寻找-插入-删除

OK,看我们题目就可知道啦,今天要分享学习的一种数据结构就是二叉搜索树。 内容题目也说了三个大概的,分别是寻找、插入、删除。 讲这个之前呢,那么就先讲讲这个二叉搜索树是何方神圣呢? 二叉搜索树: 又…

包成功安装tiny-cuda-nn,记录安装过程中的问题解决,附带pytorch3d安装【踩坑指南】

tiny-cuda-nn安装过程中的问题解决,附带pytorch3d安装【踩坑指南】 前言tiny-cuda-nn第一种下载方法:命令行安装tiny-cuda-nn第二种下载方法:本地编译 pytorch3d安装 前言 official repo: https://github.com/NVlabs/tiny-cuda-nn 该包可以显…

酷克数据亮相第13届PostgreSQL中国技术大会,获数据库杰出贡献奖

7 月 12 日,第 13 届 PostgreSQL 中国技术大会在杭州盛大开幕。本次大会以“聚焦云端创新,汇聚智慧共享”为主题,邀请了国内外 PG 领域众多行业大咖、学术精英及技术专家,共同探讨数据库领域的发展趋势、技术创新和实践经验。酷克…

计算机的错误计算(二十九)

摘要 (1)讨论近似值的错误数字个数。有时,遇到数字9或0, 不太好确认近似值的错误数字个数。(2)并进一步解释确认计算机的错误计算(二十八)中一个函数值的错误数字个数。 理论上,我…

《Python零基础入门》——关于PyCharm使用技巧及python基本概念

从本次文章开始,我们将学习一门新的编程语言——Python。作为最热门的编程语言,Python相对比较清晰、简单。 python主要的编译工具就是pycharm,关于pycharm的安装及python配置环境,大家可自行参考网络上的教程,本文不…

借人工智能之手,编织美妙歌词篇章

在音乐的领域中,歌词宛如璀璨的明珠,为旋律增添了无尽的魅力和情感深度。然而,对于许多创作者来说,编织出美妙动人的歌词并非易事。但如今,随着科技的飞速发展,人工智能为我们带来了全新的创作可能。 “妙…

【C++深度探索】全面解析多态性机制(二)

🔥 个人主页:大耳朵土土垚 🔥 所属专栏:C从入门至进阶 这里将会不定期更新有关C/C的内容,欢迎大家点赞,收藏,评论🥳🥳🎉🎉🎉 前言 我…

TEB局部路径规划算法代码及原理解读

TEB(Timed Elastic Band) 是一个基于图优化的局部路径规划算法,具有较好的动态避障能力,在ROS1/ROS2的导航框架中均被采用。该图优化以g2o优化框架实现,以机器人在各个离散时刻的位姿和离散时刻之间的时间间隔为顶点,约束其中的加…

MUR2060CTR-ASEMI无人机专用MUR2060CTR

编辑:ll MUR2060CTR-ASEMI无人机专用MUR2060CTR 型号:MUR2060CTR 品牌:ASEMI 封装:TO-220 批号:最新 最大平均正向电流(IF):20A 最大循环峰值反向电压(VRRM&#…

tkinter-TinUI-xml实战(12)pip可视化管理器

引言 pip命令行工具在平常使用方面确实足够简单,本项目只是作为TinUI多界面开发的示例。 当然,总有人想用GUI版pip,实际上也有。不过现在,我们就来手搓一个基于python和TinUI(tkinter)的pip可视化管理器。…