深度学习基础3

目录

1.过拟合与欠拟合

1.1 过拟合

1.2 欠拟合

1.2 解决欠拟合

1.2.1 L2正则化

1.2.2 L1正则化

1.2.3 Dropout

1.2.4 简化模型

1.2.5 数据增强

1.2.6 早停

1.2.7 模型集成

1.2.8 交叉验证

2.批量标准化

2.1 实现过程

2.1.1 计算均值和方差

2.1.2 标准化

2.1.3 缩放和平移

2.1.4 标准化公式

2.2 训练和推理阶段

2.3 BatchNorm


1.过拟合与欠拟合

在训练深层神经网络时,由于模型参数较多,在数据量不足时很容易过拟合。而正则化技术主要就是用于防止过拟合,提升模型的泛化能力(对新数据表现良好)和鲁棒性(对异常数据表现良好)。

1.1 过拟合

过拟合是指模型对训练数据拟合能力很强、表现很好,但在测试数据上表现较差。

过拟合常见原因有:

  • 数据量不足:当训练数据较少时,模型可能会过度学习数据中的噪声和细节。

  • 模型太复杂:如果模型很复杂,也会过度学习训练数据中的细节和噪声。

  • 正则化强度不足:如果正则化强度不足,可能会导致模型过度学习训练数据中的细节和噪声。

1.2 欠拟合

欠拟合是由于模型学习能力不足,无法充分捕捉数据中的复杂关系。

1.2 解决欠拟合

欠拟合的解决思路比较直接:

  • 增加模型复杂度:引入更多的参数、增加神经网络的层数或节点数量,使模型能够捕捉到数据中的复杂模式。

  • 增加特征:通过特征工程添加更多有意义的特征,使模型能够更好地理解数据。

  • 减少正则化强度:适当减小 L1、L2 正则化强度,允许模型有更多自由度来拟合数据。

  • 训练更长时间:如果是因为训练不足导致的欠拟合,可以增加训练的轮数或时间.

1.2.1 L2正则化

L2 正则化通过在损失函数中添加权重参数的平方和来实现,目标是惩罚过大的参数值。

数学表示:

设损失函数为 L(\theta),其中 \theta 表示权重参数,加入L2正则化后的损失函数表示为:

其中:

  • L(\theta) 是原始损失函数(比如均方误差、交叉熵等)。

  • \lambda 是正则化强度,控制正则化的力度。

  • \theta_i 是模型的第 i 个权重参数。

  • \frac{1}{2} \sum_{i} \theta_i^2 是所有权重参数的平方和,称为 L2 正则化项。

L2 正则化会惩罚权重参数过大的情况,通过参数平方值对损失函数进行约束。

梯度更新:

在 L2 正则化下,梯度更新时,不仅要考虑原始损失函数的梯度,还要考虑正则化项的影响。更新公式:

其中:

  • \eta 是学习率。

  • \nabla L(\theta_t) 是损失函数关于参数 \theta_t 的梯度。

  • \lambda \theta_t 是 L2 正则化项的梯度,对应的是参数值本身的衰减。

很明显,参数越大惩罚力度就越大,从而让参数逐渐趋向于较小值,避免出现过大的参数。

API:

optimizer = optim.SGD(model.parameters(), lr, weight_decay)

import torch
import torch.nn as nn
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        # 输入层
        self.linear1 = nn.Linear(2,2)
        self.linear1.weight.data=torch.tensor([[0.15,0.20],
                                               [0.25,0.30]])
        self.linear1.bias.data = torch.tensor([0.35], dtype=torch.float32)

        # 输出层
        self.linear2 = nn.Linear(2,2)
        self.linear2.weight.data=torch.tensor([[0.40,0.45],
                                               [0.50,0.55]])
        self.linear2.bias.data = torch.tensor([0.60], dtype=torch.float32)
        self.activation = nn.Sigmoid()


    def forward(self,input):
        x = self.linear1(input)
        x = self.activation(x)
        x = self.linear2(x)
        output = self.activation(x)
        return output
    
        
def backward():
    model = Net()

    optimizer = optim.SGD(model.parameters(),lr=0.1,weight_decay = 0.01)

    for epoch in range(100):
          
        input = torch.tensor([0.05,0.10])
        true  = torch.tensor([0.01,0.99])
        predict = model.forward(input)
        mse = nn.MSELoss()

        loss = mse(predict,true)
        print(f"{epoch}",loss)
        
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
    
backward()

1.2.2 L1正则化

L1 正则化通过在损失函数中添加权重参数的绝对值之和来约束模型的复杂度。

数学表示:

设模型的原始损失函数为 L(\theta),其中 \theta 表示模型权重参数,则加入 L1 正则化后的损失函数表示为:

其中:

  • L(\theta) 是原始损失函数。

  • \lambda 是正则化强度,控制正则化的力度。

  • |\theta_i| 是模型第i 个参数的绝对值。

  • \sum_{i} |\theta_i| 是所有权重参数的绝对值之和,这个项即为 L1 正则化项。

梯度更新:

在 L1 正则化下,梯度更新公式:

其中:

  • \eta 是学习率。

  • \nabla L(\theta_t) 是损失函数关于参数 \theta_t 的梯度。

  • \text{sign}(\theta_t) 是参数 \theta_t 的符号函数,表示当 \theta_t 为正时取值为 1,为负时取值为 -1,等于 0 时为 0。

因为 L1 正则化依赖于参数的绝对值,其梯度更新时不是简单的线性缩小,而是通过符号函数来直接调整参数的方向。

L1与L2对比:

  • L1 正则化 更适合用于产生稀疏模型,会让部分权重完全为零,适合做特征选择。

  • L2 正则化 更适合平滑模型的参数,避免过大参数,但不会使权重变为零,适合处理高维特征较为密集的场景。

import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(5,1)

loss_fun = torch.nn.MSELoss()

optimizer = optim.SGD(model.parameters(),lr = 0.01)
for epoch in range(100):
    train_x = torch.tensor([[0.40,0.45,0.84,0.54,0.2],
                            [0.50,0.55,0.92,0.34,0.6]])

    predict = model(train_x)
    target = torch.tensor([[0.5],[0.8]])

    # L1正则化项并将其加入到总损失中
    l1_lambda = 0.001
    l1_norm = sum(p.abs().sum() for p in model.parameters())
    loss  = loss_fun(predict,target) + l1_lambda*l1_norm

    print(loss)
    if model.weight.grad is not None:
        model.weight.grad.zero_()
    loss.backward()
    optimizer.step()

1.2.3 Dropout

Dropout 是一种在训练过程中随机丢弃部分神经元的技术。它通过减少神经元之间的依赖来防止模型过于复杂,从而避免过拟合。

nn.Dropout(p)

 参数:p:每一个神经元被丢弃的概率

x = torch.randint(0, 10, (5, 6), dtype=torch.float)

# 每一个神经元有p的概率被丢弃
dropout = nn.Dropout(p=0.5)

x = dropout(x)
print(x)
print(x.shape)
print(x==0)
# p不一定等于 死亡神经元占总神经元的比例
print(sum(sum(x==0))/(x.shape[0]*x.shape[1]))

1.2.4 简化模型

  • 减少网络层数和参数: 通过减少网络的层数、每层的神经元数量或减少卷积层的滤波器数量,可以降低模型的复杂度,减少过拟合的风险。

  • 使用更简单的模型: 对于复杂问题,使用更简单的模型或较小的网络架构可以减少参数数量,从而降低过拟合的可能性。

1.2.5 数据增强

通过对训练数据进行各种变换(如旋转、裁剪、翻转、缩放等),可以增加数据的多样性,提高模型的泛化能力。

1.2.6 早停

一种在训练过程中监控模型在验证集上的表现,并在验证误差不再改善时停止训练的技术。这样可避免训练过度,防止模型过拟合。

1.2.7 模型集成

通过将多个不同模型的预测结果进行集成,可以减少单个模型过拟合的风险。常见的集成方法包括投票法、平均法和堆叠法。

1.2.8 交叉验证

使用交叉验证技术可以帮助评估模型的泛化能力,并调整模型超参数,以防止模型在训练数据上过拟合。

2.批量标准化

2.1 实现过程

批量标准化的基本思路是在每一层的输入上执行标准化操作,并学习两个可训练的参数:缩放因子 \lambda 偏移量 \beta。

2.1.1 计算均值和方差

对于给定的神经网络层,假设输入数据为 \mathbf{x} = {x_1, x_2, \ldots, x_m},其中 m是批次大小。首先计算该批次数据的均值和方差。

2.1.2 标准化

使用计算得到的均值和方差对数据进行标准化,使得每个特征的均值为0,方差为1。

2.1.3 缩放和平移

标准化后的数据通常会通过可训练的参数进行缩放和平移,以恢复模型的表达能力。

 

2.1.4 标准化公式

其中,\gamma 和 \beta 是在训练过程中学习到的参数。

  • λ 和 β 是可学习的参数,它相当于对标准化后的值做了一个线性变换,λ 为系数,β 为偏置;

  • \epsilon通常指为 1e-5,避免分母为 0;

  • \mu_B 表示变量的均值;

  • \sigma_B^2 表示变量的方差;

