机器学习 —— 深入剖析线性回归模型

一、线性回归模型简介

线性回归是机器学习中最为基础的模型之一,主要用于解决回归问题,即预测一个连续的数值。其核心思想是构建线性方程,描述自变量(特征)和因变量(目标值)之间的关系。简单来说,若有一个自变量 x x x 和一个因变量 y y y,简单线性回归模型可表示为: y = θ 0 + θ 1 x y = \theta_0 + \theta_1x y=θ0+θ1x,其中 θ 0 \theta_0 θ0 是截距, θ 1 \theta_1 θ1 是斜率,也被称为回归系数。通过这条直线,我们尝试让模型预测值尽可能接近真实值。

(一)多元线性回归

在实际应用中,数据往往具有多个特征,这就需要多元线性回归模型。假设我们有 n n n 个自变量 x 1 , x 2 , ⋯   , x n x_1, x_2, \cdots, x_n x1,x2,,xn,多元线性回归模型的表达式为: y = θ 0 + θ 1 x 1 + θ 2 x 2 + ⋯ + θ n x n y = \theta_0 + \theta_1x_1 + \theta_2x_2 + \cdots + \theta_nx_n y=θ0+θ1x1+θ2x2++θnxn。从几何角度理解,简单线性回归是在二维平面上找一条最佳拟合直线;而多元线性回归则是在更高维度空间中寻找一个超平面,使得所有数据点到这个超平面的距离之和最小。

例如,在预测房价时,房屋价格可能受到面积、房龄、房间数量、周边配套设施等多个因素影响,多元线性回归模型能够综合考虑这些因素,从而做出更准确的预测。

(二)岭回归

岭回归是一种改进的线性回归算法,也被称为 Tikhonov 正则化。在普通线性回归中,当特征数量较多且存在多重共线性(即某些特征之间存在较强的线性关系)时,计算正规方程中的 ( X T X ) − 1 (X^TX)^{-1} (XTX)1 可能会出现问题,导致模型不稳定,对训练数据的微小变化非常敏感,泛化能力差。

岭回归通过在损失函数中添加一个 L2 正则化项来解决这个问题。其损失函数变为: J ( θ ) = ∑ i = 1 m ( y ( i ) − y ^ ( i ) ) 2 + λ ∑ j = 1 n θ j 2 J(\theta) = \sum_{i = 1}^{m}(y^{(i)} - \hat{y}^{(i)})^2 + \lambda\sum_{j = 1}^{n}\theta_j^2 J(θ)=i=1m(y(i)y^(i))2+λj=1nθj2,其中 λ \lambda λ 是正则化参数,用来控制正则化的强度。当 λ \lambda λ 越大时,对回归系数的约束越强,使得回归系数更倾向于收缩到 0,从而防止过拟合;当 λ \lambda λ 为 0 时,岭回归就退化为普通的线性回归。

岭回归的优势在于,它不仅能在一定程度上解决多重共线性问题,还能提高模型的泛化能力,使得模型在面对新数据时表现更加稳定。

(三)Lasso 回归

Lasso 回归,即 Least Absolute Shrinkage and Selection Operator,同样是一种用于线性回归的正则化方法。与岭回归不同,Lasso 回归在损失函数中添加的是 L1 正则化项,其损失函数为: J ( θ ) = ∑ i = 1 m ( y ( i ) − y ^ ( i ) ) 2 + λ ∑ j = 1 n ∣ θ j ∣ J(\theta) = \sum_{i = 1}^{m}(y^{(i)} - \hat{y}^{(i)})^2 + \lambda\sum_{j = 1}^{n}|\theta_j| J(θ)=i=1m(y(i)y^(i))2+λj=1nθj

L1 正则化的特点是它能够产生稀疏解,即可以自动筛选出对目标值影响较大的特征,将一些不重要的特征对应的系数直接压缩为 0,从而达到特征选择的目的。例如在基因数据分析中,数据维度极高,特征众多,Lasso 回归可以帮助我们从大量的基因特征中筛选出真正与疾病相关的基因,简化模型的同时提高解释性。

(四)弹性网络回归

