Pytorch 复习总结 3

Pytorch 复习总结,仅供笔者使用,参考教材:

  • 《动手学深度学习》
  • Stanford University: Practical Machine Learning

本文主要内容为:Pytorch 多层感知机。

本文先介绍了多层感知机的用法,再就训练过程中经常出现的过拟合现象提出解决办法。


Pytorch 语法汇总:

  • Pytorch 张量的常见运算、线性代数、高等数学、概率论 部分 见 Pytorch 复习总结1;
  • Pytorch 线性神经网络 部分 见 Pytorch 复习总结2;
  • Pytorch 多层感知机 部分 见 Pytorch 复习总结3;
  • Pytorch 深度学习计算 部分 见 Pytorch 复习总结4;
  • Pytorch 卷积神经网络 部分 见 Pytorch 复习总结5;
  • Pytorch 现代卷积神经网络 部分 见 Pytorch 复习总结6;

目录

  • 一. 多层感知机
    • 1. 读取数据集
    • 2. 神经网络模型
    • 3. 激活函数
    • 4. 损失函数
    • 5. 优化器
    • 6. 训练
  • 二. 过拟合的缓解
    • 1. 权重衰减
    • 2. Dropout

一. 多层感知机

虽然线性模型易于实现和理解、计算成本低、泛化能力强,但是对于一些非线性问题,可能会违反线性模型的单调性。为此,多层感知器引入了隐藏层来克服线性模型的限制,并且加入激活函数以增强网络非线性建模能力。

1. 读取数据集

同 Pytorch 复习总结 2 中 Softmax 回归的数据读取,继续使用 Fashion-MNIST 图像分类数据集:

import torch
import torchvision
from torch.utils import data
from torchvision import transforms

