简单的神经网络

一、softmax的基本概念

我们之前学过sigmoid、relu、tanh等等激活函数,今天我们来看一下softmax。

先简单回顾一些其他激活函数:

  1. Sigmoid激活函数:Sigmoid函数(也称为Logistic函数)是一种常见的激活函数,它将输入映射到0到1之间。它常用于二分类问题中,特别是在输出层以概率形式表示结果时。Sigmoid函数的优点是输出值限定在0到1之间,相当于对每个神经元的输出进行了归一化处理。
  2. Tanh激活函数:Tanh函数(双曲正切函数)将输入映射到-1到1之间。与Sigmoid函数相比,Tanh函数的中心点在零值附近,这意味着它的输出是以0为中心的。这种特性可以在某些情况下提供更好的性能。
  3. ReLU激活函数:ReLU(Rectified Linear Unit)函数是当前非常流行的一个激活函数,其表达式为f(x)=max(0, x)。ReLU函数的优点是计算简单,能够在正向传播过程中加速计算。此外,ReLU函数在正值区间内梯度为常数,有助于缓解梯度消失问题。但它的缺点是在负值区间内梯度为零,这可能导致某些神经元永远不会被激活,即“死亡ReLU”问题。

Softmax函数是一种在机器学习中广泛使用的函数,尤其是在处理多分类问题时。它的主要作用是将一组未归一化的分数转换成一个概率分布。Softmax函数的一个重要性质是其输出的总和等于1,这符合概率分布的定义。这意味着它可以将一组原始分数转换为概率空间,使得每个类别都有一个明确的概率值。

  • 二分类问题选择sigmoid激活函数

  • 多分类问题选择softmax激活函数

二、交叉熵损失函数

交叉熵损失函数的公式可以分为二分类和多分类两种情况。对于二分类问题,假设我们只考虑正类(标签为1)和负类(标签为0)在多分类问题中,交叉熵损失函数可以扩展为−∑𝑖=1𝐾𝑦𝑖⋅log⁡(𝑝𝑖)−∑i=1K​yi​⋅log(pi​),其中𝐾K是类别的总数,( y_i )是样本属于第𝑖i个类别的真实概率(通常用one-hot编码表示),而𝑝𝑖pi​是模型预测该样本属于第( i )个类别的概率。

import torch
from torch import nn

# 确定随机数种子
torch.manual_seed(7)
# 自定义数据集
X = torch.rand((7, 2, 2))
target = torch.randint(0, 2, (7,))

定义网络结构

  • 一层全连接层 + Softmax层
  • x1𝑥1,x2𝑥2,x3𝑥3,x4𝑥4为 X
  • o1𝑜1,o2𝑜2,o3𝑜3为 target
class LinearNet(nn.Module):
    def __init__(self):
        super(LinearNet, self).__init__()
        # 定义一层全连接层
        self.dense = nn.Linear(4, 3)
        # 定义Softmax
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        y = self.dense(x.view((-1, 4)))
        y = self.softmax(y)
        return y

net = LinearNet()
  •  nn.Softmax(dim=1)用于计算输入张量在指定维度上的softmax激活。dim=1表示沿着第二个维度(即列)进行softmax操作。

定义损失函数和优化函数

  • torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
  • 衡量模型输出与真实标签的差异,在分类时相当有用。
  • 结合了nn.LogSoftmax()和nn.NLLLoss()两个函数,进行交叉熵计算。
loss = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)  # 随机梯度下降法

训练模型

for epoch in range(70):
    train_l = 0.0
    y_hat = net(X)
    l = loss(y_hat, target).sum()

    # 梯度清零
    optimizer.zero_grad()
    # 自动求导梯度
    l.backward()
    # 利用优化函数调整所有权重参数
    optimizer.step()

    train_l += l
    print('epoch %d, loss %.4f' % (epoch + 1, train_l))

三、自动微分模块

torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False)  :自动求取梯度

  • grad_tensors:多梯度权重
  • create_graph:创建导数计算图,用于高阶求导
  • retain_graph:保存计算图
  • tensors:用于求导的张量,如 loss
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

y.backward(retain_graph=True)

 注意点:

  1. 梯度不自动清零
  2. 依赖于叶子节点的节点,requires_grad默认为True
  3. 叶子节点不可执行in-place

