【机器学习】线性回归算法:原理、公式推导、损失函数、似然函数、梯度下降

1. 概念简述

        线性回归是通过一个或多个自变量与因变量之间进行建模的回归分析,其特点为一个或多个称为回归系数的模型参数的线性组合。如下图所示,样本点为历史数据,回归曲线要能最贴切的模拟样本点的趋势,将误差降到最小


2. 线性回归方程

        线形回归方程,就是有 n 个特征,然后每个特征 Xi 都有相应的系数 Wi ,并且在所有特征值为0的情况下,目标值有一个默认值 W0 ,因此:

线性回归方程为

h(w)=w_{0} + w_{1}*x_{1}+w_{2}*x_{2}+...+w_{n}*x_{n}

整合后的公式为:

h(w)=\sum_{i}^{n}w_{i}*x_{i} = \theta ^{T}*x


3. 损失函数

        损失函数是一个贯穿整个机器学习的一个重要概念,大部分机器学习算法都有误差,我们需要通过显性的公式来描述这个误差,并将这个误差优化到最小值。假设现在真实的值y预测的值h 。

损失函数公式为:

J(\theta )=\frac{1}{2}*\sum_{i}^{n}( y^{(i)} - \theta ^{T}*x^{(i)} )^{2}

也就是所有误差和的平方。损失函数值越小,说明误差越小,这个损失函数也称最小二乘法


4. 损失函数推导过程

4.1 公式转换

首先我们有一个线性回归方程h(\theta)=\theta_{0} + \theta_{1}*x_{1}+\theta_{2}*x_{2}+...+\theta_{n}*x_{n} 

为了方便计算计算,我们将线性回归方程转换成两个矩阵相乘的形式,将原式的 \theta _{0} 后面乘一个 x_{0}

此时的 x0=1,因此将线性回归方程转变成 h(\theta)=\sum_{i}^{n}\theta_{i}*x_{i},其中 \theta _{i} 和 x_{i} 可以写成矩阵

h(\theta)=\theta_{0} + \theta_{1}*x_{1}+...+\theta_{n}*x_{n} = \left [ \theta _{0} \; \theta _{1}\; \theta _{2}\; ... \right ]*\begin{bmatrix} x _{0}\\ x _{1}\\ x _{2}\\ ...\\ \end{bmatrix}=\sum_{i}^{n}\theta_{i}*x_{i} = \theta ^{T}*x

4.2 误差公式

以上求得的只是一个预测的值,而不是真实的值,他们之间肯定会存在误差,因此会有以下公式:

y_{i} = \theta _{i}*x_{i}+\epsilon_{i}

我们需要找出真实值 y_{i} 与预测值 \theta _{i}*x_{i} 之间的最小误差 \epsilon_{i} ,使预测值和真实值的差距最小。将这个公式转换成寻找不同的 \theta _{i} 使误差达到最小。

4.3 转化为 \theta 求解

由于 \epsilon_{i} 既存在正数也存在负数,所以可以简单的把这个数据集,看作是一个服从均值 \theta ,方差\sigma ^{2} 的正态分布。

所以 \epsilon_{i} 出现的概率满足概率密度函数

p(\epsilon _{i} ) = \frac{1}{\sigma\sqrt{2\pi }} exp\tfrac{-(\epsilon _{i})^{2}}{2\sigma ^{2}}

把 \epsilon_{i} =y_{i}- \theta _{i}*x_{i} 代入到以上的高斯分布函数(即正态分布)中,变成以下式子: 

p(\epsilon _{i} ) = \frac{1}{\sigma\sqrt{2\pi }} exp\tfrac{-(y_{i}- \theta _{i}*x_{i})^{2}}{2\sigma ^{2}}

到此,我们将对误差 \epsilon _{i} 的求解转换成对 \theta_{i} 的求解了。

在求解这个公式时,我们要得到的是误差 \epsilon _{i} 最小,也就是求概率 p(\epsilon _{i}) 最大的。因为误差 \epsilon _{i} 满足正态分布,因此在正太曲线中央高峰部的概率 p(\epsilon _{i}) 是最大的,此时标准差\sigma为0误差是最小的。

尽管在生活中标准差肯定是不为0的,没关系,我们只需要去找到误差值出现的概率最大的点。现在,问题就变成了怎么去找误差出现概率最大的点,只要找到,那我们就能求出\theta _{i}

4.4 似然函数求 \theta

似然函数的主要作用是,在已经知道变量 x 的情况下,调整 \theta,使概率 y 的值最大。

似然函数理解:

