【漫话机器学习系列】088.常见的输出层激活函数(Common Output Layer Activation Functions)

在神经网络中,输出层(Output Layer) 的激活函数(Activation Function)直接决定了模型的输出形式,并影响损失函数的选择及训练效果。不同的任务类型(如分类或回归)需要使用不同的激活函数,以确保输出结果符合问题要求。

本文将详细介绍 二元分类(Binary Classification)、多分类(Multi-class Classification)和回归(Regression)任务中常用的输出层激活函数,并提供相应的代码示例。


1. 输出层激活函数概述

根据任务类型,输出层常用的激活函数如下:

任务类型常见激活函数输出值范围
二元分类sigmoid(0,1)
多分类softmax(0,1) 且所有类别概率之和为 1
回归无激活函数(线性输出)(-∞, +∞)

接下来,我们分别介绍这些激活函数的作用、数学公式、特点及代码实现。


2. 二元分类:Sigmoid 函数

2.1 介绍

Sigmoid 函数常用于二元分类问题,其数学公式如下:

\sigma(x) = \frac{1}{1 + e^{-x}}

  • 该函数将输出值映射到 (0,1) 之间,适用于概率预测。
  • \sigma(x) > 0.5,通常预测为正类(1),否则为负类(0)

2.2 代码示例

PyTorch 实现二元分类
import torch
import torch.nn as nn
import torch.optim as optim


# 定义二元分类模型
class BinaryClassifier(nn.Module):  # 二元分类模型
    def __init__(self):  # 初始化模型
        super(BinaryClassifier, self).__init__()  # 继承父类
        self.fc = nn.Linear(2, 1)  # 2个输入特征,1个输出(概率)

    def forward(self, x):  # 前向传播
        return torch.sigmoid(self.fc(x))  # Sigmoid 激活函数


# 定义损失函数和优化器
model = BinaryClassifier()  # 创建模型
criterion = nn.BCELoss()  # 二元交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.1)  # 优化器

# 生成示例数据
inputs = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]], dtype=torch.float32)  # 输入特征
labels = torch.tensor([[1.0], [0.0], [1.0]], dtype=torch.float32)  # 标签

# 训练模型
for epoch in range(100):  # 训练100个周期
    optimizer.zero_grad()  # 清空梯度
    outputs = model(inputs)  # 前向传播
    loss = criterion(outputs, labels)  # 计算损失
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数

    if epoch % 10 == 0:  # 每10个周期输出一次损失
        print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')  # 输出损失

运行结果

Epoch [0/100], Loss: 1.1413
Epoch [10/100], Loss: 0.5956
Epoch [20/100], Loss: 0.5529
Epoch [30/100], Loss: 0.5320
Epoch [40/100], Loss: 0.5160
Epoch [50/100], Loss: 0.5033
Epoch [60/100], Loss: 0.4927
Epoch [70/100], Loss: 0.4835
Epoch [80/100], Loss: 0.4753
Epoch [90/100], Loss: 0.4679

 

2.3 适用场景

  • 垃圾邮件检测(垃圾邮件 vs. 非垃圾邮件)
  • 疾病预测(有病 vs. 无病)
  • 金融欺诈检测(欺诈 vs. 正常交易)

3. 多分类:Softmax 函数

3.1 介绍

Softmax 函数用于多分类任务,其数学公式如下:

\text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}

  • 特点
    • 归一化所有输出值,使其总和等于 1(形成概率分布)。
    • 适用于独占类别(Mutually Exclusive Classes)的分类任务。

3.2 代码示例

PyTorch 实现多分类
import torch
import torch.nn as nn
import torch.optim as optim


class MultiClassClassifier(nn.Module):  # 多分类模型
    def __init__(self, input_dim, num_classes):  # 输入维度和类别数
        super(MultiClassClassifier, self).__init__()  # 继承父类初始化方法
        self.fc = nn.Linear(input_dim, num_classes)  # 线性层
        self.fc.weight.data.normal_(0, 0.01)  # 初始化权重

    def forward(self, x):  # 前向传播
        return torch.softmax(self.fc(x), dim=1)  # Softmax 归一化


