【机器学习】多元线性回归

在实际应用中,许多问题都包含多个特征(输入变量),而不仅仅是单个输入变量。多元线性回归是线性回归的扩展,它能够处理多个输入特征并建立它们与目标变量的线性关系。本教程将系统性推演多元线性回归,包括向量化处理、特征放缩、梯度下降的收敛性和学习率选择等,并使用numpy实现。最后,我们会通过sklearn快速实现多元线性回归模型。

多元线性回归模型简介

多元线性回归的模型公式为:
y = X ⋅ w + b y = X \cdot w + b y=Xw+b
其中:

  • ( y ) 是预测值(输出),
  • ( X ) 是输入特征矩阵(每一行表示一个样本,每一列表示一个特征),
  • ( w ) 是权重向量,
  • ( b ) 是偏置项(通常看作一个常数项)。

模型的目标是找到最优的权重向量 ( w ) 和偏置 ( b ),使得预测值与真实值的差异最小。

向量化处理

在机器学习中,向量化是一种通过矩阵运算来加速模型训练的方式。我们将模型的多个样本和特征表示为矩阵形式,这样能够利用线性代数库(如 numpy)中的优化操作来加速计算。

多元线性回归的预测可以用向量化表示为:

Y pred = X ⋅ w + b Y_{\text{pred}} = X \cdot w + b Ypred=Xw+b
其中:

  • ( X ) 是 ( n \times m ) 的矩阵,表示 ( n ) 个样本的 ( m ) 个特征,
  • ( w ) 是 ( m \times 1 ) 的权重向量,
  • ( b ) 是常数偏置项。

损失函数

我们依然使用**均方误差(MSE, Mean Squared Error)**作为损失函数,用来衡量模型预测值与真实值之间的差异。其公式为:

M S E = 1 n ∑ i = 1 n ( y i − y i ^ ) 2 MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y_i})^2 MSE=n1i=1n(yiyi^)2

特征放缩(Feature Scaling)

在多元线性回归中,特征的尺度对模型训练的影响较大。如果某些特征的值范围过大,会导致它们在梯度下降中主导权重更新,导致收敛速度变慢甚至无法收敛。因此,通常会对特征进行标准化或归一化。

  1. 归一化(Normalization):将特征缩放到[0, 1]区间,公式如下:

X ′ = X − X min X max − X min X' = \frac{X - X_{\text{min}}}{X_{\text{max}} - X_{\text{min}}} X=XmaxXminXXmin

  1. 标准化(Standardization):将特征的均值归零,标准差为1,公式如下:

X ′ = X − μ σ X' = \frac{X - \mu}{\sigma} X=σXμ

梯度下降的收敛性

在使用梯度下降优化模型参数时,梯度下降的收敛性取决于学习率的选择和损失函数的性质。如果学习率过大,梯度下降可能会在更新中超出最优值;学习率过小,收敛速度会非常慢。通常我们需要通过实验选择合适的学习率。

代码实现:多元线性回归模型

接下来,我们使用numpy从头实现多元线性回归模型。

数据准备

我们生成一个包含多个特征的数据集。

import numpy as np
import matplotlib.pyplot as plt

# 生成多元数据
np.random.seed(42)
X = 2 * np.random.rand(100, 3)  # 生成 100 个样本,3 个特征
true_w = np.array([3, 4, 5])
y = X.dot(true_w) + 6 + np.random.randn(100)  # y = 3x1 + 4x2 + 5x3 + 6 + 噪声

# 查看数据维度
print("X shape:", X.shape)
print("y shape:", y.shape)

损失函数

实现均方误差损失函数:

def mse_loss(y_true, y_pred):
    return np.mean((y_true - y_pred) ** 2)

梯度计算

实现梯度计算函数:

def compute_gradients(X, y, w, b):
    n = len(y)
    y_pred = X.dot(w) + b
    dw = (2/n) * X.T.dot(y_pred - y)
    db = (2/n) * np.sum(y_pred - y)
    return dw, db

梯度下降

我们定义梯度下降函数,更新权重和偏置:

def gradient_descent(X, y, w, b, learning_rate, iterations):
    for i in range(iterations):
        dw, db = compute_gradients(X, y, w, b)
        w -= learning_rate * dw
        b -= learning_rate * db
        
        if i % 100 == 0:
            y_pred = X.dot(w) + b
            loss = mse_loss(y, y_pred)
            print(f"Iteration {i}: Loss = {loss}")
            
    return w, b

特征放缩

我们可以通过 StandardScaler 对特征进行标准化。

def standardize(X):
    mean = np.mean(X, axis=0)
    std = np.std(X, axis=0)
    X_scaled = (X - mean) / std
    return X_scaled

模型训练

初始化参数并训练模型:

# 初始化参数
w = np.random.randn(3)
b = np.random.randn(1)

