【机器学习】P8 过拟合与欠拟合、正则化与正则化后的损失函数和梯度下降

过拟合与欠拟合、正则化与正则化后的损失函数和梯度下降

  • 过拟合与欠拟合
    • 过拟合与欠拟合直观理解
    • 线性回归中 过拟合与欠拟合
    • 逻辑回归中 过拟合与欠拟合
  • 过拟合与欠拟合的解决办法
    • 过拟合解决方案
    • 欠拟合解决方案
  • 包含正则化的损失函数
    • 正则化线性回归损失函数
    • 正则化逻辑回归损失函数
  • 包含正则化的梯度下降
    • 正则化线性回归梯度下降
    • 正则化逻辑回归梯度下降
  • Reference

过拟合与欠拟合

过拟合与欠拟合直观理解

过拟合与欠拟合,就像你早餐三明治一样,太热了下不了嘴,太凉了吃了肚子疼,只有正正好好才刚刚好;下图 Fig.1 为周志华老师《机器学习》一书中我认为是对过拟合和欠拟合最简单的理解方式,请查阅:

图片来源于《西瓜书》

Fig. 1. 过拟合与欠拟合
  • 过拟合指的是模型在训练数据上表现良好,但在新的数据上表现较差。过拟合通常是因为模型太复杂,导致它过度适应了训练数据的噪声和细节,而忽略了数据中的一般趋势和规律。

  • 欠拟合指的是模型无法在训练数据上得到很好的拟合,更不用说在新的数据上进行预测。欠拟合通常是因为模型过于简单,无法捕捉到数据的复杂性和变化。


线性回归中 过拟合与欠拟合

在这里插入图片描述

Fig. 2. 线性回归 欠拟合-正好-过拟合

欠拟合: 在线性回归中,欠拟合通常指模型无法很好地拟合训练数据的情况,通常是由于模型过于简单( w 1 x + b w_1x+b w1x+b),无法捕捉数据的本质规律所导致。

良好的泛化能力: 我们创建模型的目标是能够使用该模型对新的例子正确预测结果。一个能做到这一点的模型被称为具有良好的泛化能力。

过拟合: 在线性回归中,过拟合通常指模型在训练集上表现良好,但是在测试集或新数据上表现较差的情况。这通常是由于模型的复杂度过高( w 1 x + w 2 x 2 + w 3 x 3 + w 4 x 4 + b w_1x+w_2x^2+w_3x^3+w_4x^4+b w1x+w2x2+w3x3+w4x4+b),或者训练数据量不足导致的。


逻辑回归中 过拟合与欠拟合

在这里插入图片描述

Fig. 3. 逻辑回归 欠拟合-正好-过拟合

欠拟合: 逻辑回归模型的欠拟合问题通常是由于模型太简单,无法捕捉数据中的非线性关系和复杂模式,从而导致模型在训练集和测试集上都表现不佳

过拟合: 逻辑回归模型的过拟合问题和线性回归模型类似,通常是由于模型过于复杂,拟合了过多的训练数据的噪声和异常值,而导致模型在测试集上表现不佳


过拟合与欠拟合的解决办法

过拟合解决方案

  1. 增加训练数据量: 通过增加训练数据量,可以使模型更好地学习数据的本质规律,减少对噪声和随机变化的过拟合。
  2. 减少模型的复杂度: 通过减少模型的复杂度,例如减少特征数量或者降低多项式阶数,可以使模型更加简单,从而减少过拟合的可能性。
  3. 正则化: 通过在损失函数中添加正则项,可以使模型更加平滑,从而减少过拟合的可能性。常见的正则化方法包括L1正则化和L2正则化。正则化内容将在本博文下述部分单开内容。
  4. 交叉验证: 通过将数据集划分为训练集和验证集(Validation Set),在验证集上评估模型的泛化性能,可以帮助选择最优模型和避免过拟合,关于验证集的部分内容会在后期阐述。验证集的主要内容就是通过模拟训练集一样提前验证,从而得知模型是不是过拟合了。

欠拟合解决方案

  1. 增加模型的复杂度: 与过拟合相反,欠拟合需要通过增加模型的复杂度,例如增加多项式阶数或者添加更多的特征,从而使模型更加灵活,更好地拟合数据中的规律。
  2. 增加训练数据量: 数据量过少,极端情况比如就一个,模型必然欠拟合。通过增加训练数据量,可以使模型更好地学习数据的本质规律,从而减少欠拟合的可能性。
  3. 减少正则化强度: 如果采用了正则化方法,可以尝试减少正则化强度,使模型更加灵活。正则化内容见下述。
  4. 换一个模型试试: 模型的选择在机器学习中是非常重要的,选择一个错误的模型,再调试也是有上限的。所以如果上述的办法都没解决欠拟合的问题,不如回头看看是不是一开始对于模型的选择上就是欠考虑的。