# 定义模型、损失函数和优化器
model = MultiClassClassifier(input_dim=3, num_classes=4)  # 创建模型
criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.01)  # Adam 优化器

# 生成示例数据(假设有3个特征,4个类别)
inputs = torch.tensor([[1.0, 2.0, 3.0], [2.0, 1.0, 3.0]], dtype=torch.float32)  # 输入特征
labels = torch.tensor([1, 3])  # 类别索引(0-3)

# 训练模型
for epoch in range(100):  # 训练100个周期
    optimizer.zero_grad()  # 清空梯度
    outputs = model(inputs)  # 前向传播
    loss = criterion(outputs, labels)  # 计算损失
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数

    if epoch % 10 == 0:  # 每10个周期打印一次损失
        print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')  # 打印损失

运行结果

Epoch [0/100], Loss: 1.3481
Epoch [10/100], Loss: 1.2146
Epoch [20/100], Loss: 1.1262
Epoch [30/100], Loss: 1.0629
Epoch [40/100], Loss: 1.0134
Epoch [50/100], Loss: 0.9731
Epoch [60/100], Loss: 0.9401
Epoch [70/100], Loss: 0.9131
Epoch [80/100], Loss: 0.8909
Epoch [90/100], Loss: 0.8727

3.3 适用场景

  • 手写数字识别(MNIST,类别 0-9)
  • 图像分类(狗、猫、鸟等)
  • 情感分析(积极、中性、消极)

4. 回归:无激活函数(线性输出)

4.1 介绍

回归任务中,目标是预测连续数值,因此输出层通常不使用激活函数,保持线性输出:

y = w \cdot x + b

  • 预测值范围为(-∞, +∞),不受激活函数限制。

4.2 代码示例

PyTorch 实现回归
import torch
import torch.nn as nn
import torch.optim as optim


class RegressionModel(nn.Module):  # 线性回归模型
    def __init__(self):  # 初始化模型
        super(RegressionModel, self).__init__()  # 继承父类
        self.fc = nn.Linear(1, 1)  # 单输入单输出

    def forward(self, x):  # 前向传播
        return self.fc(x)  # 线性输出,无激活函数


# 定义模型、损失函数和优化器
model = RegressionModel()  # 创建模型
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器

# 生成示例数据
inputs = torch.tensor([[1.0], [2.0], [3.0]], dtype=torch.float32)  # 输入 x
targets = torch.tensor([[2.0], [4.0], [6.0]], dtype=torch.float32)  # 目标 y = 2x

# 训练模型
for epoch in range(100):  # 训练100个周期
    optimizer.zero_grad()  # 梯度清零
    outputs = model(inputs)  # 前向传播
    loss = criterion(outputs, targets)  # 计算损失
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数

    if epoch % 10 == 0:  # 每10个周期打印一次损失
        print(f'Epoch [{epoch}/100], Loss: {loss.item():.4f}')  # 输出损失

运行结果 

Epoch [0/100], Loss: 18.9413
Epoch [10/100], Loss: 1.8053
Epoch [20/100], Loss: 0.1735
Epoch [30/100], Loss: 0.0181
Epoch [40/100], Loss: 0.0032
Epoch [50/100], Loss: 0.0017
Epoch [60/100], Loss: 0.0015
Epoch [70/100], Loss: 0.0014
Epoch [80/100], Loss: 0.0013
Epoch [90/100], Loss: 0.0013

4.3 适用场景

  • 房价预测
  • 股票价格预测
  • 销量预测

5. 结论

任务类型激活函数适用损失函数
二元分类Sigmoid二元交叉熵(BCELoss)
多分类Softmax交叉熵损失(CrossEntropyLoss)
回归无激活函数均方误差(MSELoss)

不同任务类型需要选择合适的激活函数,以确保神经网络能够正确学习和预测目标值。希望本文能帮助你更好地理解输出层激活函数的使用! 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/968573.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

