机器学习——多层感知机

感知机

缺点:只能处理线性问题,感知机无法解决异或问题
在这里插入图片描述
在这里偏置就像线性模型的常数项,加入偏置模型的表达能力增强,而激活函数就像示性函数,可以模拟神经元的兴奋和抑制,当大于等于0就输出1。
在这里插入图片描述

多层感知机MLP

前馈网络的层数是指权重的层数
多个单层感知机按前馈结构,前馈结构就是层只与相邻层连接,不跨越连接,就是多层感知机

激活函数
逻辑斯蒂函数
双曲正切函数
线性整流单元ReLU

一般让所有的隐含层的激活函数相同,输出层的激活函数需根据任务的需求选择,二分类可以选择逻辑斯蒂回归,多分类用softmax函数
MLP相比单层感知机的表达能力提升,关键在于非线性激活函数
可以证明任意一个R上的连续函数都可以由MLP来拟合,而对其非线性的激活函数的形式要求很少,也称作普适逼近定理
非线性对提升模型的表达能力很重要,其实因为非线性变换相当于提升了数据的维度,维度提升的好处就在于低维数据不可分的问题可以在高维中可分

import torch # PyTorch库
import torch.nn as nn # PyTorch中与神经网络相关的工具
from torch.nn.init import normal_ # 正态分布初始化

torch_activation_dict = {
    'identity': lambda x: x,
    'sigmoid': torch.sigmoid,
    'tanh': torch.tanh,
    'relu': torch.relu
}

# 定义MLP类,基于PyTorch的自定义模块通常都继承nn.Module
# 继承后,只需要实现forward函数,进行前向传播
# 反向传播与梯度计算均由PyTorch自动完成
class MLP_torch(nn.Module):

    def __init__(
        self, 
        layer_sizes, # 包含每层大小的list
        use_bias=True, 
        activation='relu',
        out_activation='identity'
    ):
        super().__init__() # 初始化父类
        self.activation = torch_activation_dict[activation]
        self.out_activation = torch_activation_dict[out_activation]
        self.layers = nn.ModuleList() # ModuleList以列表方式存储PyTorch模块
        num_in = layer_sizes[0]
        for num_out in layer_sizes[1:]:
            # 创建全连接层
            self.layers.append(nn.Linear(num_in, num_out, bias=use_bias))
            # 正态分布初始化,采用与前面手动实现时相同的方式
            normal_(self.layers[-1].weight, std=1.0)
            # 偏置项为全0
            self.layers[-1].bias.data.fill_(0.0)
            num_in = num_out

    def forward(self, x):
        # 前向传播
        # PyTorch可以自行处理batch_size等维度问题
        # 我们只需要让输入依次通过每一层即可
        for i in range(len(self.layers) - 1):
            x = self.layers[i](x)
            x = self.activation(x)
        # 输出层
        x = self.layers[-1](x)
        x = self.out_activation(x)
        return x
#%%
# 设置超参数
num_epochs = 1000
learning_rate = 0.1
batch_size = 128
eps = 1e-7
torch.manual_seed(0)

# 初始化MLP模型
mlp = MLP_torch(layer_sizes=[2, 4, 1], use_bias=True, 
    out_activation='sigmoid')

# 定义SGD优化器
opt = torch.optim.SGD(mlp.parameters(), lr=learning_rate)