包含正则化的损失函数

正则化线性回归损失函数

正则化线性回归是一种在线性回归模型中添加正则化项来控制模型复杂度和避免过拟合的方法。常见的正则化方法包括L1正则化(Lasso)和L2正则化(Ridge),L1正则化和L2正则化的主要区别在于它们所惩罚的模型参数不同:

  • L1正则化 会倾向于使某些参数为0,从而实现特征选择的效果;
  • L2正则化 则会让所有参数都尽量小,但不会使它们变成0。

而因为我们在线性回归中,使用正则化时经常不会得知具体哪些参数需要调整,所以更多用的是L2正则化,即将所有的参数都调整,从而保留了所有特征的信息。

下面内容将具体阐述如何正则化:

未使用正则化的损失函数:
J ( w ⃗ , b ) = 1 2 m ∑ i = 0 m − 1 ( f w ⃗ , b ( x ⃗ ( i ) ) − y ( i ) ) 2 J(\vec{w},b)=\frac 1 {2m} \sum ^{m-1} _{i=0}(f_{\vec{w},b}(\vec{x}^{(i)})-y^{(i)})^2 J(w ,b)=2m1i=0m1(fw ,b(x (i))y(i))2

使用正则化后的损失函数:
J ( w ⃗ , b ) = 1 2 m ∑ i = 0 m − 1 ( f w ⃗ , b ( x ⃗ ( i ) ) − y ( i ) ) 2 + λ 2 m ∑ j = 0 n − 1 w j 2 J(\vec{w},b)=\frac 1 {2m} \sum ^{m-1} _{i=0}(f_{\vec{w},b}(\vec{x}^{(i)})-y^{(i)})^2+\frac {\lambda} {2m} \sum ^{n-1} _{j=0} w_j^2 J(w ,b)=2m1i=0m1(fw ,b(x (i))y(i))2+2mλj=0n1wj2

对比正则化前与正则化后的损失函数,发现正则化后损失函数多出一个部分: λ 2 m ∑ j = 0 n − 1 w j 2 \frac {\lambda} {2m} \sum ^{n-1} _{j=0} w_j^2 2mλj=0n1wj2,该部分称为正则化项。通过该项,增大损失函数值,从而在梯度下降(最小化损失函数)时,对模型参数的取值施加限制,模型的参数会更倾向于取值较小的范围,从而降低过拟合风险。

代码实现:

def compute_cost_linear_reg(X, y, w, b, lambda_ = 1):

    m  = X.shape[0]
    n  = len(w)
    cost = 0.
    for i in range(m):
        f_wb_i = np.dot(X[i], w) + b
        cost = cost + (f_wb_i - y[i])**2
    cost = cost / (2 * m)						# 不包含正则化项的损失函数
 
    reg_cost = 0
    for j in range(n):
        reg_cost += (w[j]**2)					# 正则化项
    reg_cost = (lambda_/(2*m)) * reg_cost
    
    total_cost = cost + reg_cost
    return total_cost							# 返回的改损失函数包含正则化项

正则化逻辑回归损失函数

正则化逻辑回归同样是通过添加正则化项来控制模型复杂度和避免过拟合的方法。同样常见的也是L1和L2,甚至对于损失函数的更改都是一样的:

未使用正则化的损失函数:
J ( w ⃗ , b ) = 1 m ∑ i = 0 m − 1 [ ( − y ( i ) l o g ( f w ⃗ , b ( x ⃗ ( i ) ) ) ) − ( 1 − y ( i ) ) l o g ( 1 − f w ⃗ , b ( x ⃗ ( i ) ) ) ] J(\vec{w},b)=\frac 1 m \sum ^{m-1} _{i=0} [(-y^{(i)}log(f_{\vec{w},b}(\vec{x}^{(i)})))-(1-y^{(i)})log(1-f_{\vec{w},b}(\vec{x}^{(i)}))] J(w ,b)=m1i=0m1[(y(i)log(fw ,b(x (i))))(1y(i))log(1fw ,b(x (i)))]

