动手学深度学习——多层感知机

1. 感知机

感知机本质上是一个二分类问题。给定输入x、权重w、偏置b,感知机输出:

以猫和狗的分类问题为例,它本质上就是找到下面这条黑色的分割线,使得所有的猫和狗都能被正确的分类。

与线性回归和softmax的不同点:

  • vs 线性回归:输出的都是一个数,但线性回归输出的是实数,而感知机输出的是离散的分类。
  • vs softmax: softmax是一个多分类(如果有n个分类,softmax就会输出n个元素),而感知机只输出一个元素。

感知机存在的问题: 它只能产生线性分割面,对于XOR(异或)函数,无法拟合(一条线不论怎么分割,都无法将绿色和红色分类正确)。

2. 多层感知机(MLP)

对于上面单层感知机的问题,一个改进思想是:一层函数如果做不了,就用多层函数来做,而多层就带来了网络,用不同层解决不同的问题,多层配合来解决更复杂的问题。

可以使用蓝线对所有数据进行x轴方向的正负分类,再使用黄线对所有数据进行y轴方向的正负分类,最后再将两次分类结果进行xor运算就能得到结果。

多层感知机使用隐藏层和激活函数来得到非线性模型。

在softmax基础上多了隐藏层。可选超参:

  • 隐藏层数
  • 每个隐藏层的宽度,通常选择2的若干次冥作为层的宽度

这两个参数的选择取决于输入和输出的复杂度

对复杂的输入,输入维度一般比较高,输出一般会比较少,有两种处理办法:

  1. 做单隐藏层,把模型做平,层的大小设大一点
  2. 做多隐藏层,把模型做深,层的大小可以设小一点,每层的维度逐步减少(如果每层维度都高,则会导致模型太大)

复杂输入到简单输出本质上是一个信息压缩的过程,多层逐步压缩能避免一次压缩太大导致信息损失太严重,例如:128->64->32->16->8
也可以先expand,从128->256->64->32->16->8

3. 激活函数

作用:在神经网络中引入非线性,可以理解为一个开关,当输入信号超过一定阀值时,神经元会被激活并产生输出,而未超过阀值时神经元将会被抑制。

在没有激活函数的情况下,神经网络只能表示线性映射,无法处理复杂的非线性关系。激活函数的作用就是线性结果映射到一个非线性的输出,以帮助神经网络更好的适应输入数据,提高非线性拟合能力。

举例:一个邮件过滤模型中的神经元,负责对输入邮件的特征(长度、关键词等)进行加权求和,但这个结果只是一个连续的数值我们交

激活函数不能是线性函数,否则会变成单层感知机,依然会存在线性分割面无法处理XOR的问题。

激活函数主要作用于隐藏层。

激活函数的几种选择:

  1. sigmoid: 对于任意输入x,都能投影到0~1区间内。

  2. tanh(x): 将输入投影到[-1,1]区间内

  1. ReLU: 就是一个Max函数(常用),特点是计算很快,相比前面基于指数运算的sigmoid和tanh函数都快很多(一次指数运算要100个时钟周期)

对ReLU函数求导,小于等于0时都是0,大于0时都是1,最终结果就是一个二分类。

4. 代码实现

4.1 初始化参数

我们将实现一个具有单隐藏层的多层感知机, 这个隐藏层包含128个隐藏单元。

对于每一层我们都要记录一个权重矩阵和一个偏置向量,并指定requires_grad=True来记录参数梯度。

import torch
from torch import nn
from d2l import torch as d2l

num_inputs, num_outputs, num_hiddens = 784, 10, 128

