人工智能算法工程师(中级)课程14-神经网络的优化与设计之拟合问题及优化与代码详解

大家好,我是微学AI,今天给大家介绍一下人工智能算法工程师(中级)课程14-神经网络的优化与设计之拟合问题及优化与代码详解。在机器学习和深度学习领域,模型的训练目标是找到一组参数,使得模型能够从训练数据中学习到有用的模式,并对未知数据做出准确预测。这一过程涉及到解决两种主要的拟合问题:欠拟合(Underfitting)和过拟合(Overfitting)。

文章目录

  • 一、拟合问题概述
    • 欠拟合现象
    • 过拟合现象
    • 解决策略
  • 二、正则化方法
    • 1. L1正则化
    • 2. L2正则化
  • 三、正则化参数的更新
  • 四、Dropout
  • 五、代码实现

一、拟合问题概述

在机器学习领域,拟合问题是指通过训练数据找到最佳模型参数,使得模型在未知数据上的表现尽可能好。拟合问题主要包括欠拟合和过拟合两种现象。

欠拟合现象

定义:欠拟合指的是机器学习模型在训练集上的表现不佳,无法充分学习到数据的内在规律,导致模型的预测能力低下。这就好比一个学生在考试中,由于知识掌握不牢固,对已知题目的解答都做不好,更不用说应对新题目了。
原因分析:
模型复杂度低:如果模型太简单,如用线性模型去拟合非线性的数据分布,那么模型就无法捕捉到数据中的复杂模式,就像用直尺去测量曲线长度一样,永远无法得到准确的结果。
训练数据不足:模型需要足够的数据来学习和概括数据的特性。如果数据量太少,模型可能没有机会接触到数据的全貌,就像从一本书中只读了几页就想理解整本书的内容一样困难。
特征选择不当:如果使用的特征与目标预测无关或相关性弱,模型就难以从中学习到有效的信息,相当于在解决问题时选择了错误的工具。

过拟合现象

定义:过拟合是指模型在训练数据上表现得过于出色,以至于对训练数据中的噪声或偶然性细节也进行了学习,这导致模型在面对未见过的数据时,泛化能力下降。这就像一个学生过分依赖于记忆特定的例题,而没有真正理解背后的原理,因此在遇到稍微变化的问题时就束手无策。
原因分析:
模型复杂度过高:如果模型过于复杂,如高阶多项式回归,它可能会过度适应训练数据中的每一个细节,包括噪声和异常值,而不是学习数据的普遍规律。
训练数据包含噪声:现实世界的数据往往带有噪声,如果模型试图学习这些噪声,就会导致过拟合。这类似于试图从嘈杂的环境中听清对话,噪声会干扰对真实信息的理解。
训练数据量不足:即使模型复杂度适中,但如果训练数据量不够,模型仍然可能过拟合。这是因为数据量不足时,模型可能会把偶然出现的模式误认为是普遍规律。

解决策略

增加模型复杂度:对于欠拟合,可以通过增加模型复杂度来提升模型的学习能力,如使用更高阶的多项式或更复杂的神经网络结构。
增加训练数据量:无论是欠拟合还是过拟合,增加训练数据量都能帮助模型更好地学习数据的分布,提高泛化能力。
特征工程:优化特征选择,确保模型能够基于有意义的特征进行学习。
正则化:使用L1或L2正则化等技术来限制模型复杂度,防止过拟合。
交叉验证:通过交叉验证来评估模型的泛化能力,确保模型不仅在训练数据上表现好,也能在未见数据上给出准确预测。
早停法:在训练过程中监控验证集的性能,一旦发现验证集上的性能不再提升,就停止训练,避免过拟合。
在这里插入图片描述

二、正则化方法

为了解决过拟合问题,通常采用正则化方法对模型进行约束。常见的正则化方法有L1正则化和L2正则化。

1. L1正则化

L1正则化的目标函数为:
J ( θ ) = 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 + α ∑ j = 1 n ∣ θ j ∣ J(\theta) = \frac{1}{2m}\sum_{i=1}^{m}(h_{\theta}(x^{(i)}) - y^{(i)})^2 + \alpha\sum_{j=1}^{n}|\theta_j| J(θ)=2m1i=1m(hθ(x(i))y(i))2+αj=1nθj
其中,第一项为损失函数,第二项为L1正则化项, α \alpha α为惩罚系数, θ j \theta_j θj为模型参数。

2. L2正则化