使用正则化后的损失函数:
J ( w ⃗ , b ) = 1 m ∑ i = 0 m − 1 [ ( − y ( i ) l o g ( f w ⃗ , b ( x ⃗ ( i ) ) ) ) − ( 1 − y ( i ) ) l o g ( 1 − f w ⃗ , b ( x ⃗ ( i ) ) ) ] + λ 2 m ∑ j = 0 n − 1 w j 2 J(\vec{w},b)=\frac 1 m \sum ^{m-1} _{i=0} [(-y^{(i)}log(f_{\vec{w},b}(\vec{x}^{(i)})))-(1-y^{(i)})log(1-f_{\vec{w},b}(\vec{x}^{(i)}))]+\frac {\lambda} {2m}\sum ^{n-1} _{j=0} w_j^2 J(w ,b)=m1i=0m1[(y(i)log(fw ,b(x (i))))(1y(i))log(1fw ,b(x (i)))]+2mλj=0n1wj2

同样,增加了正则化项: λ 2 m ∑ j = 0 n − 1 w j 2 \frac {\lambda} {2m} \sum ^{n-1} _{j=0} w_j^2 2mλj=0n1wj2

代码实现:

def compute_cost_logistic_reg(X, y, w, b, lambda_ = 1):

    m,n  = X.shape
    cost = 0.
    for i in range(m):
        z_i = np.dot(X[i], w) + b
        f_wb_i = sigmoid(z_i)
        cost +=  -y[i]*np.log(f_wb_i) - (1-y[i])*np.log(1-f_wb_i)
    cost = cost/m								# 标准逻辑回归损失函数部分

    reg_cost = 0								# 添加正则化损失函数部分
    for j in range(n):
        reg_cost += (w[j]**2)
    reg_cost = (lambda_/(2*m)) * reg_cost
    
    total_cost = cost + reg_cost
    return total_cost							# 返回损失函数包含常规逻辑回归损失函数部分以及正则化损失函数部分

包含正则化的梯度下降

由于正则化对于线性回归与逻辑回归的损失函数操作相同,所以我们可以将其一起要讨论:
repeat until convergence:    {        w j = w j − α ∂ J ( w , b ) ∂ w j    for j := 0..n-1            b = b − α ∂ J ( w , b ) ∂ b } \begin{align*} &\text{repeat until convergence:} \; \lbrace \\ & \; \; \;w_j = w_j - \alpha \frac{\partial J(\mathbf{w},b)}{\partial w_j} \; & \text{for j := 0..n-1} \\ & \; \; \; \; \;b = b - \alpha \frac{\partial J(\mathbf{w},b)}{\partial b} \\ &\rbrace \end{align*} repeat until convergence:{wj=wjαwjJ(w,b)b=bαbJ(w,b)}for j := 0..n-1
有:
∂ J ( w , b ) ∂ w j = 1 m ∑ i = 0 m − 1 ( f w , b ( x ( i ) ) − y ( i ) ) x j ( i ) + λ m w j ∂ J ( w , b ) ∂ b = 1 m ∑ i = 0 m − 1 ( f w , b ( x ( i ) ) − y ( i ) ) \begin{align*} \frac{\partial J(\mathbf{w},b)}{\partial w_j} &= \frac{1}{m} \sum\limits_{i = 0}^{m-1} (f_{\mathbf{w},b}(\mathbf{x}^{(i)}) - y^{(i)})x_{j}^{(i)} + \frac{\lambda}{m} w_j\\ \frac{\partial J(\mathbf{w},b)}{\partial b} &= \frac{1}{m} \sum\limits_{i = 0}^{m-1} (f_{\mathbf{w},b}(\mathbf{x}^{(i)}) - y^{(i)}) \end{align*} wjJ(w,b)bJ(w,b)=m1i=0m1(fw,b(x(i))y(i))xj(i)+mλwj=m1i=0m1(fw,b(x(i))y(i))

说明:为什么只对 w 而不对 b 做正则化操作:
在线性回归正则化中,通常只对权重 w 进行正则化而不对偏置 b 进行正则化,这是因为偏置 b 的影响通常不如权重 w 显著。偏置b是一个常数,它可以看作是对预测值的一个基础偏移。因为它对预测值的影响相对较小,所以即使它的值很大,对模型的影响也会相对较小。

正则化线性回归梯度下降

def compute_gradient_linear_reg(X, y, w, b, lambda_): 

    m,n = X.shape
    dj_dw = np.zeros((n,))
    dj_db = 0.

    for i in range(m):                             
        err = (np.dot(X[i], w) + b) - y[i]
        dj_dw = dj_dw + err * X[i]
        # for j in range(n):
            # dj_dw[j] = dj_dw[j] + err * X[i, j]
        dj_db = dj_db + err
    dj_dw = dj_dw / m
    dj_db = dj_db / m
    
    for j in range(n):
        dj_dw[j] = dj_dw[j] + (lambda_/m) * w[j]		# 加入了正则化损失函数带来的影响

    return dj_db, dj_dw