以抛硬币为例,正常情况硬币出现正反面的概率都是0.5,假设你在不确定这枚硬币的材质、重量分布的情况下,需要判断其是否真的是均匀分布。在这里我们假设这枚硬币有 \theta 的概率会正面朝上,有 1-\theta 的概率会反面朝上

为了获得 \theta 的值,将硬币抛10次,H为正面,T为反面,得到一个正反序列 x = HHTTHTHHHH,此次实验满足二项分布,这个序列出现的概率\theta \theta (1-\theta )(1-\theta ) \theta(1-\theta ) \theta \theta \theta \theta= \theta^{7}(1-\theta )^{3},我们根据一次简单的二项分布实验,得到了一个关于 \theta 的函数,这实际上是一个似然函数,根据不同的 \theta 值绘制一条曲线,曲线就是\theta的似然函数,y轴是这一现象出现的概率。

从图中可见,当 \theta 等于 0.7 时,该序列出现的概率是最大的,因此我们确定该硬币正面朝上的概率是0.7。

因此,回到正题,我们要求的是误差出现概率 p(\epsilon _{i}) 的最大值,那就做很多次实验,对误差出现概率累乘,得出似然函数,带入不同的 \theta ,\theta是多少时,出现的概率是最大的,即可确定\theta的值。

综上,我们得出求 \theta 的似然函数为:

L( \theta ) = \prod_{i}^{m} \frac{1}{\sigma\sqrt{2\pi }} exp\tfrac{-(y_{i}- \theta _{i}*x_{i})^{2}}{2\sigma ^{2}}

4.5 对数似然

由于上述的累乘的方法不太方便我们去求解 \theta,我们可以转换成对数似然,将以上公式放到对数中,然后就可以转换成一个加法运算。取对数以后会改变结果值,但不会改变结果的大小顺序。我们只关心\theta等于什么的时候,似然函数有最大值,不用管最大值是多少,即,不是求极值而是求极值点。注:此处log的底数为e。

对数似然公式如下:

\log (L( \theta )) =\log \prod_{i}^{m} \frac{1}{\sigma\sqrt{2\pi }} exp\tfrac{-(y_{i}- \theta _{i}*x_{i})^{2}}{2\sigma ^{2}} = \sum_{i}^{n}\log \frac{1}{\sigma\sqrt{2\pi }} exp\tfrac{-(y_{i}- \theta _{i}*x_{i})^{2}}{2\sigma ^{2}}

对以上公式化简得:

\log (L( \theta )) =n*\log \frac{1}{\sigma\sqrt{2\pi }} - \frac{1}{2\sigma ^{2}}\sum_{i}^{n} (y_{i}- \theta _{i}*x_{i})^{2}

4.6 损失函数

我们需要把上面那个式子求得最大值,然后再获取最大值时的 \theta 值。 而上式中 n*\log \frac{1}{\sigma\sqrt{2\pi }} 是一个常数项,所以我们只需要把减号后面那个式子变得最小就可以了,而减号后面那个部分,可以把常数项 \frac{1}{\sigma ^{2}} 去掉,因此我们得到最终的损失函数如下,现在只需要求损失函数的最小值。

J (\theta ) = \frac{1}{2}\sum_{i}^{n} (y_{i}- \theta _{i}*x_{i})^{2}

注:保留 \frac{1}{2} 是为了后期求偏导数。

损失函数越小,说明预测值越接近真实值,这个损失函数也叫最小二乘法。


5. 梯度下降

损失函数中 xiyi 都是给定的值,能调整的只有 \theta,如果随机的调整,数据量很大,会花费很长时间,每次调整都不清楚我调整的是高了还是低了。我们需要根据指定的路径去调节,每次调节一个,范围就减少一点,有目标有计划去调节。梯度下降相当于是去找到一条路径,让我们去调整\theta

梯度下降的通俗理解就是,把对以上损失函数最小值的求解,比喻成梯子,然后不断地下降,直到找到最低的值。

5.1 批量梯度下降(BGD)

批量梯度下降,是在每次求解过程中,把所有数据都进行考察,因此损失函数因该要在原来的损失函数的基础之上加上一个m:数据量,来求平均值

J (\theta ) = \frac{1}{2m}\sum_{i}^{m} (y_{i}- \theta _{i}*x_{i})^{2}

因为现在针对所有的数据做了一次损失函数的求解,比如我现在对100万条数据都做了损失函数的求解,数据量结果太大,除以数据量100万,求损失函数的平均值。