# 训练过程
losses = []
test_losses = []
test_accs = []
for epoch in range(num_epochs):
    st = 0
    loss = []
    while True:
        ed = min(st + batch_size, len(x_train))
        if st >= ed:
            break
        # 取出batch,转为张量
        x = torch.tensor(x_train[st: ed], 
            dtype=torch.float32)
        y = torch.tensor(y_train[st: ed], 
            dtype=torch.float32).reshape(-1, 1)
        # 计算MLP的预测
        # 调用模型时,PyTorch会自动调用模型的forward方法
        # y_pred的维度为(batch_size, layer_sizes[-1])
        y_pred = mlp(x)
        # 计算交叉熵损失
        train_loss = torch.mean(-y * torch.log(y_pred + eps) \
            - (1 - y) * torch.log(1 - y_pred + eps))
        # 清空梯度
        opt.zero_grad()
        # 反向传播
        train_loss.backward()
        # 更新参数
        opt.step()

        # 记录累加损失,需要先将损失从张量转为numpy格式
        loss.append(train_loss.detach().numpy())
        st += batch_size

    losses.append(np.mean(loss))
    # 计算测试集上的交叉熵
    # 在不需要梯度的部分,可以用torch.inference_mode()加速计算
    with torch.inference_mode():
        x = torch.tensor(x_test, dtype=torch.float32)
        y = torch.tensor(y_test, dtype=torch.float32).reshape(-1, 1)
        y_pred = mlp(x)
        test_loss = torch.sum(-y * torch.log(y_pred + eps) \
            - (1 - y) * torch.log(1 - y_pred + eps)) / len(x_test)
        test_acc = torch.sum(torch.round(y_pred) == y) / len(x_test)
        test_losses.append(test_loss.detach().numpy())
        test_accs.append(test_acc.detach().numpy())

print('测试精度:', test_accs[-1])
# 将损失变化进行可视化
plt.figure(figsize=(16, 6))
plt.subplot(121)
plt.plot(losses, color='blue', label='train loss')
plt.plot(test_losses, color='red', ls='--', label='test loss')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Cross-Entropy Loss')
plt.legend()

plt.subplot(122)
plt.plot(test_accs, color='red')
plt.ylim(top=1.0)
plt.xlabel('Step')
plt.ylabel('Accuracy')
plt.title('Test Accuracy')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/688004.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

从军事角度理解“战略与战术”

战略与战术,均源于军事术语。 战略(Strategy),源自希腊语词汇“strategos(将军)”和“strategia(军事指挥部,即将军的办公室和技能)”。指的是指挥全局性作战规划的谋略…

网络基础_02

1.ARP协议 地址解析协议(Address Resolution Protocol) 已知对方的三层ip地址,需要二层mac地址 当一台设备(请求方)需要知道某个 IP 地址对应的 MAC 地址时,会使用 ARP封装一个数据帧。这台设备的网络层以…

[职场] 为什么不能加薪? #学习方法#知识分享#微信

为什么不能加薪? 不能加薪的根本原因,终于被我找到了! 朋友们!职场这个地方是个很神奇的世界,有些规则并不是你想象的那样。我们都希望能在这个世界里施展自己的才华,获得升职加薪的荣耀。然而&#xff0c…

Blender 学习笔记(二)游标与原点

1. 游标 游标是界面中的红色圆圈: 1.1 移动游标 我们可以通过点击工具栏中的游标按钮,来移动游标,或者通过快捷键 shift右键 移动。若想要重置复游标位置,可以用 shiftc 恢复,或则通过 shifts 点击 游标->世界原…

Python中的Paramiko与FTP文件夹及文件检测技巧

哈喽,大家好,我是木头左! Python代码的魅力与实用价值 在当今数字化时代,编程已成为一种不可或缺的技能。Python作为一种简洁、易读且功能强大的编程语言,受到了全球开发者的喜爱。它不仅适用于初学者入门&#xff0c…

Java 数据库连接(JDBC)的使用,包括连接数据库、执行SQL语句等

一、简介 Java Database Connectivity(JDBC)是Java应用程序与关系数据库进行交互的一种API。它提供了一组用于访问和操作数据库的标准接口,使开发人员能够使用Java代码执行数据库操作,如查询、插入、更新和删除等。 二、JDBC架构…

C++ AVL树 详细讲解

目录 一、AVL树的概念 二、AVL树的实现 1.AVL树节点的定义 2.AVL树的插入 3.AVL树的旋转 4.AVL树的验证 三、AVL树的性能 四、完结撒❀ 一、AVL树的概念 二叉搜索树虽可以缩短查找的效率,但 如果数据有序或接近有序二叉搜索树将退化为单支树,查 …

SQL进阶day11——窗口函数

目录 1专用窗口函数 1.1 每类试卷得分前3名 1.2第二快/慢用时之差大于试卷时长一半的试卷 1.3连续两次作答试卷的最大时间窗 1.4近三个月未完成试卷数为0的用户完成情况 1.5未完成率较高的50%用户近三个月答卷情况 2聚合窗口函数 2.1 对试卷得分做min-max归一化 2.2每份…