正则化逻辑回归梯度下降

def compute_gradient_logistic_reg(X, y, w, b, lambda_): 

    m,n = X.shape
    dj_dw = np.zeros((n,))
    dj_db = 0.0
    
    for i in range(m):
        f_wb_i = sigmoid(np.dot(X[i],w) + b)
        err_i  = f_wb_i  - y[i]
        dj_dw = dj_dw + err_i * X[i]
        # for j in range(n):
            # dj_dw[j] = dj_dw[j] + err_i * X[i,j]
        dj_db = dj_db + err_i
    dj_dw = dj_dw/m
    dj_db = dj_db/m

    for j in range(n):
        dj_dw[j] = dj_dw[j] + (lambda_/m) * w[j]		# 加入正则化损失函数带来影响

    return dj_db, dj_dw  

Reference

[1]. 周志华. (2016). 《机器学习》(第1版). 清华大学出版社. https://book.douban.com/subject/26708119/

[2]. Ng, A. (n.d.). The problem of overfitting. Coursera. Retrieved from https://www.coursera.org/learn/machine-learning/lecture/erGPe/the-problem-of-overfitting

[3]. Ng, A. (n.d.). Optional Lab: Regularization. Coursera. Retrieved from https://www.coursera.org/learn/machine-learning/ungradedLab/36A9A/optional-lab-regularization

  • This is a lab exercise on regularization from the “Machine Learning” course by Andrew Ng on Coursera. It covers both L1 and L2 regularization for linear regression and logistic regression, and provides hands-on experience with implementing regularization in Python using NumPy.

[4]. Ng, A. (n.d.). Optional Lab: Overfitting. Coursera. Retrieved from https://www.coursera.org/learn/machine-learning/ungradedLab/3nraU/optional-lab-overfitting

  • This is a lab exercise on overfitting from the “Machine Learning” course by Andrew Ng on Coursera. It covers the concepts of overfitting, underfitting, and bias-variance tradeoff, and provides hands-on experience with identifying and addressing overfitting in linear regression and neural networks using regularization and early stopping techniques.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/5194.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

java爬虫利器Jsoup的使用

对于长期使用java做编程的程序猿应该知道,java支持的爬虫框架还是有很多的,如:ebMagic、Spider、Jsoup等。今天我们就用Jsoup来实现一个小小的爬虫程序,Jsoup作为kava的HTML解析器,可以直接对某个URL地址、HTML文本内容…

焦虑真的好吗 过度的焦虑存在哪些影响

日常常见的焦虑情绪真的好吗?焦虑是我们七情中的一种正常情绪表现,我们生活当中很多因素都可能会导致我们产生焦虑的情绪表现,如一场考试、一次挑战、一个活动等等。这种焦虑情绪的产生并不是一件坏事,相反,焦虑情绪的…

ROS学习笔记(零):ROS与机器人概述

ROS学习笔记(零):ROS与机器人概述ROSROS的起源ROS的特点ROS架构设计机器人机器人的定义机器人的组成执行机构驱动系统传感系统控制系统ROS ROS的起源 ROS(Robot Operating System)是一个广泛使用的机器人操作系统&…

Python图片相册批处理器的设计与实现批量添加图片水印、批量命名等功能

课题研究使用Python语言开发一个包含批量添加图片水印、批量命名等功能的图片批处理程序,功能模块大概包含以下模块: (1)首页模块:首页是整个软件的初始页面,包含用户登录、注册、关于本软件等功能&#xf…

红日(vulnstack)5 内网渗透ATTCK实战

环境配置 链接:百度网盘 请输入提取码 提取码:l8r7 攻击机:kali2022.03 192.168.135.128(NET模式) win7 192.168.138.136 (仅主机模式) 192.168.135.150 (NET模式) win2008 192.168.138.138 (仅主机模式) web渗透 1.nmap探测目标靶机开…

Qt学习笔记之SQLITE数据库

1. SQLite数据库介绍 SQLite,是一款轻型的数据库,是遵守ACID的关系型数据库管理系统,它包含在一个相对小的C库中。它是D.RichardHipp建立的公有领域项目。它的设计目标是嵌入式的,而且已经在很多嵌入式产品中使用了它,…

SpringBoot(1)基础入门

SpringBoot基础入门SpringBoot项目创建方式Idea创建SpringBoot官网创建基于阿里云创建项目手工搭建SpringBoot启动parentstarter引导类内嵌tomcat基础配置属性配置配置文件分类yaml文件yaml数据读取整合第三方技术整合JUnit整合MyBatis整合Mybatis-Plus整合DruidSpringBoot是由…

