多层感知机

目录

一、感知机

1、相关概念介绍

2、(单层)感知机存在的问题

3、总结

二、多层感知机(MLP)

1、多层感知机思路

2、激活函数

3、常见的激活函数

4、多类分类

4、总结

三、多层感知机从零开始实现

1、读取数据集

2、初始化模型参数

3、激活函数

4、定义模型

5、定义损失函数

6、定义优化算法

7、训练

四、多层感知机的简洁实现

1、定义模型并初始化模型参数

2、定义损失函数和优化算法

3、读取数据集并训练


一、感知机

1、相关概念介绍

       感知机(perceptron)是二分类的线性分类模型,属于监督学习算法。输入为实例的特征向量,输出为实例的类别(取+1和-1)。

2、(单层)感知机存在的问题

       感知机不能拟合XOR函数,它只能产生线性分割面。

3、总结

二、多层感知机(MLP)

1、多层感知机思路

       多层感知机思路就是将一次不能完成学习的东西分多次进行学习。针对XOR问题就是先学一个x,再学一个y,然后再将这两次学习的结果组合起来。

2、激活函数

       激活函数不能是线性的,一定要使用非线性的激活函数,否则会造成模型坍塌。

例如,如果激活函数为:

$ \sigma \left( x \right) =x $

那么输出仍然是线性函数,其实仍然相当于一个单层的感知机

3、常见的激活函数

       ReLU函数是其中最常用的函数,因为其计算起来很快,相比于Sigmoid函数和Tanh函数来讲没有指数运算。因为指数运算较慢,因此有时候使用ReLU激活函数节约算力。

4、多类分类

4、总结

三、多层感知机从零开始实现

       我们已经在前面中描述了多层感知机(MLP),现在让我们尝试自己实现一个多层感知机。为了与之前softmax回归(Softmax回归-CSDN博客)获得的结果进行比较,我们将继续使用Fashion-MNIST图像分类数据集。

import torch
from torch import nn
from d2l import torch as d2l

1、读取数据集

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

2、初始化模型参数

       回想一下,Fashion-MNIST中的每个图像由$28 \times 28 = 784$个灰度像素值组成。所有图像共分为10个类别。忽略像素之间的空间结构,我们可以将每个图像视为具有784个输入特征和10个类的简单分类数据集。首先,我们将实现一个具有单隐藏层的多层感知机,它包含256个隐藏单元。注意,我们可以将这两个变量都视为超参数。通常,我们选择2的若干次幂作为层的宽度。因为内存在硬件中的分配和寻址方式,这么做往往可以在计算上更高效。

       我们用几个张量来表示我们的参数。注意,对于每一层我们都要记录一个权重矩阵和一个偏置向量。跟以前一样,我们要为损失关于这些参数的梯度分配内存。

num_inputs, num_outputs, num_hiddens = 784, 10, 256

W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens, requires_grad=True) * 0.01)    # nn.Parameter()加不加都行,主要作用是声明是torch的Parameter
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))

params = [W1, b1, W2, b2]

3、激活函数

       为了确保我们对模型的细节了如指掌,我们将自己动手实现ReLU激活函数,而不是直接调用内置的`relu`函数。

def relu(X):
    a = torch.zeros_like(X)
    return torch.max(X, a)

4、定义模型

       因为我们忽略了空间结构,所以我们使用`reshape`将每个二维图像转换为一个长度为`num_inputs`的向量。只需几行代码就可以实现我们的模型。

def net(X):
    X = X.reshape((-1, num_inputs))     # num_inputs:784  -1:batch_size
    H = relu(X @ W1 + b1)  # 这里“@”代表矩阵乘法
    return (H @ W2 + b2)

5、定义损失函数

       由于我们已经从零实现过softmax函数,因此在这里我们直接使用高级API中的内置函数来计算softmax和交叉熵损失。

loss = nn.CrossEntropyLoss(reduction='none')    # 里面内置softmax和交叉熵损失算法

6、定义优化算法

num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)