# 特征标准化
X_scaled = standardize(X)

# 超参数设置
learning_rate = 0.01
iterations = 1000

# 训练模型
w_trained, b_trained = gradient_descent(X_scaled, y, w, b, learning_rate, iterations)
print(f"Trained weights: {w_trained}, Trained bias: {b_trained}")

可视化模型

对于多元回归,权重无法直接用图像展示,但可以展示损失值的收敛曲线:

# 绘制损失曲线
losses = []
for i in range(1000):
    dw, db = compute_gradients(X_scaled, y, w, b)
    w -= learning_rate * dw
    b -= learning_rate * db
    y_pred = X_scaled.dot(w) + b
    loss = mse_loss(y, y_pred)
    losses.append(loss)

plt.plot(losses)
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.show()

使用 sklearn 实现多元线性回归

最后,我们使用sklearn快速实现多元线性回归。

from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import StandardScaler

# 特征标准化
scaler = StandardScaler()
X_scaled_sklearn = scaler.fit_transform(X)

# 训练模型
lin_reg = LinearRegression()
lin_reg.fit(X_scaled_sklearn, y)

# 输出权重和偏置
print(f"Sklearn Trained weights: {lin_reg.coef_}, Sklearn Trained bias: {lin_reg.intercept_}")

总结

在本教程中,我们深入推演了多元线性回归的基本原理,从向量化、特征放缩、梯度下降收敛性到学习率选择,并使用numpy实现了完整的多元线性回归模型。通过sklearn的实现,我们验证了结果并加速了训练流程。希望这篇教程能帮助你进一步理解多元线性回归模型的核心概念。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/902748.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Android Junit 单元测试 | 依赖配置和编译报错解决

问题 为什么在依赖中添加了testImplement在build APK的时候还是会报错?是因为没有识别到test文件夹是test源代码路径吗? 最常见的配置有: implementation - 所有源代码集(包括test源代码集)中都有该依赖库.testImplementation - 依赖关系仅在test源代码…

前端代码分享--爱心

给对象写的&#xff0c;顺便源码给大家分享一下 就是简单的htmlcssjs&#xff0c;不复杂 xin1.html <!DOCTYPE html> <html lang"zh-CN"> <head> <meta charset"UTF-8"> <title>写你自己的</title> <lin…

buildroot制作自己的软件包(可以理解为应用程序)

以helloworld为例记录使用步骤 一&#xff1a;书写自己的源程序以及Makefile helloworld.c #include <stdio.h>int main(int argc, char **argv) {printf("hello world\r\n");return 0; }Makefile all: helloworldhelloworld: helloworld.o$(CC) -o hellow…

关于嵌入式学习的一些短浅经验

一、写在前面 感谢在 10.23&#xff0c;各位大佬对我进行的模拟面试&#xff0c;我也发现了我对知识的不熟练的部分&#xff0c;比如 IPC 方法和线程同步方法的知识。模拟面试第四期-已经拿到大厂 OFFER 的研究生大佬-LINUX 卷到飞起_哔哩哔哩_bilibili 然后&#xff0c;沈阳…

OpenRTP 传输增加OpenRTPServer

开源地址 最近增加了OpenRTPServer&#xff0c; 已经修改完成一版放在了目录下&#xff0c;window和linux下编译都成功了&#xff0c;不过由于修改代码CMakefile 需要修改&#xff0c;先放放 OpenRTP开源地址 vlc得纠错传输方式 我发现我代码写错以后&#xff0c;vlc 依然能…

重要:民族共同体精品课格式说明

铸牢中华民族共同体意识精品课以微课形式呈现&#xff0c;包括微课 视频、教学设计讲义、课件等。 微课视频 微课视频应采用“教师讲解多媒体大屏”的形式&#xff0c;适当呈现授课教师画面&#xff0c;增强教学的交互性和画面的可视性。单个微课视频时长&#xff1a;高校专题…

【已解决】cannot import name ‘Literal‘ from ‘typing‘

问题描述 在用vscode进行debug的时候&#xff0c;报错cannot import name Literal from typing 解决方法 方法一&#xff1a;升级Python版本到3.8以上 我的python版本是3.7&#xff0c;但由于环境都配好了&#xff0c;升级太麻烦&#xff0c;没采用该方法 方法二&#xff1…

C++和Java该如何进行选择?

曾经的自己与许多C程序员都有着一样的盲目自信&#xff1a;认为掌握了C&#xff0c;在去学习Java上手会容易很多。 到底是谁给了你这种勇气和自信&#xff1f; 很多人经常会说&#xff0c;Java这种通过虚拟机运行的语言&#xff0c;虚拟机本身就是C开发的&#xff0c;根本就没…

Java 多线程(九)—— JUC 常见组件 与 线程安全的集合类