运动健康路线导入,助力用户轻松导航

华为HMS Core运动健康服务支持通过REST API,以GPX文件格式写入用户路线数据,支持导入轨迹(Track)或路程(Route)类型的数据,实现用户路线数据在华为运动健康App中的展示效果。 假若与华为运动健…

​selenium+python做web端自动化测试框架与实例详解教程​

下面有详细的代码介绍,如果不是很明白的话,可以看看这套视频,在哔站学习人数超过数万人! 在华为工作了10年的大佬出的Web自动化测试教程,华为现用技术教程!_哔哩哔哩_bilibili在华为工作了10年的大佬出的W…

分享NVIDIA GTC干货_用软件引领车辆电子架构

随着软件定义功能变得更多,车辆电气/电子架构正在从分布式计算演变为集中式计算。通过将这台集中式超级计算机与人工智能融合在一起,开发模块化软件并创建数据中心基础设施。 电子架构 EEA(Electrical and Electronic Architecture) 首先介绍下EEA&am…

Ansys Zemax | 如何建模离轴抛物面镜

离轴抛物面反射镜是光学工业中一种重要的设计类型。本文演示了如何根据制造商给出的规格设计一个离轴抛物面反射镜,并演示如何使用主光线求解将像面中心与主光线路径对齐。(联系我们获取文章附件) 简介 离轴抛物面反射镜的优点是光束通过反射到达像面途中将不会受…

Winform控件开发(25)——TabControl(史上最全)

一、属性 1、Name 用于获取控件对象 2、AllowDrop 指示用户是否可以拖动数据到TabCotrol上 3、TabCotrol 3.1 Top 沿控件的底部放置选项卡 3.2 Left 沿控件的左边缘放置选项卡 3.3 Right 沿控件的右边缘放置选项卡 3.4 Bottom 沿控件的顶部放置选项卡 4、Anchor 锚定控件…

第18章_MySQL8其它新特性

第18章_MySQL8其它新特性 🏠个人主页:shark-Gao 🧑个人简介:大家好,我是shark-Gao,一个想要与大家共同进步的男人😉😉 🎉目前状况:23届毕业生,…

新一轮商业革命将至,张勇用“敏捷组织”率先交出答卷

一向拥抱变化的阿里再一次拥抱变化。2023年3月28日,阿里宣布了新的组织变革,这应该是迄今为止,阿里最重要的组织变革,其变革力度之大堪称前所未有。具体而言,阿里集团将设立云智能、淘宝天猫商业、本地生活、国际数字商…

口罩检测——环境准备(1)

文章目录前言一、工具及环境要求工具本地环境要求二、工具介绍1.labelimg2.AI Studio3.YOLO2COCO4.PaddleUtils5.paddleyolo三、库的安装总结前言 小编之前做过一期《OpenVINO-yolov5推理》,点开博客自动播放视频甚至有点吵,想过删掉,但是想到…

Day924.自动化测试 -系统重构实战

自动化测试 Hi,我是阿昌,今天学习记录的是关于自动化测试的内容。 自动化测试是一个很容易产生“争议”的话题,也经常会有一些很有意思的问题。 自动化测试不是应该由测试同学来编写吗,开发是不是没有必要学吧?之前…

Lesson 9.1 集成学习的三大关键领域、Bagging 方法的基本思想和 RandomForestRegressor 的实现

文章目录一、 集成学习的三大关键领域二、Bagging 方法的基本思想三、RandomForestRegressor 的实现在开始学习之前,先导入我们需要的库,并查看库的版本。 import numpy as np import pandas as pd import sklearn import matplotlib as mlp import sea…

【MySQL速通篇001】5000字超详细介绍MySQL部分重要知识点

🍀 写在前面 这篇5000多字博客也花了我几天的时间😂,主要是我对MySQL一部分重要知识点的理解【后面当然还会写博客补充噻,欢迎关注我哟】,当然这篇文章可能也会有不恰当的地方【毕竟也写了这么多字,错别字可…

Linux常用命令——ldconfig命令

在线Linux命令查询工具 ldconfig 动态链接库管理命令 补充说明 ldconfig命令的用途主要是在默认搜寻目录/lib和/usr/lib以及动态库配置文件/etc/ld.so.conf内所列的目录下,搜索出可共享的动态链接库(格式如lib*.so*),进而创建出动态装入程…

python框架有哪些,常用的python框架代码

Python的应用已经相当广泛了,可以做很多事情,而 Python本身就是一个应用程序,我们也可以说 Python是一个高级语言。由于 Python有很多包,所以我们不能把所有的 Python包都了解一下,也不能把所有的包都读一遍&#xff0…