弹性网络回归结合了岭回归和 Lasso 回归的优点,在损失函数中同时使用 L1 和 L2 正则化项,其损失函数表达式为: J ( θ ) = ∑ i = 1 m ( y ( i ) − y ^ ( i ) ) 2 + λ 1 ∑ j = 1 n ∣ θ j ∣ + λ 2 ∑ j = 1 n θ j 2 J(\theta) = \sum_{i = 1}^{m}(y^{(i)} - \hat{y}^{(i)})^2 + \lambda_1\sum_{j = 1}^{n}|\theta_j| + \lambda_2\sum_{j = 1}^{n}\theta_j^2 J(θ)=i=1m(y(i)y^(i))2+λ1j=1nθj+λ2j=1nθj2 ,其中 λ 1 \lambda_1 λ1 λ 2 \lambda_2 λ2 分别是 L1 和 L2 正则化项的系数。

这种方法既可以像 Lasso 回归一样进行特征选择,又能像岭回归一样处理多重共线性问题。在一些复杂的数据场景中,比如图像识别中,数据既存在大量冗余特征,又有特征间的相关性,弹性网络回归能够发挥其综合优势,平衡模型的复杂度和性能。

二、线性回归模型的原理

线性回归模型的目标是找到一组最优的回归系数 θ = [ θ 0 , θ 1 , ⋯   , θ n ] \theta = [\theta_0, \theta_1, \cdots, \theta_n] θ=[θ0,θ1,,θn],使得模型预测值与真实值之间的误差最小。通常,我们使用最小二乘法来衡量这种误差。最小二乘法的目标函数(也称为损失函数)为: J ( θ ) = ∑ i = 1 m ( y ( i ) − y ^ ( i ) ) 2 J(\theta) = \sum_{i = 1}^{m}(y^{(i)} - \hat{y}^{(i)})^2 J(θ)=i=1m(y(i)y^(i))2,其中 m m m 是样本数量, y ( i ) y^{(i)} y(i) 是第 i i i 个样本的真实值, y ^ ( i ) \hat{y}^{(i)} y^(i) 是第 i i i 个样本的预测值, y ^ ( i ) = θ 0 + θ 1 x 1 ( i ) + θ 2 x 2 ( i ) + ⋯ + θ n x n ( i ) \hat{y}^{(i)} = \theta_0 + \theta_1x_1^{(i)} + \theta_2x_2^{(i)} + \cdots + \theta_nx_n^{(i)} y^(i)=θ0+θ1x1(i)+θ2x2(i)++θnxn(i)

为了找到使损失函数最小的 θ \theta θ,我们可以对 J ( θ ) J(\theta) J(θ) 求关于 θ \theta θ 的导数,并令导数为零,从而得到正规方程: θ = ( X T X ) − 1 X T y \theta = (X^TX)^{-1}X^Ty θ=(XTX)1XTy,其中 X X X 是特征矩阵,每一行代表一个样本,每一列代表一个特征, y y y 是目标值向量。但正如前面提到的,当 X T X X^TX XTX 接近奇异矩阵(即不可逆)时,求解正规方程会出现问题,这也是岭回归、Lasso 回归和弹性网络回归等方法出现的原因之一。

三、线性回归模型的优化方法

除了使用正规方程求解回归系数外,我们还可以使用梯度下降法来优化损失函数。梯度下降法是一种迭代的优化算法,它通过不断地沿着损失函数的负梯度方向更新回归系数,来逐步减小损失函数的值。

具体来说,对于损失函数 J ( θ ) J(\theta) J(θ),其梯度为: ∇ J ( θ ) = 2 m X T ( X θ − y ) \nabla J(\theta) = \frac{2}{m}X^T(X\theta - y) J(θ)=m2XT(y)。在每次迭代中,我们按照以下公式更新回归系数: θ = θ − α ∇ J ( θ ) \theta = \theta - \alpha\nabla J(\theta) θ=θαJ(θ),其中 α \alpha α 是学习率,它控制着每次更新的步长。学习率的选择非常关键,如果学习率过大,可能会导致模型无法收敛,甚至发散;如果学习率过小,模型收敛速度会非常慢,需要更多的迭代次数。

四、Python 代码实现

下面我们使用 Python 来实现一个简单的线性回归模型,包括普通线性回归、多元线性回归、岭回归、Lasso 回归和弹性网络回归,并对比它们的效果。首先,我们需要导入必要的库,如numpymatplotlibsklearn中的相关模块。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
from sklearn.preprocessing import PolynomialFeatures
from sklearn.model_selection import GridSearchCV