W1 = nn.Parameter(torch.randn(
    num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(
    num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))

params = [W1, b1, W2, b2]

通常,我们选择2的若干次幂作为层的宽度。 因为内存在硬件中的分配和寻址方式,这么做往往可以在计算上更高效。

4.2 加载数据集

这里继续使用Fashion-MNIST图像分类数据集。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

4.3 激活函数

Relu函数的实现比较简单,就是一个max函数的调用, 它将输入的负值部分截断为0,保留正值部分不变。

def relu(X):
    a = torch.zeros_like(X)
    return torch.max(X, a)
  • torch.zeros_like(X): 创建了一个与X具有相同形状的全零张量a。
  • torch.max(X, a): 对于输入X中的每个元素,如果它是正值,则该元素保留不变;如果它是负值,则将其替换为0。

4.4 模型

def net(X):
    X = X.reshape((-1, num_inputs))    
    H = relu(X@W1 + b1)  # 隐藏层,这里“@”代表矩阵乘法
    return (H@W2 + b2)   # 输出层
  1. 使用reshape将输入的二维图像转换为一个长度为num_inputs=784的向量;
  2. 用ReLu函数对隐藏层的线性输出进行激活,得到输出张量H;
  3. 最后,由张量H和权重矩阵W2进行矩阵乘法操作,将偏置向量b2加到结果上,得到预测输出结果。

4.5 损失函数

这里直接使用pytorch中内置的交叉熵损失函数。

loss = nn.CrossEntropyLoss(reduction='none')

4.6 训练

多层感知机的训练过程与softmax的训练过程完全相同,可以直接调用之前定义过的train_ch3函数。

# 将迭代周期数设置为10,并将学习率设置为0.1.
num_epochs, lr = 10, 0.1
updater = torch.optim.SGD(params, lr=lr)
train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

训练过程中的模型损失和精度的收敛变化:

epoch: 1, loss: 1.1021366075515746, test_acc: 0.7544
epoch: 2, loss: 0.6142196039199829, test_acc: 0.8004
epoch: 3, loss: 0.5257990721384684, test_acc: 0.8061
epoch: 4, loss: 0.4842481053034465, test_acc: 0.7988
epoch: 5, loss: 0.4575055497487386, test_acc: 0.8266
epoch: 6, loss: 0.4389862974802653, test_acc: 0.8382
epoch: 7, loss: 0.42252545185089113, test_acc: 0.8443
epoch: 8, loss: 0.40933472124735515, test_acc: 0.8458
epoch: 9, loss: 0.3975078603744507, test_acc: 0.8467
epoch: 10, loss: 0.38488629398345947, test_acc: 0.8527

基于之前softmax模型上定义的预测函数,在测试数据集上使用这个模型做验证:

predict_ch3(net, test_iter)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/612050.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Chronos:学习时间序列的大语言模型(代码解析)

前言 《Chronos: Learning the Language of Time Series》原文地址,Github开源代码地址Chronos:学习时间序列的大语言模型(论文解读)CSDN地址GitHub项目地址Some-Paper-CN。本项目是译者在学习长时间序列预测、CV、NLP和机器学习…

【漫画版】指挥官的排序战术:快速排序算法解密

作者介绍:10年大厂数据\经营分析经验,现任字节跳动数据部门负责人。 会一些的技术:数据分析、算法、SQL、大数据相关、python,欢迎探讨交流 欢迎加入社区:码上找工作 作者专栏每日更新: LeetCode解锁1000题…

使用Python在PowerPoint演示文稿之间复制样式(复制幻灯片母版)

在专业演示文稿设计与制作领域,多场演示间保持一致性至关重要。在PowerPoint演示文稿之间复制幻灯片母版成为了一项关键技巧,用以维持统一的视觉风格,确保品牌形象的一致性,并提升观众的参与度。这一做法不仅能节省宝贵的时间&…

OC foudation框架(下)的学习

OCfoudation框架(下) 前面学习了有关OCfoudation框架的部分内容,我们现在对于后面的内容继续学习。 文章目录 OCfoudation框架(下)数组(NSArray和NSMutableArray)对集合元素整体调用方法排序使用…

SwinIR: Image Restoration Using Swin Transformer

ICCV2021 workshophttps://github.com/JingyunLiang/SwinIR 问题引入 将swim transformer使用到图像恢复任务当中,因为卷积存在不能建模长距离依赖以及使用相同的卷积核来恢复不同的图像区域;并不是首个将transformer引入图像恢复中的方法,…

简单的Python HTML 输出

1、问题背景 一名初学者在尝试将 Python 脚本输出到网页上时遇到了一些问题。他当前使用 Python 和 HTML 进行开发,并且遇到了以下问题: 担心自己的代码过于复杂,尤其是 WebOutput() 函数。希望通过 JavaScript 使用 HTML 模板文件更新数据。…

Java多线程与并发编程

1.多线程基础 1.1 线程相关概念 程序(program):是为完成特定任务、用某种语言编写的一组指令的集合。简单的说:就是我们写的代码 进程: 1. 进程是指运行中的程序,比如我们使用QQ,就启动了一个进程,操作系统就会为该进程…

常见扩频系统的基础概念和模型

扩频系统是一种通信技术,它通过将信号的频谱扩展到一定程度来实现传输,这种系统的设计和实现涉及到多种不同的方法和技术。 扩频系统的主要特点和好处包括: 抗干扰能力强:由于信号被扩展到较宽的频带上,单位带宽内的功…

数据收集-分化轨迹推断

数据收集-分化轨迹推断 1参考内容 2参考内容 3参考内容 4参考内容 5参考内容 6:methods and datasets review参考内容 1 参考 Ranek, J.S., Stanley, N. & Purvis, J.E. Integrating temporal single-cell gene expression modalities for trajectory inferen…

【p7】正规式转正规文法

需要注意的是,有时候需要自己构造一个非终结符,非终结符推导到空,然后套用上面的公式即可

十大排序算法(java实现)

注:本篇仅用来自己学习,大量内容来自菜鸟教程(地址:1.0 十大经典排序算法 | 菜鸟教程) 排序算法可以分为内部排序和外部排序,内部排序是数据记录在内存中进行排序,而外部排序是因排序的数据很大…

SpringCloud面试题

SpringCloud常见组件有哪些 注册中心组件:Eureka、Nacos 负载均衡组件:Ribbon 远程调用组件:OpenFeign 网关组件:Zuul、Gateway 服务保护组件:Hystrix、Sentinel 服务配置管理组件:SpringCloudConfig、Nac…

OpenCompass大模型评估

作业链接: Tutorial/opencompass/homework.md at camp2 InternLM/Tutorial GitHub 项目链接: GitHub - open-compass/opencompass: OpenCompass is an LLM evaluation platform, supporting a wide range of models (Llama3, Mistral, InternLM2,GPT-…

Docker快速搭建NAS服务——FileBrowser

Docker快速搭建NAS服务——FileBrowser 文章目录 前言FileBrowser的搭建docker-compose文件编写运行及访问 总结 前言 本文主要讲解如何使用docker在本地快速搭建NAS服务,这里主要写如下两种: FileBrowser1:是一个开源的Web文件管理器&…

【吊打面试官系列】Java高并发篇 - 为什么 wait(), notify()和 notifyAll ()必须在同步方法或者同步块中被调用?

大家好,我是锋哥。今天分享关于 【为什么 wait(), notify()和 notifyAll ()必须在同步方法或者同步块中被调用?】面试题,希望对大家有帮助; 为什么 wait(), notify()和 notifyAll ()必须在同步方法或者同步块中被调用?…

这3种深拷贝实现,你都知道吗?

目录: 1、JSON.parse 2、structuredClone 3、cloneDeep

【竞技宝jjb.lol】MSI:换线战术或将成为BLG命门

北京时间2024年5月10日,英雄联盟2024MSI季中赛继续进行,昨日迎来胜败分组赛首轮BLG对阵PSG。本以为这场比赛没有任何悬念,BLG将会非常轻松地击败PSG,没想到最终PSG两度扳平比分,BLG决胜局抗住压力才艰难取胜。虽然赢下了比赛,但BLG低迷的状态还是在比赛结束后遭到网友们的热议。…

超全MySQL锁机制介绍

前言 MySQL作为关系型数据库管理系统中的佼佼者,为了保证数据的一致性和完整性,在并发控制方面采用了锁机制。锁机制是数据库管理系统用于控制对共享资源的访问,避免多个事务同时修改同一数据造成的数据不一致问题。了解MySQL的锁机制对于数…

【组合博弈】介绍

本文为学习笔记,详细内容参考"Lessons in Play,Michael H. Albert Richard J. Nowakowski David Wolfe" 文章目录 组合博弈介绍(Combinatorial Games)DOMINEERING游戏组合游戏选手介绍Options博弈树(game tree) 组合博弈介绍(Combi…

*****水上飞机:继承,虚函数,虚继承

一题目 请设计以下航行器、飞机、船、水上飞机等 4 个类。 CRAFT 为航行器类,是公共基类,提供航行器的基本特性。包括: 一个保护数据成员:speed(速度)。 三个公有成员函数:构造函数(初始化速度)、析构函数和 Show 函数…