Callable 与 FutureTask Callable 接口和 Runnable 接口是并列关系&#xff0c;都是用来给线程提供任务的&#xff0c;只不过 Callable 接口的任务可以带有返回值。 但是 Callable 接口创建的任务不能直接传入 Thread 里面&#xff0c;这也是为了 解耦合&#xff0c;我们可以使…

pdf合并,这4款好用软件分分钟解决问题!

PDF作为一种跨平台、不易被篡改的文档格式&#xff0c;广泛应用于工作、学习和日常生活中。然而&#xff0c;当面对多个PDF文件需要合并成一个时&#xff0c;繁琐的手动操作往往让人头疼不已。别担心&#xff0c;今天就给大家安利4款超实用的PDF合并软件&#xff0c;它们不仅操…

c++二级指针

如果要通过函数改变一个指针的值&#xff0c;要往函数中传入指针的指针 如果要通过函数改变一个变量的值&#xff0c;那就要往函数中传入这个变量的地址 改变a的值和b的值 #include <iostream>using namespace std;void swap(int* a, int* b) {int temp *a;*a *b;*b …

pyvideotrans 最佳AI翻译软件

文章目录 体验视频翻译配音工具主要用途和功能预打包版本(仅win10/win11可用&#xff0c;MacOS/Linux系统使用源码部署)MacOS源码部署Linux 源码部署Window10/11 源码部署源码部署问题说明使用教程和文档语音识别模型:视频教程(第三方)软件预览截图相关联项目致谢 体验 不错&a…

根据Redis漏洞通知的整改修复过程

一、收到通知&#xff1a; 二、查看本校“宝山商城&#xff08;教学&#xff09;”已安装的Redis版本号 对照影响范围的版本号&#xff0c;在其内&#xff0c;所以需要升级Redis版本。 三、升级centos中的Redis版本 在Cent0S系统中&#xff0c;如果我们需要升级Redis版本&…

C++,STL 048(24.10.25)

内容 set容器对内置数据类型、自定义数据类型指定排序规则。 运行代码 &#xff08;1&#xff09;内置数据类型 #include <iostream> #include <set>using namespace std;// set容器默认排序规则为升序&#xff08;从小到大&#xff09;&#xff0c;可以通过仿函…

若依框架篇-若依集成 X-File-Storage 框架(实现图片上传阿里云 OSS 服务器)、EasyExcel 框架(实现 Excel 数据批量导入功能)

&#x1f525;博客主页&#xff1a; 【小扳_-CSDN博客】 ❤感谢大家点赞&#x1f44d;收藏⭐评论✍ 文章目录 1.0 实现使用 Excel 文件批量导入 1.1 导入功能的前端具体实现 1.2 导入功能的后端具体实现 1.3 使用 EasyExcel 框架实现 Excel 读、写功能 1.4 将 Easy Excel 集成到…

基于SSM+微信小程序考试的管理系统(考试1)

&#x1f449;文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1、项目介绍 基于SSM微信小程序考试的管理系统实现了管理员及用户。 1、管理员功能有个人中心&#xff0c;用户管理&#xff0c;考试资料管理&#xff0c;用户交流管理&#xff0c;试卷管理&#xff…

新能源汽车充电设施在储充电站的应用

0引言 全球能源和环境问题促使新能源汽车受到关注&#xff0c;但其推广受充电设施和能源供应限制。光伏站、储能站和电动汽车充放电站作为可再生能源利用和储存方式&#xff0c;具有巨大潜力。本研究旨在探索新能源汽车充电设施与这些站点的融合模式&#xff0c;以支持新能源汽…

【ROS GitHub使用】

提示&#xff1a;环境配置为Ubuntu20.04&ROS Noetic 文章目录 前言一、创建工作空间目录二、尝试从GitHub上下载一个源码包&#xff0c;对它进行编译&#xff0c;运行这个源码包1.打开script文件夹&#xff0c;右键文件夹空白区域&#xff0c;选择在中端中打开&#xff1b;…

OceanBase 安全体系解析之身份鉴别

本文作者&#xff1a;金长龙爱可生测试工程师&#xff0c;负责 DMP 产品的测试工作。 本文以MySQL为参照&#xff0c;详细阐述了OceanBase 在MySQL模式下的安全体系中&#xff0c;身份鉴别的能力&#xff0c;涵盖了身份鉴别机制、用户名的构成规则、密码的复杂度&#xff0c;以…

ctfshow(66->70)--RCE/命令执行漏洞--禁用命令执行函数

Web66 源代码&#xff1a; if(isset($_POST[c])){$c $_POST[c];eval($c); }else{highlight_file(__FILE__); }代码审计&#xff1a; POST传参c&#xff0c;eval进行代码执行。 思路&#xff1a; 由于题目过滤了命令执行函数&#xff0c;所以使用其他方法进行RCE。 先使用c…