L2正则化的目标函数为:
J ( θ ) = 1 2 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) 2 + α 2 ∑ j = 1 n θ j 2 J(\theta) = \frac{1}{2m}\sum_{i=1}^{m}(h_{\theta}(x^{(i)}) - y^{(i)})^2 + \frac{\alpha}{2}\sum_{j=1}^{n}\theta_j^2 J(θ)=2m1i=1m(hθ(x(i))y(i))2+2αj=1nθj2
其中,第一项为损失函数,第二项为L2正则化项, α \alpha α为惩罚系数, θ j \theta_j θj为模型参数。

三、正则化参数的更新

在优化目标函数时,我们需要对正则化参数进行更新。以下为L2正则化的参数更新公式:
θ j : = θ j − α ( 1 m ∑ i = 1 m ( h θ ( x ( i ) ) − y ( i ) ) x j ( i ) + λ θ j ) \theta_j := \theta_j - \alpha\left(\frac{1}{m}\sum_{i=1}^{m}(h_{\theta}(x^{(i)}) - y^{(i)})x_j^{(i)} + \lambda\theta_j\right) θj:=θjα(m1i=1m(hθ(x(i))y(i))xj(i)+λθj)
其中, λ = α m \lambda = \frac{\alpha}{m} λ=mα为正则化参数。
在这里插入图片描述

四、Dropout

Dropout是一种有效的正则化方法,通过在训练过程中随机丢弃部分神经元,来减少模型对特定训练样本的依赖。以下是Dropout的实现步骤:
(1)在训练过程中,按照一定概率随机丢弃神经元;
(2)在测试过程中,将所有神经元的输出乘以概率因子。

五、代码实现

以下是基于PyTorch的拟合问题及优化代码实现:

import torch
import torch.nn as nn
import torch.optim as optim
# 定义模型
class LinearRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
    def forward(self, x):
        return self.linear(x)
# 生成数据
x = torch.randn(100, 1)
y = 3 * x + 2 + torch.randn(100, 1)
# 实例化模型
model = LinearRegression(1, 1)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.01)  # L2正则化
# 训练模型
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(x)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')
# 测试模型
model.eval()
with torch.no_grad():
    predicted = model(x).detach().numpy()
    print(f'预测值:{predicted}')

通过本文的介绍,相信大家对拟合问题及优化方法有了更深入的了解。在实际应用中,可根据数据特点选择合适的正则化方法,以提高模型的泛化能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/799918.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

mysql命令练习

创建数据表grade: CREATE TABLE grade( id INT NOT NULL, sex CHAR(1), firstname VARCHAR(20) NOT NULL, lastname VARCHAR(20) NOT NULL, english FLOAT, math FLOAT, chinese FLOAT ); 向数据表grade中插…

基于springboot与vue的旅游推荐系统与门票售卖

💗博主介绍💗:✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示:文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…

图像边缘检测中Sobel算子的原理,并附OpenCV和Matlab的示例代码

Sobel算子是一种用于图像边缘检测的离散微分算子。它结合了图像的平滑处理和微分计算,旨在强调图像中强度变化显著的区域,即边缘。Sobel算子在图像处理中被广泛使用,特别是在计算机视觉和图像分析领域。 Sobel算子的原理 Sobel算子主要用于计…

zookeeper+kafka的消息队列

zookeeperKafka 两个都是消息队列的工具 消息队列 出现原因:生产者产生的消息与消费者处理消息的效率相差很大。为了避免出现数据丢失而设立的中间件。 在消息的生产者与消费之间设置一个系统,负责缓存生产者与消费者之间的消息的缓存。将消息排序。 优…

SpringMVC注解全解析:构建高效Web应用的终极指南 (上)

SpringMVC 是一个强大的 Web 框架,广泛应用于 Java Web 开发中。它通过注解简化了配置,增强了代码的可读性。本文将全面解析 SpringMVC 中常用的注解及其用法,帮助你构建高效的 Web 应用。 一. MVC介绍 MVC 是 Model View Controller 的缩写…

鸿蒙语言基础类库:【@system.bluetooth (蓝牙)】

蓝牙 说明: 开发前请熟悉鸿蒙开发指导文档:gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md点击或者复制转到。 从API Version 7 开始,该接口不再维护,推荐使用新接口[ohos.bluetooth]。本模块首批接口从API version…

反序列化漏洞详细介绍