m3u8视频怎么打开?教你三招!

m3u8 是一种文本文件格式,用于创建媒体播放列表,现在大部分的视频流媒体都是m3u8格式。当我们从网上下载下来m3u8文件的时候会发现,它本身不是一段视频,而是一个索引纯文本文件。想要正常打开播放m3u8视频其实也很简单&#xff0c…

解码 LangChain | LangChain + GPTCache =兼具低成本与高性能的 LLM

LangChain 联合创始人 Harrison Chase 提到,多跳问题会给语义检索带来挑战,并提出可以试用 AI 代理工具解决。不过,频繁调用 LLM 会导致出现使用成本高昂的问题。 对此,Zilliz 软件工程师 Filip Haltmayer 指出,将 GP…

塑造财务规划团队的未来角色

随着企业不断改革,其财务规划团队的格局也在不断变化,领先行业的专业人士已经开始利用更创新的财务知识和思维来驾驭现代化财务规划角色的复杂性。财务团队需要不同的职能角色和技能组合来支持其发展,多学科团队和跨职能协作带来的挑战和机遇…

C语言详解(动态内存管理)1

Hi~!这里是奋斗的小羊,很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~~ 💥💥个人主页:奋斗的小羊 💥💥所属专栏:C语言 🚀本系列文章为个人学习…

MacBook Pro上高cpu上不断重启运行的efilogin-helper

高 cpu 运行这个不知道干什么的进程,让风扇疯狂输出,让人甚是烦躁,苹果社区里的回答比较抽象,要么换设备,要么重装。 尝试过找到这个文件,删了部分内容,无果。。。 stack overflow 有个回答&a…

stm32驱动直流电机实现启动/加速/减速/倒车/停车等功能

不积跬步,无以至千里;不积小流,无以成江海。 大家好,我是闲鹤,公众号 xxh_zone,十多年开发、架构经验,先后在华为、迅雷服役过,也在高校从事教学3年;目前已创业了7年多&a…

UML交互图-协作图

概述 协作图和序列图都表示出了对象间的交互作用,但是它们侧重点不同。序列图清楚地表示了交互作用中的时间顺序,但没有明确表示对象间的关系。协作图则清楚地表示了对象间的关系,但时间顺序必须从顺序号获得。序列图常常用于表示方案&#…

使用紫铜管制作半波天线的折合振子

一、概述 半波天线是一种简单而有效的天线类型,其长度约为工作波长的一半。它具有较好的辐射特性和较高的增益,广泛应用于业余无线电、电视接收等领域。使用紫铜管制作折合振子,不仅可以提高天线的机械强度,还能增强其导电性能。 …

k8s——Pod容器中的存储方式及PV、PVC

一、Pod容器中的存储方式 需要存储方式前提:容器磁盘上的文件的生命周期是短暂的,这就使得在容器中运行重要应用时会出现一些问题。 首先,当容器崩溃时,kubelet 会重启它,但是容器中的文件将丢失——容器以干净的状态&…

如何让tracert命令的显示信息显示*星号

tracert命令如果在中间某一个节点超时,只会在显示信息中标识此节点信息超时“ * * * ”,不影响整个tracert命令操作。 如上图所示,在DeviceA上执行tracert 10.1.2.2命令,缺省情况下,DeviceA上的显示信息为:…

Xamarin.Android实现通知推送功能(1)

目录 1、背景说明1.1 开发环境1.2 实现效果1.2.1 推送的界面1.2.2 推送的设置1.2.3 推送的功能实现1.2.3.1、Activity的设置【重要】1.2.3.2、代码的实现 2、源码下载3、总结4、参考资料 1、背景说明 在App开发中,通知(或消息)的推送&#x…

k8s小型实验模拟

(1)Kubernetes 区域可采用 Kubeadm 方式进行安装。(5分) (2)要求在 Kubernetes 环境中,通过yaml文件的方式,创建2个Nginx Pod分别放置在两个不同的节点上,Pod使用hostPat…