神经网络全连接层: 每个神经元都与前一层的所有神经元相连接。全连接层通常用于网络的最后几层,它将之前层(如卷积层和池化层)提取的特征进行整合,以映射到样本标记空间,即最终的分类或回归结果。

关于loss.backward()方法:

主要作用就是计算损失函数对模型参数的梯度,loss.backward()实现了反向传播算法,它通过链式法则计算每个模型参数相对于最终损失的梯度。这个过程从输出层开始,向后传递到输入层,逐层计算梯度。

过程:得到每个参数相对于损失函数的梯度,这些梯度信息会存储在对应张量的.grad属性中。loss.backward本身不负责更细权重,但它为权重更新提供了梯度值,方便配合optimizer.step()来更新参数。

前向传播过程中,数据从输入层流向输出层,并生成预测结果;而在反向传播过程中,误差(即预测值与真实值之间的差距,也就是损失函数的值)会从输出层向输入层传播,逐层计算出每个参数相对于损失函数的梯度。这些梯度指示了如何调整每一层中的权重和偏置,以最小化损失函数。

  • 损失函数衡量了当前模型预测与真实情况之间的不一致程度,而梯度则提供了损失函数减少最快的方向。

建立一个简单的全连接层:

import torch
import torch.nn as nn

# 定义一个简单的全连接层模型
class SimpleFC(nn.Module):
    def __init__(self, input_size, output_size):
        super(SimpleFC, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, x):  
        return self.fc(x)

# 创建输入数据和目标输出
input_data = torch.tensor([[1.0, 2.0, 3.0]])
target_output = torch.tensor([[4.0, 5.0]])

# 实例化模型、损失函数和优化器
model = SimpleFC(input_size=3, output_size=2)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 前向传播
output = model(input_data)

# 计算损失
loss = criterion(output, target_output)

# 反向传播
loss.backward()

# 更新参数
optimizer.step()

当调用loss.backward()时,PyTorch会自动计算损失值关于模型参数的梯度,并将这些梯度存储在模型参数的.grad属性中。然后优化器(torch.optim.SGD)可以使用这些梯度来更新模型参数,以最小化损失函数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/612644.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

EPAI手绘建模APP动画、场景、手势操作

(15) 动画 图 299 动画控制器 ① 打开动画控制器。播放动画过程中,切换场景观察视角时,自动停止播放。动画编辑参见常用工具栏-更多-动画动画编辑器部分。 ② 关闭动画控制器。 ③ 设置动画参数:设置动画总帧数;这只帧率&#x…

从RAID 0到RAID 10:全面解析RAID技术与应用

🐇明明跟你说过:个人主页 🏅个人专栏:《Linux :从菜鸟到飞鸟的逆袭》🏅 🔖行路有良友,便是天堂🔖 目录 一、前言 1、磁盘阵列简介 2、磁盘阵列诞生背景 3、硬件RA…

Spring Boot集成activiti快速入门Demo

1.什么事activiti? Activiti是一个工作流引擎,可以将业务系统中复杂的业务流程抽取出来,使用专门的建模语言BPMN2.0进行定义,业务流程按照预先定义的流程进行执行,实现了系统的流程流activiti进行管理,减少业务系统由于流程变更进行系统升级改造的工作量,从而提高系…

与队列和栈相关的【OJ题】

✨✨✨专栏:数据结构 🧑‍🎓个人主页:SWsunlight 目录 一、用队列实现栈: 1、2个队列的关联起来怎么由先进先出转变为先进后出:(核心) 2、认识各个函数干嘛用的: …

pgbackrest 备份工具使用 postgresql

为啥我会使用pgbackrest进行备份?因为postgresql没有自带的差异备份工具。。。而我们在生产环境上,一般都需要用到差异备份或者增量备份。我们的备份策略基本是,1天1次完整备份,1个小时1次差异备份。如果只需要完整备份&#xff0…

【Mac】Indesign 2023 Mac(ID2023) v18.5中文版安装教程

软件介绍 Adobe InDesign是一款由Adobe Systems开发的桌面排版软件,旨在用于创建、编辑和格式化印刷和数字出版物,如书籍、杂志、报纸、传单等。以下是一些关于Adobe InDesign的主要特点和功能: 1.强大的排版工具:InDesign提供了…

Linux的命令(第二篇)

昨天学习到了第17个命令到 rm 命令(作用删除目录和文件),今天继续往下里面了解其他命令以及格式、选项: (17)wc命令(此wc非wc) 作用:统计行数、单词数、字符分数。 格…