反序列化漏洞详细介绍: 反序列化漏洞是软件开发中一个严重的安全问题,尤其在使用网络通信和持久化数据的应用中更为常见。下面是对反序列化漏洞的详细介绍: 原理 序列化是将对象的状态信息转换为可以存储或传输的格式(如字节流&#xff09…

【TAROT】韦特体系塔罗牌学习(2)——魔术师 THE MAGICIAN I

韦特体系塔罗牌学习(2)——魔术师 THE MAGICIAN I 目录 韦特体系塔罗牌学习(2)——魔术师 THE MAGICIAN I牌面分析1. 基础信息2. 图片元素 正位牌意1. 关键词/句2.爱情婚姻3. 学业事业4. 人际关系5. 其他象征意 逆位牌意1. 关键词…

python数据可视化(5)——绘制饼图

课程学习来源:b站up:【蚂蚁学python】 【课程链接:【【数据可视化】Python数据图表可视化入门到实战】】 【课程资料链接:【链接】】 Python绘制饼图分析北京天气 饼图,是一个划分为几个扇形的圆形统计图表&#xff…

【网络世界】网络基础概念

目录 🌈 前言🌈 📁 什么是网络 📁 协议 📂 概念 📂 OSI参考模型 📂 TCP/IP模型 📂 TCP/IP 和 系统分层的关系 📁 网络传输的基本流程 📂 MAC地址 &#…

文件上传接口

文章目录 开发前端接口 开发前端接口 首先这个前端的文件上传组件使用了,前端组件 首先这个接口不是一般的接口,这个接口可以提取出来,之后那里使用了,就直接放到哪里 所以这是一个万能文件上传接口 写完之后选择 头像组件 在图库中添加组件 写前端组件之后,写了前端的组件…

WPF实现一个带旋转动画的菜单栏

WPF实现一个带旋转动画的菜单栏 一、创建WPF项目及文件1、创建项目2、创建文件夹及文件3、添加引用 二、代码实现2.ControlAttachProperty类 一、创建WPF项目及文件 1、创建项目 打开VS2022,创建一个WPF项目,如下所示 2、创建文件夹及文件 创建资源文件夹&…

Go: IM系统技术架构梳理

概述 整个IM系统的一般架构如下 我们这张图展示了整个IM系统的一般架构可见分为四层那最上面这一层是前端,包括哪些东西呢? 它包括两部分,第一部分是跟用户直接交互的比如说各种IOS APP, 各种安卓 APP还有各种 web APP 在浏览器里面打开的以…

区块链学习05-web3中solidity和move语言

Solidity 和 Move 语言的比较:Web3 开发中的两种选择 Solidity 和 Move 都是用于开发区块链平台智能合约的编程语言。它们具有一些相似之处,但也存在一些关键差异。 相似之处: Solidity 和 Move 都是图灵完备语言,这意味着它们可以表达计算…

Anything in Any Scene:无缝融入任何场景,实现逼真视频对象插入技术

现实世界的视频捕获虽然因其真实性而宝贵,但常常受限于长尾分布的问题,即常见场景过度呈现,而关键的罕见场景却鲜有记录。这导致了所谓的"分布外问题",在模拟复杂环境光线、几何形状或达到高度逼真效果方面存在局限。传…

CentOS配置时钟服务

一、ntp协议 1.1 基础 NTP(Network Time Protocol,网络时间协议)是用于同步计算机网络中各个设备时间的协议。 下面了解一下 ntp 的配置选项 1.) iburst 功能: 通过发送一组八个数据包来加速初始同步。 用法: server 0.pool.ntp.org i…

Python实现简单的ui界面设计(小白入门)

引言: 当我们书写一个python程序时,我们在控制台输入信息时,往往多有不便,并且为了更加美观且直观的方式输入控制命令,我们常常设计一个ui界面,这样就能方便执行相关功能。如计算器、日历等界面。 正文&a…

Docker安装RabbitMQ(带web管理端)

1.拉取带web管理的镜像 可以拉取rabbitmq对应版本的web管理端,比如:rabbitmq:3.9.11-management,也可以直接拉取带web管理端的最新版本 rabbitmq:management. docker pull rabbitmq:3.9.11-management 注意:如果docker pull ra…

Linux目录网络设置远程工具的使用

文章目录 Linux目录虚拟机⽹络配置查看⽹络信息修改⽹络配置信息 虚拟机管理操作远程⼯具的使⽤ Linux目录 Linux的⽬录结构 Linux中的常⻅⽬录 Linux常⻅的⽬录结构,不同版本的Linux⽬录结构可能略有不同 Centos7的⽂件⽬录结构 Linux根⽬录下的常⻅⽬录及作⽤ …

C语言之qsort函数

一、qsort 1.库函数qsort qsort是库函数&#xff0c;直接可以用来排序数据&#xff0c;底层使用的是快速排序。 qsort函数可以排序任意类型的数据。 2.头文件 #include<stdlib.h> 3.参数讲解 void*类型的指针是无具体类型的指针&#xff0c;这种类型的指针的不能直接解…