7、训练

d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

四、多层感知机的简洁实现

       下面将介绍通过高级API更简洁地实现多层感知机。

import torch
from torch import nn
from d2l import torch as d2l

1、定义模型并初始化模型参数

       与softmax回归的简洁实现相比,唯一的区别是我们添加了2个全连接层(之前我们只添加了1个全连接层)。第一层是隐藏层,它包含256个隐藏单元,并使用了ReLU激活函数。第二层是输出层。

net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.ReLU(),
                    nn.Linear(256, 10))

def init_weights(m):    # 初始化模型参数
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

2、定义损失函数和优化算法

batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)

3、读取数据集并训练

train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/245956.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[Kubernetes]2. k8s集群中部署基于nodejs golang的项目以及Pod、Deployment详解

一. 创建k8s部署的镜像 1.部署nodejs项目 (1).上传nodejs项目到节点node1 (2).压缩nodejs项目 (3).构建nodejsDockerfile 1).创建nodejsDockerfile 具体可参考:[Docker]十.Docker Swarm讲解,在/root下创建nodejsDockerfile,具体代码如下: FROM node #把压缩文件COPY到镜像的…

掌握iText:轻松处理PDF文档-高级篇-添加页眉和页脚

推荐语 本文介绍了如何使用iText编程库为PDF文档添加自定义的页眉和页脚。通过指定位置、大小、字体和颜色等属性,你可以将文本、图像或其他元素添加到每一页的固定位置,实现专业、可读的自定义页眉和页脚效果。这对于需要批量处理大量PDF文档或需要更精…

modelbox线程爆满宕机bug

序 该bug的解决需要特别感谢张同学。有了大佬的帮助,这个bug才得以解决。 问题现象 modelbox可以进行模型推理,但压测一段时间后,modelbox会宕机,并发生段错误。 “libgomp: Thread creation failed: Resource temporarily una…

TCP/IP详解——ICMP协议,Ping程序,Traceroute程序,IP源站选路选项

文章目录 一、ICMP 协议1. ICMP 概念2. ICMP 重定向3. ICMP 差错检测4. ICMP 错误报告/差错报文5. ICMP 差错报文的结构6. ICMP 源站抑制差错7. ICMP 数据包格式8. ICMP 消息类型和编码类型9. ICMP 应用-Ping10. ICMP 应用-Tracert11. BSD 对 ICMP 报文的处理12. 总结 PING 程序…

数据结构:队列

数据结构:队列 文章目录 数据结构:队列1.队列常用操作:2.队列的实现3.队列典型应用 ***「队列 queue」是一种遵循先入先出规则的线性数据结构。***队列模拟了排队现象,即新来的人不断加入队列尾部,而位于队列头部的人逐…

Visual studio+Qt开发环境搭建以及注意事项和打开qt的.pro项目

下载qt-然后安装5.14.2_msvc2017 不知道安装那个就全选5.14.2的父级按钮 https://download.qt.io/archive/qt/5.14/5.14.2/ 安装Visual studio,下载直接下一步就行 配置Visual studio的qt环境 在线安装-重启Visual studio会自动安装 离线安装-关闭Visual studio点击安装 关闭…

a16z:加密行业2024趋势“无缝用户体验”

近日,知名加密投资机构a16z发布了“Big ideas 2024”,列出了加密行业在 2024 年几个具备趋势的“大想法”,其中 Seamless UX(无缝用户体验)赫然在列。 从最为直观的理解上,Seamless UX 是在强调用户在使用产…

路由器原理

目录 一.路由器 1.路由器的转发原理 2.路由器的工作原理 二.路由表 1.路由表的形成 2.路由表表头含义 直连: 非直连: 静态 静态路由的配置 负载均衡(浮动路由) 默认路由 动态 三.交换与路由对比 一.路由器 1.路由器…

独立完成软件的功能的测试(4)

独立完成软件的功能的测试(4) (12.14)(功能测试>头条项目实战) 项目总体概述 项目背景和定位:一款汇聚科技咨询,技术文章和问答交流的用户移动终端产品,用户可以通过…