JavaScript使用 BigInt

在 JavaScript 中,最大的安全整数是 2 的 53 次方减 1,即 Number.MAX_SAFE_INTEGER,其值为 9007199254740991。这是因为 JavaScript 中使用双精度浮点数表示数字,双精度浮点数的符号位占 1 位,指数位占 11 位&#xff…

探索计算之美:HTML CSS 计算器案例

本次案例是通过HTML和CSS,我们可以为计算器赋予独特的外观和功能; 在这个计算器中,你将会发现: 简洁清晰的界面设计,使用户能够轻松输入和查看计算结果。利用HTML构建的结构,确保页面具有良好的可访问性和…

gitee 简易使用 上传文件

Wiki - Gitee.com 官方教程 1.gitee 注册帐号 2.下载git 安装 http://git-scm.com/downloads 3. 桌面 鼠标右键 或是开始菜单 open git bash here 输入(复制 ,粘贴) 运行完成后 刷新网页 下方加号即可以添加文件 上传文件 下载 教程…

前端崽的java study笔记

文章目录 basic1、sprint boot概述2、sprint boot入门3、yml 配置信息书写和获取 basic 1、sprint boot概述 sprint boot特性: 起步依赖(maven坐标):解决配置繁琐的问题,只需要引入sprint boot起步依赖的坐标就行 自动…

【敦煌网注册/登录安全分析报告】

敦煌网注册/登录安全分析报告 前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 暴力破解密码,造成用户信息泄露短信盗刷的安全问题,影响业务及导致用户投诉带来经济损失,尤其是后付费客户,风险巨大…

基于STM32移植lvgl(V8.2)(SPI接口的LCD)

目录 概述 1 认识LVGL 1.1 LVGL官网 1.2 LVGL库文件下载 2 认识SPI接口型LCD 2.1 PIN引脚定义 2.2 MCU IO与LCD PIN对应关系 3 实现LCD驱动 3.1 使用STM32Cube配置Project 3.2 STM32Cube生成工程 4 移植LVGL 4.1 准备移植文件 4.2 添加lvgl库文件到项目 4.2.1 src下…

工作中使用Optional过滤出符合条件的数据

工作中使用Optional获取非空对象的属性 实体类Optional对非空对象的处理满足过滤条件返回的值不满足条件返回的值 实体类 package po;import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor;import java.io.Serializable;Data AllArgsConst…

stm32开发三、GPIO

部分引脚可容忍5V,容忍5V的意思是:可以在这个端口输入5V的电压,也认为是高电平 但是对于输出而言,最大就只能输出3.3V,因为供电就只有3.3V 具体哪些端口能容忍5V,可以参考一下STM32的引脚定义 不带FT的,就只…

MobileNet 网络详解

一、了解 网络亮点: 1、DW网络,大大减少运算量和参数数量 2、增加超参数:控制卷积层卷积核个数的超参数 ,控制图像输入大小的超参数 ,这两个超参数是人为设定的,不是机器学习到的。 二、DW卷积&#xff…

湖仓一体 - Apache Arrow的那些事

湖仓一体 - Apache Arrow的那些事 Arrow是高性能列式内存格式标准。它的优势:高效计算:所有列存的通用优势,CPU缓存友好、SIMD向量化计算友好等;零序列化/反序列化:arrow的任何数据结构都是一段连续的内存,…

深入学习指针3

目录 前言 1.二级指针 2.指针数组 3.指针数组模拟二维数组 前言 Hello,小伙伴们我又来了,上期我们讲到了数组名的理解,指针与数组的关系等知识,那今天我们就继续深入到学习指针域数组的练联系,如果喜欢作者菌生产的内容还望不…

### 【数据结构】线性表--顺序表(二)

文章目录 1、什么是线性表2、线性表的基本操作3、顺序表3.1、顺序表的定义3.2、顺序表的实现方式:静态分配3.3、顺序表的实现方式:动态分配3.4、顺序表的特点3.5、顺序表的初始化与插入操作3.6、顺序表的删除与查询 1、什么是线性表 ​ 线性表是具有相同…

MyBatis——使用MyBatis完成CRUD

CRUD&#xff1a;Create Retrieve Update Delete 1、insert <insert id"insertCar">insert into t_car(id,car_num,brand,guide_price,produce_time,car_type)values(null,1003,五菱宏光,30.0,2020-09-18,燃油车); </insert> 这样写显然是写死的&#…