# 生成一些随机数据
np.random.seed(0)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)

# 普通线性回归
lin_reg = LinearRegression()
lin_reg.fit(X, y)
y_lin_pred = lin_reg.predict(X)

# 多元线性回归(添加一个多项式特征)
poly_features = PolynomialFeatures(degree=3, include_bias=False)  # 修改多项式次数为3
X_poly = poly_features.fit_transform(X)
lin_reg_2 = LinearRegression()
lin_reg_2.fit(X_poly, y)
y_poly_pred = lin_reg_2.predict(X_poly)

# 岭回归
ridge_reg = Ridge(alpha=0.1)
ridge_reg.fit(X, y)
y_ridge_pred = ridge_reg.predict(X)

# Lasso回归
lasso_reg = Lasso(alpha=0.1)
lasso_reg.fit(X, y)
y_lasso_pred = lasso_reg.predict(X)

# 弹性网络回归
elastic_net_reg = ElasticNet(alpha=0.1, l1_ratio=0.5)
elastic_net_reg.fit(X, y)
y_elastic_pred = elastic_net_reg.predict(X)

# 使用网格搜索优化岭回归和Lasso回归的超参数
ridge_grid = GridSearchCV(Ridge(), param_grid={'alpha': [0.01, 0.1, 1, 10, 100]})
ridge_grid.fit(X, y)
best_ridge = ridge_grid.best_estimator_
y_ridge_best_pred = best_ridge.predict(X)

lasso_grid = GridSearchCV(Lasso(), param_grid={'alpha': [0.01, 0.1, 1, 10, 100]})
lasso_grid.fit(X, y)
best_lasso = lasso_grid.best_estimator_
y_lasso_best_pred = best_lasso.predict(X)

# 绘制数据和拟合直线
plt.figure(figsize=(15, 8))

plt.subplot(2, 3, 1)
plt.plot(X, y, "b.")
plt.plot(X, y_lin_pred, "r-", linewidth=2, label='Linear Regression')
plt.title('Linear Regression')
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend()

plt.subplot(2, 3, 2)
plt.plot(X, y, "b.")
X_sorted = np.sort(X, axis=0)
X_poly_sorted = poly_features.fit_transform(X_sorted)
plt.plot(X_sorted, lin_reg_2.predict(X_poly_sorted), "g-", linewidth=2, label='Polynomial Linear Regression (Degree=3)')
plt.title('Polynomial Linear Regression')
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend()

plt.subplot(2, 3, 3)
plt.plot(X, y, "b.")
plt.plot(X, y_ridge_pred, "m-", linewidth=2, label='Ridge Regression (alpha=0.1)')
plt.title('Ridge Regression')
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend()

plt.subplot(2, 3, 4)
plt.plot(X, y, "b.")
plt.plot(X, y_lasso_pred, "c-", linewidth=2, label='Lasso Regression (alpha=0.1)')
plt.title('Lasso Regression')
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend()

plt.subplot(2, 3, 5)
plt.plot(X, y, "b.")
plt.plot(X, y_elastic_pred, "y", linewidth=2, label='Elastic Net Regression (alpha=0.1, l1_ratio=0.5)')
plt.title('Elastic Net Regression')
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend()

plt.subplot(2, 3, 6)
plt.plot(X, y, "b.")
plt.plot(X, y_ridge_best_pred, "k", linewidth=2, label='Optimized Ridge Regression')
plt.plot(X, y_lasso_best_pred, "saddlebrown", linewidth=2, label='Optimized Lasso Regression') 
plt.title('Optimized Ridge and Lasso Regression')
plt.xlabel("$x_1$", fontsize=18)
plt.ylabel("$y$", rotation=0, fontsize=18)
plt.legend()

plt.tight_layout()
plt.show()

在上述代码中,我们首先生成了一些随机数据。然后分别使用LinearRegression类实现普通线性回归和多元线性回归(通过添加多项式特征实现),使用Ridge类实现岭回归,使用Lasso类实现 Lasso 回归,使用ElasticNet类实现弹性网络回归。最后绘制出数据点和各个模型的拟合直线,以便直观对比它们的效果。

五、总结与模型选用建议

不同的线性回归模型各有特点,在实际应用中需要根据具体情况选择合适的模型。