然后,我们需要去求一个点的方向,也就是去求它的斜率。对这个点求导数,就是它的斜率,因此我们只需要求出 J(\theta ) 的导数,就知道它要往哪个方向下降了。它的方向先对所有分支方向求导再找出它们的合方向。

J(\theta ) 的导数为:

\frac{\partial J (\theta)}{\partial \theta _{j}} = -\frac{1}{m}\sum_{i}^{m} (y^{j}- h_{\theta} (x^{i}))x_{j}^{i}

由于导数的方向是上升的,现在我们需要梯度下降,因此在上式前面加一个负号,就得到了下降方向,而下降是在当前点的基础上下降的。

批量梯度下降法下降后的点为:

\theta_{j}{'} = \theta_{j}+\alpha \frac{1}{m}\sum_{i}^{m} (y^{j}- h_{\theta} (x^{i}))x_{j}^{i}

新点是在原点的基础上往下走一点点,斜率表示梯度下降的方向,\alpha 表示要下降多少。由于不同点的斜率是不一样的,以此循环,找到最低点。

批量梯度下降的特点:每次向下走一点点都需要将所有的点拿来运算,如果数据量大非常耗时间。


5.2 随机梯度下降(SGD)

随机梯度下降是通过每个样本来迭代更新一次。对比批量梯度下降,迭代一次需要用到所有的样本,一次迭代不可能最优,如果迭代10次就需要遍历整个样本10次。SGD每次取一个点来计算下降方向。但是,随机梯度下降的噪音比批量梯度下降要多,使得随机梯度下降并不是每次迭代都向着整体最优化方向

随机梯度下降法下降后的点为:

\theta_{j}{'} = \theta_{j}+\alpha (y^{j}- h_{\theta} (x^{i}))x_{j}^{i}

每次随机一个点计算,不需要把所有点拿来求平均值,梯度下降路径弯弯曲曲趋势不太好。


5.3 mini-batch 小批量梯度下降(MBGO)

我们从上面两个梯度下降方法中可以看出,他们各自有优缺点。小批量梯度下降法在这两种方法中取得了一个折衷,算法的训练过程比较快,而且也要保证最终参数训练的准确率。

假设现在有10万条数据,MBGO一次性拿几百几千条数据来计算,能保证大体方向上还是下降的。

小批量梯度下降法下降后的点为:

\theta_{j}{'} = \theta_{j}+\alpha \frac{1}{n}\sum_{i}^{n} (y^{j}- h_{\theta} (x^{i}))x_{j}^{i}

\alpha 用来表示学习速率,即每次下降多少。已经求出斜率了,但是往下走多少合适呢,\alpha值需要去调节,太大的话下降方向会偏离整体方向,太小会导致学习效率很慢。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/152145.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于旗鱼算法优化概率神经网络PNN的分类预测 - 附代码

基于旗鱼算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于旗鱼算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于旗鱼优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要:针对PNN神经网络的光滑…

云表|低代码软件开发“外挂”,新时代的黑科技

随着技术的日新月异,现代企业对于软件开发的需求愈加迫切,传统的软件开发方式已然无法满足快速迭代和创新的需求。在这种背景下,低代码开发平台如破茧而出,应运而生。这种平台通过提供可视化的开发工具和预构建的组件,…

NVIDIA安装

电脑显卡类型 两种方法: 选择对应的版本 产品系列下载Notebooks,这样产品才会出现Laptop的GPU(Laptop是代表笔记本)。 下载完之后双击安装,更改下载路径后,选择默认的下载即可。 卸载 如果之后要卸载…

Spring6(三):面向切面AOP

文章目录 4. 面向切面:AOP4.1 场景模拟4.1.1 声明接口4.1.2 创建实现类4.1.3 创建带日志功能的实现类4.1.4 提出问题 4.2 代理模式4.2.1 概念4.2.2 静态代理4.2.3 动态代理4.2.4 测试 4.3 AOP概念4.3.1 相关术语①横切关注点②通知(增强)③切…

chrome 浏览器个别字体模糊不清

特别是在虚拟机里,有些字体看不清,但是有些就可以,设置办法: chrome://settings/fonts 这里明显可以看到有些字体就是模糊的状态: 把这种模糊的字体换掉即可解决一部分问题。 另外,经过观察,…

Neuro-Oncology | IF:15.9 CUTTag和RNA-seq联合解析胶质母细胞瘤的耐药性

发表单位:德克萨斯大学圣安东尼奥分校 发表日期:2023年1月18日 期 刊:Neuro-Oncology(IF: 15.9) 研究技术:CUT&Tag-seq、RNA-seq、RT-qPCR(爱基百客均可提供) 2023年1月1…

如何在10亿级别用户中检查用户名是否存在?

题目 不知道大家有没有留意过,在使用一些app注册的时候,提示你用户名已经被占用了,需要更换一个,这是如何实现的呢?你可能想这不是很简单吗,去数据库里查一下有没有不就行了吗,那么假如用户数量…

【人工智能实验】A*算法求解8数码问题 golang

人工智能经典问题八数码求解 实际上是将求解转为寻找最优节点的问题,算法流程如下: 求非0元素的逆序数的和,判断是否有解将开始状态放到节点集,并设置访问标识位为true从节点集中取出h(x)g(x)最小的节点判断取出的节点的状态是不…

Redis - 订阅发布替换 Etcd 解决方案

为了减轻项目的中间件臃肿,由于我们项目本身就应用了 Redis,正好 Redis 的也具备订阅发布监听的特性,正好应对 Etcd 的功能,所以本次给大家讲解如何使用 Redis 消息订阅发布来替代 Etcd 的解决方案。接下来,我们先看 R…

linux之shell

一、是什么 Shell是一个由c语言编写的应用程序,它是用户使用 Linux 的桥梁。Shell 既是一种命令语言,又是一种程序设计语言 它连接了用户和Linux内核,让用户能够更加高效、安全、低成本地使用 Linux 内核 其本身并不是内核的一部分&#x…

Java实现自定义windows右键菜单

要添加Java应用程序到Windows桌面的右键菜单,可以按照以下步骤操作: 创建一个新的.reg文件,并在文本编辑器中打开它。 添加以下代码到.reg文件中,将名称和路径替换为您的Java应用程序的名称和路径。 Windows Registry Editor V…

虚拟化热添加技术在数据备份上的应用

虚拟化中的热添加技术主要是指:无需停止或中断虚拟机的情况下,在线添加物理资源(如硬盘、内存、CPU、网卡等)的技术。热添加技术也是相比物理机一个非常巨大的优势,其使得资源分配变得更加灵活。 虚拟化中的热添加技术…

SOP作业指导书系统如何帮助厂家实现数字化转型

SOP(Standard Operating Procedure,标准操作程序)电子作业操作手册的应用对于厂家实现数字化转型起着至关重要的作用。本文将探讨SOP电子作业操作手册如何帮助厂家实现数字化转型的重要性和优势。 首先,SOP作业指导书可以提高生产…

idea菜单栏任务栏放缩比例修改

在编辑自定义VM选项中增加 -Dide.ui.scale0.8 参数 Help -> Edit Custom VM Options

这家提供数据闭环完整链路的企业,已拿下多家头部主机厂定点

“BEV感知数据闭环”已经成为新一代自动驾驶系统的核心架构。 进入2023年,小鹏、理想、阿维塔、智己、华为问界等汽车品牌正在全力推动从高速NOA到城区NOA的升级。在这一过程当中,如何利用高效的算力支撑、完善的算法模型、大量有效的数据形成闭环&…

Ubuntu部署OpenStack踩坑指南:还要看系统版本?

正文共:1515 字 12 图,预估阅读时间:2 分钟 到目前为止,我对OpenStack还不太了解,只知道OpenStack本身是一个云管理平台(什么是OpenStack?)。那作为云管理平台,我能想到最…

解决网络编程中的EOF违反协议问题:requests库与SSL错误案例分析

1. 问题背景 近期,一个用户在使用requests库进行网络编程时遭遇到了一个不寻常的问题,涉及SSL错误,并提示错误消息为SSLError(SSLEOFError(8, uEOF occurred in violation of protocol (_ssl.c:661)),))。该用户表示已经采取了多种方法来解决…

【深度学习实验】网络优化与正则化(五):数据预处理详解——标准化、归一化、白化、去除异常值、处理缺失值

文章目录 一、实验介绍二、实验环境1. 配置虚拟环境2. 库版本介绍 三、优化算法0. 导入必要的库1. 随机梯度下降SGD算法a. PyTorch中的SGD优化器b. 使用SGD优化器的前馈神经网络 2.随机梯度下降的改进方法a. 学习率调整b. 梯度估计修正 3. 梯度估计修正:动量法Momen…

【文件包含】phpmyadmin 文件包含(CVE-2014-8959)

1.1漏洞描述 漏洞编号CVE-2014-8959漏洞类型文件包含漏洞等级高危漏洞环境Windows漏洞名称phpmyadmin 文件包含(CVE-2014-8959) 描述: phpMyAdmin是一套开源的、基于Web的MySQL数据库管理工具。其index.php中存在一处文件包含逻辑,通过二次编…