def load_data_fashion_mnist(batch_size, resize=None):
    """下载Fashion-MNIST数据集并将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="./data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="./data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True),
            data.DataLoader(mnist_test, batch_size, shuffle=False))

batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

2. 神经网络模型

先将输入的图像展平,然后使用 2 个全连接层进行处理,中间的全连接层需要使用激活函数激活,最后一层全连接层作为输出:

from torch import nn
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 256),
                    nn.ReLU(),
                    nn.Linear(256, 10)
)

仍然使用 init_weights() 函数按正态分布初始化所有全连接层的权重:

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

3. 激活函数

上一节使用了 ReLU 函数进行激活,在实际应用中,还可以使用 sigmoid、tanh 等函数激活。ReLU、sigmoid、tanh 函数的梯度可视化如下:

import torch
from matplotlib import pyplot as plt

x = torch.arange(-8.0, 8.0, 0.1, requires_grad=True)
# y = torch.relu(x)
# y = torch.sigmoid(x)
y = torch.tanh(x)
y.backward(torch.ones_like(x), retain_graph=True)
plt.figure(figsize=(5, 2.5))
plt.plot(x.detach(), x.grad)
plt.show()

4. 损失函数

同 Softmax 回归:

loss = nn.CrossEntropyLoss(reduction='none')

5. 优化器

同 Softmax 回归:

trainer = torch.optim.SGD(net.parameters(), lr=0.1)

6. 训练

同 Softmax 回归,可以将训练过程封装成函数:

def accuracy(y_hat, y):
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

def train_net(net, train_iter, test_iter, loss, num_epochs, trainer):
    for epoch in range(num_epochs):     # 迭代训练轮次
        net.train()                     # 将模型设置为训练模式

        train_loss_sum = 0.0            # 训练损失总和
        train_acc_sum = 0.0             # 训练准确度总和
        sample_num = 0                  # 样本数

        for X, y in train_iter:
            y_hat = net(X)
            l = loss(y_hat, y)
            trainer.zero_grad()
            l.mean().backward()
            trainer.step()

            train_loss_sum += l.sum()
            train_acc_sum += accuracy(y_hat, y)
            sample_num += y.numel()

        train_loss = train_loss_sum / sample_num
        train_acc = train_acc_sum / sample_num

        net.eval()                      # 将模型设置为评估模式
        test_acc_sum = 0.0
        test_sample_num = 0
        for X, y in test_iter:
            test_acc_sum += accuracy(net(X), y)
            test_sample_num += y.numel()
        test_acc = test_acc_sum / test_sample_num

        print(f'epoch {epoch + 1}, '
            f'train loss {train_loss:.4f}, train acc {train_acc:.4f}, '
            f'test acc {test_acc:.4f}')
    
num_epochs = 10
train_net(net, train_iter, test_iter, loss, num_epochs, trainer)

二. 过拟合的缓解

当模型过于复杂、训练数据太少、迭代轮数太多时,就会出现过拟合现象。解决过拟合的方法有很多:

  • 增加数据量:增加训练数据可以帮助模型更好地学习数据的真实规律,减少过拟合的发生;
  • 简化模型:降低模型的复杂度,可以通过减少模型的参数数量、使用正则化等方法来实现;
  • 交叉验证:使用交叉验证来评估模型的泛化能力,选择最优的模型;
  • 提前停止:即 Dropout,在训练过程中监控模型在验证集上的表现,当验证集误差不再下降甚至开始上升时,及时停止训练,防止模型过拟合;
  • 集成学习:使用集成学习方法(如随机森林、梯度提升树等)降低模型的方差,提高泛化能力。

下面介绍几种常用的正则化方法。

1. 权重衰减

权重衰减 (Weight Decay) 通过向损失函数中添加一个惩罚项来减小模型复杂度,以防止过拟合。惩罚项也叫 正则项,通常是权重的平方和(即 L2 范数)或权重的绝对值和(即 L1 范数)乘以一个正则化系数。

以线性回归的损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 为例,使用优化器训练时,在损失函数 L ( w , b ) L(\mathbf{w}, b) L(w,b) 上添加 L2 范数如下:
L ( w , b ) + λ 2 ∥ w ∥ 2 = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 + λ 2 ∥ w ∥ 2 L(\mathbf{w}, b)+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ =\frac{1}{n} \sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right)^2+\frac{\lambda}{2}\|\mathbf{w}\|^2\\ L(w,b)+2λw2=n1i=1n21(wx(i)+by(i))2+2λw2

损失函数中没有添加偏置 b b b 的惩罚项,因为一般情况下,网络输出层的偏置项不需要正则化。代入 w \mathbf{w} w 的参数更新表达式为:
w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) \mathbf{w} \leftarrow(1-\eta \lambda) \mathbf{w}-\frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)}\left(\mathbf{w}^{\top} \mathbf{x}^{(i)}+b-y^{(i)}\right) w(1ηλ)wBηiBx(i)(wx(i)+by(i))

要想对模型进行权重衰减,只需要在实例化优化器时通过 weight_decay 指定权重衰减参数。默认情况下,PyTorch 同时衰减权重和偏移:

trainer = torch.optim.SGD(net.parameters(), lr=lr)

如果想要只衰减权重,需要指定参数:

params_to_optimize = [
    {"params": net[0].weight, 'weight_decay': wd},
    {"params":net[0].bias}
]
trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)

2. Dropout

Dropout 通过在训练过程中随机地将网络 内部 的一部分神经元的输出设置为零,即以一定的概率 “丢弃” 这些神经元。这样可以防止神经元在训练过程中过于依赖其他神经元,从而降低了网络对特定神经元的依赖性,使得网络更具鲁棒性:
在这里插入图片描述

通常情况下,Dropout 只在训练过程中使用,不在推理阶段使用,因为推理时模型需要产生确定性的输出。

Dropout 需要在网络中添加 Dropout 层,一般位于激活函数后,并且给定 dropout 概率:

dropout1, dropout2 = 0.2, 0.5

net = nn.Sequential(nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Dropout(dropout1),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Dropout(dropout2),
        nn.Linear(256, 10)
)

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

Dropout 概率的设置技巧是靠近输入层的地方设置较低的概率,远离输入层的地方设置较高的概率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/408917.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024.2.23 模拟实现 RabbitMQ —— 实现消费消息逻辑

目录 引言 函数式接口 消费者订阅消息 实现思路 关于消息确认 引言 函数式接口 Lambda 表达式的本质是匿名函数Java 函数无法脱离类而存在,所以 Java 通过引入函数式接口以支持 Lambda 表达式 特性: 函数式接口为一个 interface 类该类中有且仅有一个…

【Python笔记-设计模式】代理模式

一、说明 代理模式是一种结构型设计模式,提供对象的替代品或其占位符。代理控制着对于原对象的访问,并允许在将请求提交给对象前后进行一些处理。 (一) 解决问题 控制对对象的访问,或在访问对象前增加额外的功能或控制访问 (二) 使用场景…

统信UOS系统窗口特效设置

原文链接:统信UOS系统设置窗口特效 在今天的技术分享中,我们将探讨如何在统信UOS系统上充分利用窗口特效来美化和提升用户界面的交互体验。统信UOS作为一款注重视觉体验和用户友好性的操作系统,提供了丰富的窗口特效设置,让用户可…

R语言入门笔记2.6

描述统计 分类数据与顺序数据的图表展示 为了下面代码便于看出颜色参数所对应的值,在这里先集中介绍, col1是黑色,2是粉红,3是绿色,4是天蓝,5是浅蓝,6是紫红,7是黄色,…

Go 利用上下文进行并发计算

关注公众号【爱发白日梦的后端】分享技术干货、读书笔记、开源项目、实战经验、高效开发工具等,您的关注将是我的更新动力! 在Go编程中,上下文(context)是一个非常重要的概念,它包含了与请求相关的信息&…

Bluejay电调固件修改自检音乐、自定义启动音乐旋律

Bluejay电调固件修改自检音乐、自定义启动音乐旋律 Bluejay电调固件基本介绍Bluejay电调固件特点修改自检音乐、启动音乐旋律准备材料修改过程 Bluejay固件旋律音乐格式开头部分音符部分 收集到的音乐代码 Bluejay电调固件基本介绍 Bluejay是一种数字电调固件,用于控…

Stable Diffusion 3 发布及其重大改进

1. 引言 就在 OpenAI 发布可以生成令人瞠目的视频的 Sora 和谷歌披露支持多达 150 万个Token上下文的 Gemini 1.5 的几天后,Stability AI 最近展示了 Stable Diffusion 3 的预览版。 闲话少说,我们快来看看吧! 2. 什么是Stable Diffusion…

运维SRE-08 网络基础与进阶

今日内容 - **定时备份案例进阶.** - **定时巡检(检查系统基础指标),写入到文件中.** - 网络(抽象) 掌握与吸收时间: 直到课程结束.(第2阶段结束) - 网络基础: 网络概述,网络结构,网络设备. - 网络核心: OSI7层模型 ※※※※※※TCP/IP 3次握手 ※※※※※※TCP/IP 4…

Django入门指南:从环境搭建到模型管理系统的完整教程

环境安装: ​ 由于我的C的Anaconda 是安装在C盘的,但是没内存了,所有我将环境转在e盘,下面的命令是创建环境到指定目录中. conda create --prefixE:\envs\dj42 python3.9进入环境中: conda activate E:\envs\dj42…

【并发】CAS原子操作

1. 定义 CAS是Compare And Swap的缩写,直译就是比较并交换。CAS是现代CPU广泛支持的一种对内存中的共享数据进行操作的一种特殊指令,这个指令会对内存中的共享数据做原子的读写操作。其作用是让CPU比较内存中某个值是否和预期的值相同,如果相…

C#与VisionPro联合开发——串口通信

串口通信 串口通信是一种常见的数据传输方式,通过串行接口(串口)将数据以串行比特流的形式进行传输。在计算机和外部设备之间,串口通信通常是通过串行通信标准(如RS-232)来实现的。串口通信可以用于连接各…

AtCoder ABC342 A-D题解

华为出的比赛&#xff1f; 好像是全站首个题解哎&#xff01; 比赛链接:ABC342 Problem A: 稍微有点含金量的签到题。 #include <bits/stdc.h> using namespace std; int main(){string S;cin>>S;for(int i0;i<s.size();i){if(count(S.begin(),S.end(),S[i…

《穿越火线:枪战王者》手游客户端技术方案: 实时同步与手感优化 转载

一、项目背景 CF手游的团队有着相当丰富的FPS游戏制作经验&#xff0c;但是移动端开发经验相对匮乏。团队面对的挑战很大&#xff0c;我们需要在手机端完美还原CF十多个游戏模式&#xff0c;上百把枪械手感。 虽然我们有实时对战FPS游戏开发经验&#xff0c;但是手游网络质量…

H5获取手机相机或相册图片两种方式-Android通过webview传递多张照片给H5

需求目的&#xff1a; 手机机通过webView展示H5网页&#xff0c;在特殊场景下&#xff0c;需要使用相机拍照或者从相册获取照片&#xff0c;上传后台。 完整流程效果&#xff1a; 如下图 一、H5界面样例代码 使用html文件格式&#xff0c;文件直接打开就可以展示布局&#…

从源码学习单例模式

单例模式 单例模式是一种设计模式&#xff0c;常用于确保一个类只有一个实例&#xff0c;并提供一个全局访问点。这意味着无论在程序的哪个地方&#xff0c;只能创建一个该类的实例&#xff0c;而不会出现多个相同实例的情况。 在单例模式中&#xff0c;常用的实现方式包括懒汉…

【C语言】linux内核ipoib模块 - ipoib_send

一、中文注释 int ipoib_send(struct net_device *dev, struct sk_buff *skb,struct ib_ah *address, u32 dqpn) {struct ipoib_dev_priv *priv ipoib_priv(dev); // 获取IPoIB设备的私有数据struct ipoib_tx_buf *tx_req; // 发送请求结构体int hlen, rc; // 分别为头部长度…

安装 WSL 报错 Error code: Wsl/WININET_E_NAME_NOT_RESOLVED 问题解决

问题描述 在执行 wsl --install 安装Windows子系统Linux WSL (Windows Subsystem for Linux) 时报错&#xff1a; 无法从“https://raw.githubusercontent.com/microsoft/WSL/master/distributions/DistributionInfo.json”中提取列表分发。无法解析服务器的名称或地址 Error…

如何在本地电脑部署HadSky论坛并发布至公网可远程访问【内网穿透】

文章目录 前言1. 网站搭建1.1 网页下载和安装1.2 网页测试1.3 cpolar的安装和注册 2. 本地网页发布2.1 Cpolar临时数据隧道2.2 Cpolar稳定隧道&#xff08;云端设置&#xff09;2.3 Cpolar稳定隧道&#xff08;本地设置&#xff09;2.4 公网访问测试 总结 前言 经过多年的基础…

2000-2022年上市公司全要素生产率测算数据合计(原始数据+计算代码+结果)(LP法+OLS法+GMM法+固定效应法)

2000-2022年上市公司全要素生产率测算数据合计&#xff08;原始数据计算代码结果&#xff09;&#xff08;LP法OLS法GMM法固定效应法&#xff09; 1、时间&#xff1a;2000-2022年 2、范围&#xff1a;上市公司 3、指标&#xff1a;证券代码、证券简称、统计截止日期、固定资…

怎么自学python,大概要多久?python多久上手?

无限时长~~~~技术不断在更新&#xff0c;你的自学不也需要一直进行吗&#xff1f; 但如果是问&#xff1a;自学多长时间可以入门&#xff1f;或者可以找到工作&#xff1f;那我可以告诉你答案。 从零基础开始自学Python&#xff0c;依照每个人理解能力的不同&#xff0c;大致…