在实体机和wsl2中安装docker、使用GPU

正常使用docker和gpu,直接命令行安装dcoker和,nvidia-container-toolkit。区别在于,后者在于安装驱动已经cuda加速时存在系统上的差异。 1、安装gpu驱动 在实体机中,安装cuda加速包,我们直接安装 driver 和 cuda 即可…

麒麟v10 server版安装ollama跑Deepseek

麒麟v10 server版安装ollama跑Deepseek 1. 环境 2. 安装docker yum install docker 发现源只有18.x版本,启动ollama,发现调用CPU,没调用GPU docker19.x以上才 会调用GPU, 可以添加centos8的原,安装docker-ce3.启动ollama&#…

LabVIEW用户界面(UI)和用户体验(UX)设计

作为一名 LabVIEW 开发者,满足功能需求、保障使用便捷与灵活只是基础要求。在如今这个用户体验至上的时代,为 LabVIEW 应用程序设计直观且具有美学感的界面,同样是不容忽视的关键任务。一个优秀的界面设计,不仅能提升用户对程序的…

如何使用Java语言在Idea和Android中分别建立服务端和客户端实现局域网聊天

手把手教你用Java语言在Idea和Android中分别建立服务端和客户端实现局域网聊天 目录 文章目录 手把手教你用**Java**语言在**Idea**和**Android**中分别建立**服务端**和**客户端**实现局域网聊天**目录**[toc]**基本实现****问题分析****服务端**Idea:结构预览Server类代码解…

【实战篇】DeepSeek + ElevenLabs:让人工智能“开口说话”,打造你的专属语音助手!

最近,AI语音合成技术真是火得不行,各种“开口脆”的AI声音层出不穷,听得我直呼“这也太像真人了吧!” 作为一个科技爱好者,我当然不能错过这股潮流,这不,最近就沉迷于用 DeepSeek 和 ElevenLabs 这两款神器,捣鼓各种人声音频,简直停不下来! 先来科普一下这两位“主角…

vscode/cursor+godot C#中使用socketIO

在 Visual Studio Code(VS Code)中安装 NuGet 包(例如SocketIOClient),你可以通过以下几种方法: 方法 1:使用dotnet cli 打开终端:在 VS Code 中按下Ctrl 或者通过菜单View -> Terminal打开终端。 导…

ETL的使用(sqoop):数据导入,导出

ETL ETL: 是数据抽取(Extract)、数据转换(Transform)和数据加载(Load)的整个过程 常用的ETL工具 sqoop 1.Apache Sqoop 是 Apache 软件基金会旗下的一个开源项目,旨在帮助用户高效地在 Hado…

安卓手游内存call综合工具/内部call/安卓注入call/数据分析(类人猿学院)

进程分析注入综合工具总界面 模块分析函数分析遍历 函数分析 so汇编分析 汇编call植入器,支持模拟器x86 x64 和手机arm64指令全平台 防ce搜索数据功能 全国首套发布,阿凡老师学院最好的安卓内存逆向老师,几乎行业最强的,有兴趣可以…

Transformer 详解:了解 GPT、BERT 和 T5 背后的模型

目录 什么是 Transformer? Transformer如何工作? Transformer 为何有用? 常见问题解答:机器学习中的 Transformer 在技​​术领域,突破通常来自于修复损坏的东西。制造第一架飞机的人研究过鸟类。莱特兄弟观察了秃鹫如何在气流中保持平衡,意识到稳定性比动力更重要。…

在C++的DLL文件中定义的结构体,在DLL外可以使用吗,如何使用?

在C的DLL文件中定义的结构体,在DLL外可以使用,但需要正确处理类型匹配和内存管理。‌ 在DLL外使用DLL中定义的结构体的方法: 一、 ‌使用extern "C"声明‌: 在DLL的导出声明中使用extern "C"可以确保函数和…

PDF另存为图片的一个方法