STM32在CTF中的应用和快速解题

题目给的是bin文件,基本上就是需要我们手动修复的固件逆向。 如果给的是hex文件,我们可能需要使用MKD进行动态调试 主要还是以做题为目的 详细的可以去看文档:https://pdf1.alldatasheet.com/datasheet-pdf/view/201596/STMICROELECTRONIC…

微服务学习:Gateway服务网关

一,Gateway服务网关的作用: 路由请求:Gateway服务网关可以根据请求的URL或其他标识符将请求路由到特定的微服务。 负载均衡:Gateway服务网关可以通过负载均衡算法分配请求到多个实例中,从而平衡各个微服务的负载压力。…

一入二出热电阻温度信号隔离变送器

一入二出热电阻温度信号隔离变送器 用于测量铂热电阻Pt10,Pt100,Pt1000,Cu50,Cu100的热电阻传感器的小型仪器设备。广泛应用于工业测量温度系统,是降低成本且有效的测量方式。 型号:JSD TARZ-1002系列 我们来看下有什么特点: ◆小体积&#x…

天猫数据分析平台-天猫销售数据查询软件-11月天猫平台冲锋衣市场销售运营数据分析

随着气温逐渐下降,保暖服饰迎来热销,冲锋衣的需求大增。如今冲锋衣已经不仅仅是户外运动的装备,还成为很多年轻人的日常穿搭和时尚的追求。 新的穿搭趋势也带来了巨大的市场机会。据公开数据显示,中国有冲锋衣生产及经营企业超过8…

竞赛保研 python+opencv+机器学习车牌识别

0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 基于机器学习的车牌识别系统 🥇学长这里给一个题目综合评分(每项满分5分) 难度系数:4分工作量:4分创新点:3分 该项目较为新颖,适…

什么是纯净IP?如何判断IP地址的纯净度?有哪些干净IP推荐?

您是否想知道什么使代理“干净”或如何确保您的代理不会将您列入网站的黑名单?对于通过代理访问网络的人来说,干净的代理是无缝在线体验的重要组成部分。在这篇文章中,我们将深入研究干净代理的世界,并探讨决定其质量的因素。 一、…

k8s常用命令及示例(三):apply 、edit、delete

k8s常用命令及示例(三):apply 、edit、delete 1. kubectl apply -f 命令:从yaml文件中创建资源对象。 -f 参数为强制执行。kubectl apply和kubectl create的区别如下:kubectl create 和 kubectl apply 是 Kubernetes 中两个常用的命令&…

加速数据采集:用OkHttp和Kotlin构建Amazon图片爬虫

引言 曾想过轻松获取亚马逊上的商品图片用于项目或研究吗?是否曾面对网络速度慢或被网站反爬虫机制拦截而无法完成数据采集任务?如果是,那么本文将为您介绍如何用OkHttp和Kotlin构建一个高效的Amazon图片爬虫解决方案。 背景介绍 亚马逊&a…

Spring Boot之自定义starter

🥳🥳Welcome Huihuis Code World ! !🥳🥳 接下来看看由辉辉所写的关于Spring Boot的相关操作吧 目录 🥳🥳Welcome Huihuis Code World ! !🥳🥳 一. starter是什么 二.为什么要使…

万界星空科技AI低代码云MES系统

在企业生产管理过程中,从市场、生产现场到产品交付,生产制造行业都面临着诸多挑战,比如: 订单排产难度大:订单混乱,常漏排产、错排产;产能不明晰,无法承诺交期,常丢单&a…

智慧工地源码(微服务+Java+Springcloud+Vue+MySQL)

智慧工地系统是依托物联网、互联网、AI、可视化建立的大数据管理平台,是一种全新的管理模式,能够实现劳务管理、安全施工、绿色施工的智能化和互联网化。围绕施工现场管理的人、机、料、法、环五大维度,以及施工过程管理的进度、质量、安全三…