2.2 训练和推理阶段

  • 训练阶段: 在训练过程中,均值和方差是基于当前批次的数据计算得到的。

  • 推理阶段: 在推理阶段,批量标准化使用的是训练过程中计算得到的全局均值和方差,而不是当前批次的数据。这些全局均值和方差通常会被保存在模型中,用于推理时的标准化过程。

2.3 BatchNorm

数据在经过 BN 层之后,无论数据以前的分布是什么,都会被归一化成均值为 β,标准差为 γ 的分布。

注意:BN 层不会改变输入数据的维度,只改变输入数据的的分布. 在实际使用过程中,BN 常常和卷积神经网络结合使用,卷积层的输出结果后接 BN 层。

API:

torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True)

 参数:

  • 由于每次使用的 mini batch 的数据集,所以 BN 使用移动加权平均来近似计算均值和方差,而 momentum 参数则调节移动加权平均值的计算;

  • affine = False 表示 \lambda=1,β=0,反之,则表示 γ 和 β 要进行学习;

  • BatchNorm2d 适用于输入的数据为 4D,输入数据的形状 [N,C,H,W]

    其中:N 表示批次,C 代表通道数,H 代表高度,W 代表宽度

由于每次输入到网络中的时小批量的样本,我们使用指数加权平均来近似表示整体的样本的均值和方差,其更新公式如下:

running_mean = momentum * running_mean + (1.0 – momentum) * batch_mean
running_var = momentum * running_var + (1.0 – momentum) * batch_var

batch_mean 和 batch_var 表示当前批次的均值和方差。而 running_mean 和 running_var 是近似的整体的均值和方差的表示。当我们进行评估时,可以使用该均值和方差对输入数据进行归一化。

x = torch.randint(0,10,(4,3,4,5)).float()
# 批量标准化
bn = nn.BatchNorm2d(num_features = x.shape[1],eps =1e-8,affine =True,momentum=0.9)
print(bn(x))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/925287.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Scala习题

姓名,语文,数学,英语 张伟,87,92,88 李娜,90,85,95 王强,78,90,82 赵敏,92,88,91 孙涛&#xff0c…

【赵渝强老师】PostgreSQL的数据库

PostgreSQL的逻辑存储结构主要是指数据库中的各种数据库对象,包括:数据库集群、数据库、表、索引、视图等等。所有数据库对象都有各自的对象标识符oid(object identifiers),它是一个无符号的四字节整数,相关对象的oid都…

(C语言) 8大翻译阶段

(C语言) 8大翻译阶段 文章目录 (C语言) 8大翻译阶段⭐前言🗃️8大阶段🗂️1. 字符映射🗂️2. 行分割🗂️3. 标记化🗂️4. 预处理🗂️5. 字符集映射🗂️6. 字符串拼接🗂️7. 翻译&…

安全基线检查

一、安全基线检测基础知识 安全基线的定义 安全基线检查的内容 安全基线检查的操作 二、MySQL的安全基线检查 版本加固 弱口令 不存在匿名账户 合理设置权限 合理设置文件权限 日志审核 运行账号 可信ip地址控制 连接数限制 更严格的基线要求 1、禁止远程连接数据库 2、修改…

玩转 uni-app 静态资源 static 目录的条件编译

一. 前言 老生常谈,了解 uni-app 的开发都知道,uni-app 可以同时支持编译到多个平台,如小程序、H5、移动端 App 等。它的多端编译能力是 uni-app 的一大特点,让开发者可以使用同一套代码基于 Vue.js 的语法编写程序,然…

[2024年3月10日]第15届蓝桥杯青少组stema选拔赛C++中高级(第二子卷、编程题(2))

方法一&#xff08;string&#xff09;&#xff1a; #include <iostream> #include <string> using namespace std;// 检查是否为回文数 bool isPalindrome(int n) {string str to_string(n);int left 0, right str.size() - 1;while (left < right) {if (s…

快速排序hoare版本和挖坑法(代码注释版)

hoare版本 #define _CRT_SECURE_NO_WARNINGS 1 #include <stdio.h>// 交换函数 void Swap(int* p1, int* p2) {int tmp *p1;*p1 *p2;*p2 tmp; }// 打印数组 void _printf(int* a, int n) {for (int i 0; i < n; i) {printf("%d ", a[i]);}printf("…

C5.【C++ Cont】getchar,putchar和scanf

目录 1.回顾C语言文章24.【C语言】getcha和putchar的使用 2.C中和C语言不同的地方 3.关键点 4.scanf 5.练习1 题目描述 输入描述: 输出描述: 输入 输出 6.练习2 题目描述 输入格式 输出格式 输入输出样例 说明/提示 1.回顾C语言文章24.【C语言】getcha和putchar…