✨简单线性回归模型形式最为简单,仅包含一个自变量和一个因变量 ,适用于特征与目标值之间呈现明显线性关系,且数据特征单一的场景,比如根据时间预测某一产品的销量变化趋势。

🎈多元线性回归在简单线性回归基础上拓展到多个自变量,能处理更复杂的数据关系,像预测房价时综合考虑多个影响因素。但当数据存在多重共线性时,普通的多元线性回归可能导致模型不稳定。

🎨岭回归通过 L2 正则化项,在一定程度上缓解多重共线性问题,同时提升模型泛化能力。若数据特征众多且存在共线性,又希望保留所有特征,岭回归是不错的选择,如金融风险评估中,众多经济指标相互关联,岭回归可有效处理。

🍫Lasso 回归利用 L1 正则化产生稀疏解,自动筛选重要特征,实现特征选择,在高维数据场景优势明显,如基因数据分析,能从海量基因特征中找出关键特征。

🧆弹性网络回归结合了 L1 和 L2 正则化,兼具特征选择和处理共线性的能力,当数据既存在大量冗余特征,又有特征间相关性时,弹性网络回归能平衡模型复杂度与性能,例如图像识别领域。

在选择线性回归模型时,首先要分析数据特征,判断是否存在多重共线性、数据维度高低等。若数据简单且特征少,普通线性回归即可;若特征多且存在共线性,可考虑岭回归;若需特征选择,Lasso 回归或弹性网络回归更合适。还可以通过交叉验证等方法,比较不同模型在训练集和验证集上的性能指标,如均方误差(MSE)、决定系数(R²)等,最终选择性能最优的模型。 不断实践和尝试不同模型,才能在实际应用中发挥线性回归模型的最大价值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/966193.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

网络安全威胁框架与入侵分析模型概述

引言 “网络安全攻防的本质是人与人之间的对抗,每一次入侵背后都有一个实体(个人或组织)”。这一经典观点概括了网络攻防的深层本质。无论是APT(高级持续性威胁)攻击、零日漏洞利用,还是简单的钓鱼攻击&am…

Redis企业开发实战(三)——点评项目之优惠券秒杀

目录 一、全局唯一ID (一)概述 (二)全局ID生成器 (三)全局唯一ID生成策略 1. UUID (Universally Unique Identifier) 2. 雪花算法(Snowflake) 3. 数据库自增 4. Redis INCR/INCRBY 5.总结 (四)Redis实现全局唯一ID 1.工具类 2.测试类 3…

Verilog代码实例

Verilog语言学习! 文章目录 目录 文章目录 前言 一、基本逻辑门代码设计和仿真 1.1 反相器 1.2 与非门 1.3 四位与非门 二、组合逻辑代码设计和仿真 2.1 二选一逻辑 2.2 case语句实现多路选择逻辑 2.3 补码转换 2.4 7段数码管译码器 三、时序逻辑代码设计和仿真 3.1…

排序算法--基数排序