说明 有时需要把PDF的每一页另存为图片。用Devexpress可以很方便的完成这个功能。 窗体上放置一个PdfViewer。 然后循环每一页 for (int i 1; i < pdfViewer1.PageCount; i) 调用 chg_pdf_to_bmp函数获得图片并保存 chg_pdf_to_bmp中调用了PdfViewer的CreateBitmap函数…

中间件-安装Minio-集成使用(ubantu-docker)

目录 1、安装docer 2、运行以下命令拉取MinIO的Docker镜像 3、检查当前所有Docker下载的镜像 4、创建目录 5、创建Minio容器并运行 6、SDK操作 FileUploader.java 1、安装docer 参考这篇&#xff1a;Linux安装Docker 2、运行以下命令拉取MinIO的Docker镜像 docker pull…

ffmpeg学习:ubuntu下编译Android版ffmpeg-kit

文章目录 前言一. 配置环境1.1 虚拟机版本1.2 安装Android环境1.2.1 Android SDK安装1.2.2 Android NDK安装 1.3 编译前的准备工作1.3.1 libtasn1-1安装1.3.2 meson安装1.3.3 harfbuzz下载 二. 编译ffmpeg-kit三. 总结 前言 ffmpeg-kit是一款跨多个平台的&#xff0c;用于在应…

【全球人口数据集】全球人口密度数据集GPWv4

目录 数据概述数据处理方法数据下载参考GPWv4: Population Density, Revision 11 是由 NASA Socioeconomic Data and Applications Center (SEDAC) 提供的全球人口密度数据集,旨在支持社会经济和环境研究。 数据概述 Gridded Population of the World, Version 4 (GPWv4): Po…

PyTorch 中 `torch.cuda.amp` 相关警告的解决方法

在最近的写代码过程中&#xff0c;遇到了两个与 PyTorch 的混合精度训练相关的警告信息。这里随手记录一下。 警告内容 警告 1: torch.cuda.amp.autocast FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast(cuda, args...) i…

NLP面试之-激活函数

一、动机篇 1.1 为什么要有激活函数&#xff1f; 数据角度&#xff1a;由于数据是线性不可分的&#xff0c;如果采用线性化&#xff0c;那么需要复杂的线性组合去逼近问题&#xff0c;因此需要非线性变换对数据分布进行重新映射;线性模型的表达力问题&#xff1a;由于线性模型…

四、自然语言处理_08Transformer翻译任务案例

0、前言 在Seq2Seq模型的学习过程中&#xff0c;做过一个文本翻译任务案例&#xff0c;多轮训练后&#xff0c;效果还算能看 Transformer作为NLP领域的扛把子&#xff0c;对于此类任务的处理会更为强大&#xff0c;下面将以基于Transformer模型来重新处理此任务&#xff0c;看…

关于conda换镜像源,pip换源

目录 1. 查看当前下载源2. 添加镜像源2.1清华大学开源软件镜像站2.2上海交通大学开源镜像站2.3中国科学技术大学 3.删除镜像源4.删除所有镜像源&#xff0c;恢复默认5.什么是conda-forge6.pip换源 1. 查看当前下载源 conda config --show channels 如果发现多个 可以只保留1个…

因果机器学习(CausalML)前沿创新思路

结合了传统因果推断与机器学习的因果机器学习是目前AI领域的前沿研究方向&#xff0c;其核心优势在于将因果逻辑融入数据驱动模型&#xff0c;从根本上解决了传统方法的缺陷。因此&#xff0c;它也是突破传统机器学习瓶颈的关键方向&#xff0c;不仅当下热度高&#xff0c;在未…

网络防御高级02-综合实验

web页面&#xff1a; [FW]interface GigabitEthernet 0/0/0 [FW-GigabitEthernet0/0/0]service-manage all permit 需求一&#xff0c;接口配置&#xff1a; SW2: [Huawei]sysname SW2 1.创建vlan [sw2]vlan 10 [sw2]vlan 20 2.接口配置 [sw2]interface GigabitEther…