深入理解 AI 产品的核心价值——《AI产品经理手册》

现在&#xff0c;人们对AI 充满了兴趣和看法。这些年&#xff0c;我亲身经历了对AI 的感受和认识的此起彼伏。我还是学生时&#xff0c;就对AI 以及伴随而来的第四次工业革命感到无比激动和期待。然而&#xff0c;当我开始组织读书会&#xff0c;每月阅读有关AI 的书籍&#xf…

Spring Boot拦截器(Interceptor)详解

拦截器Interceptor 拦截器我们主要分为三个方面进行讲解&#xff1a; 介绍下什么是拦截器&#xff0c;并通过快速入门程序上手拦截器拦截器的使用细节通过拦截器Interceptor完成登录校验功能 1. 快速入门 什么是拦截器&#xff1f; 是一种动态拦截方法调用的机制&#xff…

python代码示例(读取excel文件,自动播放音频)

目录 python 操作excel 表结构 安装第三方库 代码 自动播放音频 介绍 安装第三方库 代码 python 操作excel 表结构 求出100班同学的平均分 安装第三方库 因为这里的表结构是.xlsx文件,需要使用openpyxl库 如果是.xls格式文件,需要使用xlrd库 pip install openpyxl /…

构建 LLM (大型语言模型)应用程序——从入门到精通(第七部分:开源 RAG)

通过检索增强生成 (RAG) 应用程序的视角学习大型语言模型 (LLM)。 本系列博文 简介数据准备句子转换器矢量数据库搜索与检索大语言模型开源 RAG&#xff08;本帖&#xff09;评估服务LLM高级 RAG 1. 简介 我们之前的博客文章广泛探讨了大型语言模型 (LLM)&#xff0c;涵盖了其…

2024健康大数据与智能医疗(ICHIH 2024)

大会官网&#xff1a;www.ic-ichih.net 大会时间&#xff1a;2024年12月13-15日 大会地点&#xff1a;中国珠海 收录检索&#xff1a;IEEE Xplore&#xff0c;EI Compendex&#xff0c;Scopus

从0开始学PHP面向对象内容之常用设计模式(适配器,桥接,装饰器)

二&#xff0c;结构型设计模式 上两期咱们讲了创建型设计模式&#xff0c;都有 单例模式&#xff0c;工厂模式&#xff0c;抽象工厂模式&#xff0c;建造者模式&#xff0c;原型模式五个设计模式。 这期咱们讲结构型设计模式 1、适配器模式&#xff08;Adapter&#xff09; …

原生微信小程序画表格

wxml部分&#xff1a; <view class"table__scroll__view"><view class"table__header"><view class"table__header__item" wx:for"{{TableHeadtitle}}" wx:key"index">{{item.title}}</view></…

TDengine 签约深圳综合粒子,赋能粒子研究新突破

在高能物理和粒子研究领域&#xff0c;实验装置的不断升级伴随着海量数据的产生与处理。尤其是随着大湾区综合性国家科学中心的建设步伐加快&#xff0c;深圳综合粒子设施研究院&#xff08;以下简称“研究院”&#xff09;作为承载“双区驱动”战略的重要科研机构&#xff0c;…

SpringMVC——SSM整合

SSM整合 创建工程 在pom.xml中导入坐标 <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_…

jenkins 2.346.1最后一个支持java8的版本搭建

1.jenkins下载 下载地址&#xff1a;Index of /war-stable/2.346.1 2.部署 创建目标文件夹&#xff0c;移动到指定位置 创建一个启动脚本&#xff0c;deploy.sh #!/bin/bash set -eDATE$(date %Y%m%d%H%M) # 基础路径 BASE_PATH/opt/projects/jenkins # 服务名称。同时约定部…

Apache-maven在Windows中的安装配置及Eclipse中的使用

Apache Maven 是一个自动化项目管理工具&#xff0c;用于构建&#xff0c;报告和文档的项目管理工具。以下是在不同操作系统上安装和配置 Maven 的基本步骤&#xff1a; 安装 Maven 下载 Maven: apache-maven-3.9.9下载地址&#xff0c;也可访问 Apache Maven 官方网站 下载最…

【MySQL】MySQL从入门到放弃

文章目录 声明MYSQL一,架构1.1.网络连接层数据库连接池 1.2.系统服务层1.2.1.SQL接口1.2.2.存储过程1.2.3.触发器1.2.4.解析器1.2.5.优化器1.2.6.缓存,缓冲 1.3.存储引擎层1.4.文件系统层1.4.1.日志模块1.4.2.数据模块 二,SQL 执行2.1.执行流程2.2.刷盘2.3.返回 三.库表设计3.1…