核心思想是按位排序(低位到高位)。适用于定长的整数或字符串,如例如:手机号、身份证号排序。按数据的每一位从低位到高位(或相反)依次排序,每次排序使用稳定的算法(如计数排序&#…

图形化界面MySQL(MySQL)(超级详细)

目录 1.官网地址 1.1在Linux直接点击NO thanks…? 1.2任何远端登录,再把jj数据库给授权 1.3建立新用户 优点和好处 示例代码(MySQL Workbench) 示例代码(phpMyAdmin) 总结 图形化界面 MySQL 工具大全及其功能…

C++ 使用CURL开源库实现Http/Https的get/post请求进行字串和文件传输

CURL开源库介绍 CURL 是一个功能强大的开源库,用于在各种平台上进行网络数据传输。它支持众多的网络协议,像 HTTP、HTTPS、FTP、SMTP 等,能让开发者方便地在程序里实现与远程服务器的通信。 CURL 可以在 Windows、Linux、macOS 等多种操作系…

mapbox进阶,添加绘图扩展插件,绘制圆形

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象1.2 ☘️mapboxgl.Map style属性1.3 ☘️MapboxDraw 绘图控件二、🍀添加绘图扩…

网络工程师 (24)数据封装与解封装

一、数据封装 数据封装是指将协议数据单元(PDU)封装在一组协议头和尾中的过程。在OSI 7层参考模型中,数据从应用层开始,逐层向下封装,直到物理层。每一层都会为其PDU添加相应的协议头和尾,以包含必要的通信…

OSPF基础(3):区域划分

OSPF的区域划分 1、区域产生背景 路由器在同一个区域中泛洪LSA。为了确保每台路由器都拥有对网络拓扑的一致认知,LSDB需要在区域内进行同步。OSPF域如果仅有一个区域,随着网络规模越来越大,OSPF路由器的数量越来越多,这将导致诸…

C++----继承

一、继承的基本概念 本质:代码复用类关系建模(是多态的基础) class Person { /*...*/ }; class Student : public Person { /*...*/ }; // public继承 派生类继承基类成员(数据方法),可以通过监视窗口检…

【DeepSeek】DeepSeek小模型蒸馏与本地部署深度解析DeepSeek小模型蒸馏与本地部署深度解析

一、引言与背景 在人工智能领域,大型语言模型(LLM)如DeepSeek以其卓越的自然语言理解和生成能力,推动了众多应用场景的发展。然而,大型模型的高昂计算和存储成本,以及潜在的数据隐私风险,限制了…

ZZNUOJ(C/C++)基础练习1081——1090(详解版)

目录 1081 : n个数求和 (多实例测试) C C 1082 : 敲7(多实例测试) C C 1083 : 数值统计(多实例测试) C C 1084 : 计算两点间的距离(多实例测试) C C 1085 : 求奇数的乘积(多实例测试…

axios 发起 post请求 json 需要传入数据格式

• 1. axios 发起 post请求 json 传入数据格式 • 2. axios get请求 1. axios 发起 post请求 json 传入数据格式 使用 axios 发起 POST 请求并以 JSON 格式传递数据是前端开发中常见的操作。 下面是一个简单的示例,展示如何使用 axios 向服务器发送包含 JSON 数…

硬盘接入电脑提示格式化?是什么原因?怎么解决?

有时候,当你将硬盘接入电脑时,看到系统弹出“使用驱动器中的光盘之前需要将其格式化”的提示,肯定会感到十分困惑和焦虑。这种情况不仅让人担心数据丢失,也可能影响正常使用。为什么硬盘会突然要求格式化?是硬盘出了问…

使用Python实现PDF与SVG相互转换

目录 使用工具 使用Python将SVG转换为PDF 使用Python将SVG添加到现有PDF中 使用Python将PDF转换为SVG 使用Python将PDF的特定页面转换为SVG SVG(可缩放矢量图形)和PDF(便携式文档格式)是两种常见且广泛使用的文件格式。SVG是…

【大数据技术】搭建完全分布式高可用大数据集群(Kafka)

搭建完全分布式高可用大数据集群(Kafka) kafka_2.13-3.9.0.tgz注:请在阅读本篇文章前,将以上资源下载下来。 写在前面 本文主要介绍搭建完全分布式高可用集群 Kafka 的详细步骤。 注意: 统一约定将软件安装包存放于虚拟机的/software目录下,软件安装至/opt目录下。 安…

【C++篇】C++11新特性总结1

目录 1,C11的发展历史 2,列表初始化 2.1C98传统的{} 2.2,C11中的{} 2.3,C11中的std::initializer_list 3,右值引用和移动语义 3.1,左值和右值 3.2,左值引用和右值引用 3.3,…

Redis --- 使用HyperLogLog实现UV(访客量)

UV 和 PV 是网站或应用数据分析中的常用指标,用于衡量用户活跃度和页面访问量。 UV (Unique Visitor 独立访客): 指的是在一定时间内访问过网站或应用的独立用户数量。通常根据用户的 IP 地址、Cookies 或用户 ID 等来唯一标识一个用户。示例&#xff1…

【机器学习案列】糖尿病风险可视化及预测

🧑 博主简介:曾任某智慧城市类企业算法总监,目前在美国市场的物流公司从事高级算法工程师一职,深耕人工智能领域,精通python数据挖掘、可视化、机器学习等,发表过AI相关的专利并多次在AI类比赛中获奖。CSDN…

单片机之基本元器件的工作原理

一、二极管 二极管的工作原理 二极管是一种由P型半导体和N型半导体结合形成的PN结器件,具有单向导电性。 1. PN结形成 P型半导体:掺入三价元素,形成空穴作为多数载流子。N型半导体:掺入五价元素,形成自由电子作为多…