Lucas带你手撕机器学习——线性回归

什么是线性回归

线性回归是机器学习中的基础算法之一,用于预测一个连续的输出值。它假设输入特征与输出值之间的关系是线性关系,即目标变量是输入变量的线性组合。我们可以从代码实现的角度来学习线性回归,包括如何使用 Python 进行简单的线性回归模型构建、训练、和预测。

线性回归的直观理解

你可以把线性回归理解成“画一条线来预测未来”。假设你有一张散点图,每个点代表某个物品的重量和它的价格。你的目标是找到一条直线,能够尽可能准确地描述这些点之间的关系。

线性回归的工作原理

假设我们有一些数据点,每个点都有一个输入(如重量)和一个输出(如价格)。线性回归就是在这些点之间找到一条直线,使得这条线能够“最好”地描述这些数据点。

这条直线的公式是:

在这里插入图片描述

其中:

  • y:输出,即我们想要预测的值(例如,物品的价格)
  • x:输入特征(例如,物品的重量)
  • w:线的斜率,表示重量对价格的影响有多大
  • b:截距,表示当重量为 0 时,预测的价格是多少

线性回归的基本原理

线性回归的数学公式为:

在这里插入图片描述

其中:

  • y 是预测值(目标变量)
  • x1,x2,…,xn 是输入特征
  • w1,w2,…,wn 是特征对应的权重(回归系数)
  • b 是偏置项(截距)

如何找到“最好的”直线?

“最好的”直线是指那些经过这条直线的点尽可能接近数据点。为了衡量直线的好坏,我们需要一个方法来计算直线与数据点之间的差距。

误差的概念
  • 对于每个数据点,我们可以计算它的实际价格(真实值)和用这条直线预测出来的价格之间的差距,称为“误差”。
  • 比如说,某个物品的真实价格是 10 元,但通过直线预测出来的价格是 9 元,那么这个点的误差就是 10−9=1。
均方误差(Mean Squared Error,MSE)

为了让误差的计算更稳定,我们通常不直接使用误差,而是使用“均方误差”来衡量模型的好坏:

在这里插入图片描述

其中:

  • yi:第 i 个样本的真实值
  • yi^:第 i 个样本通过模型预测的值
  • N:样本数量

均方误差的作用就是将所有数据点的误差平方后取平均值,这样可以确保误差不会因为正负抵消。我们的目标是让这个均方误差尽可能小,意味着直线与数据点之间的差距最小。

训练模型

在实际训练过程中,我们会不断调整直线的斜率 w 和截距 b,直到找到使均方误差最小的那一组 w 和 b。这就意味着找到了“最好的”直线。

代码实现

使用 Scikit-Learn 实现****线性回归

我们可以使用 Scikit-Learn 库,它提供了非常简洁的接口来进行线性回归。下面是一个完整的示例代码:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

# 生成一些模拟数据
np.random.seed(42)
X = 2 * np.random.rand(100, 1)  # 输入特征,100 个样本,1 个特征
y = 4 + 3 * X + np.random.randn(100, 1)  # 线性关系 y = 4 + 3x + 噪声

# 拆分数据为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建线性回归模型并进行训练
model = LinearRegression()
model.fit(X_train, y_train)

# 输出模型的系数和截距
print(f'权重(w): {model.coef_[0][0]}')
print(f'截距(b): {model.intercept_[0]}')

# 预测并计算均方误差
y_pred = model.predict(X_test)
mse = mean_squared_error(y_test, y_pred)
print(f'测试集上的 MSE: {mse}')

# 可视化结果
plt.scatter(X_test, y_test, color='blue', label='真实值')
plt.plot(X_test, y_pred, color='red', label='预测值', linewidth=2)
plt.xlabel('X')
plt.ylabel('y')
plt.legend()
plt.title('线性回归拟合结果')
plt.show()

在这里插入图片描述

  1. 代码解释
  • 生成模拟数据: 生成了一些随机数据点 X和 y,其中 y=4 + 3X + 噪声,这样我们就有一个线性关系的示例数据。
  • 数据集拆分: 使用 train_test_split 将数据集拆分成训练集和测试集,80% 用于训练,20% 用于测试。
  • 训练模型: 使用 LinearRegression 类创建模型,并用训练集数据拟合模型。
  • 预测和评估: 使用测试集进行预测,计算预测值与真实值之间的均方误差(MSE)。
  • 结果可视化: 将真实值和预测结果在图中可视化,可以清楚地看到线性回归的拟合效果。

PyTorch 实现线性回归

为了更好地理解线性回归的原理,我们也可以使用 PyTorch 从头实现一个简单的线性回归模型:

import torch
import torch.nn as nn
import torch.optim as optim

# 生成模拟数据
torch.manual_seed(42)
X = torch.randn(100, 1) * 2
y = 4 + 3 * X + torch.randn(100, 1)

# 定义线性模型
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(1, 1)  # 输入 1 维,输出 1 维

    def forward(self, x):
        return self.linear(x)

# 创建模型、损失函数和优化器
model = LinearRegressionModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
epochs = 1000
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}')

# 输出训练好的模型参数
[w, b] = model.parameters()
print(f'权重(w): {w.item()}')
print(f'截距(b): {b.item()}')

代码解释

  • 定义模型: 使用 nn.Module 定义了一个简单的线性模型,只包含一个线性层。
  • 定义损失函数和优化器: 选择均方误差作为损失函数(nn.MSELoss()),使用随机梯度下降(optim.SGD)优化模型。
  • 模型训练: 通过前向传播计算损失,通过反向传播计算梯度并更新模型参数。

总结

以上两种方法分别使用 Scikit-Learn 和 PyTorch 实现了线性回归模型。Scikit-Learn 的方式适合快速建模和测试,而 PyTorch 版本则更灵活,更适合理解深度学习模型的训练过程。掌握这些方法后,可以将它们应用于更复杂的模型和任务中。

感谢阅读!!我是正在澳洲深造的Lucas!!
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/895092.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Thread的基本用法

创建线程 方法一 继承Thread类 继承 Thread 来创建一个线程类. class MyThread extends Thread {Overridepublic void run() {System.out.println("这里是线程运行的代码");} } 创建 MyThread 类的实例 MyThread t new MyThread(); 调用 start 方法启动线程 t…

《深空彼岸》TXT完整版下载,知轩藏书校对版!

【内容简介】:   浩瀚的宇宙中,一片星系的生灭,也不过是刹那的斑驳流光。仰望星空,总有种结局已注定的伤感,千百年后你我在哪里?家国,文明火光,地球,都不过是深空中的一…

2024 年 9 月区块链游戏研报:行业回暖,Telegram 游戏引发热潮

作者:Stella L (stellafootprint.network) 数据来源:Footprint Analytics Games Research Page 9 月份,区块链游戏代币的市场总值增长了 29.2%,达到 232 亿美元,日活跃用户(DAU)数量上升了 1…

C++进阶:AVL树实现

目录 一.AVL的概念 二.AVL的实现 2.1AVL树的结构 2.2AVL树的插入 2.2.1AVL树插入一个值的大概过程 2.2.2平衡因子更新 2.2.3插入节点及更新平衡因子的实现 2.3旋转 2.3.1旋转的原则 2.3.2右单旋 2.3.3右单旋的代码实现 2.3.4左单旋 2.3.5左单旋的代码实现 2.3.6…

腾讯云宝塔面板前后端项目发版

后端发版 1. 打开“网站”页面,找到java项目,点击状态暂停服务 2.打开“文件”页面,进入jar包目录,删除原有的jar包,上传新jar包 3. 再回到第一步中的网站页面,找到jar项目,启动项目即可 前端发…

项目一:3-8译码器的设计与实现(FPGA)

本文以Altera公司生产的Cyclone IV系列的EP4CE15F17C8为主芯片的CRD500开发板作为项目的硬件实现平台,并以Quarter 18.1和ModelSim为开发工具和仿真工具。 目录 一、3-8译码器工作原理 二、设计步骤 1、创建工程文件夹和编辑设计文件 (1)…

微信小程序上传组件封装uploadHelper2.0使用整理

一、uploadHelper2.0使用步骤说明 uploadHelper.js ---上传代码封装库 cos-wx-sdk-v5.min.js---腾讯云,对象存储封装库 第一步,下载组件代码,放置到自己的小程序项目中 第二步、 创建上传对象,执行选择图片/视频 var _this th…

【成长day】SuperPointSuperGlue(02): Superglue论文算法学习与对应源码解析

论文工作地址:https://psarlin.com/superglue/ 论文地址:https://arxiv.org/abs/1911.11763 讲解PPT:https://psarlin.com/superglue/doc/superglue_slides.pdf 论文源码:https://github.com/magicleap/SuperGluePretrainedNetwor…

WebRTC音频 03 - 实时通信框架

WebRTC音频01 - 设备管理 WebRTC音频 02 - Windows平台设备管理 WebRTC音频 03 - 实时通信框架(本文) WebRTC音频 04 - 关键类 WebRTC音频 05 - 音频采集编码 一、前言: 前面介绍了音频设备管理,并且以windows平台为例子,介绍了ADM相关的类…

【分立元件】方形贴片固定电阻器制造流程

方形贴片固定电阻器是怎么制造的呢?我们在文章【分立元件】电阻的基础知识中介绍到电阻器中的固定电阻器,其品种有贴片电阻器。 贴片电阻器如下所示&#

vuex的store应用

1.在pakage.json加一行 2.和main同级别加一个js文件 import Vue from vue import Vuex from vuexVue.use(Vuex)export default new Vuex.Store({state: {langFlag: new Date().getTime()},mutations: {setLangFlag(state) {state.langFlag new Date().getTime()}} })3.在mai…

Shiro框架——shiro的认证

基本使用 1.环境搭建 引入pom依赖 说明:Shiro获取权限相关信息可以通过数据库获取,也可以通过ini配置文件获取 这里演示从ini文件中获取。 在resources目录下创建ini文件注:这里等号左边的(如:zhangsan),就代表用户…

STM32--基于STM32F103C8T6的OV7670摄像头显示

本文介绍基于STM32F103C8T6实现的OV7670摄像头显示设计(完整资源及代码见文末链接) 一、简介 本文实现的功能:基于STM32F103C8T6实现的OV7670摄像头模组实时在2.2寸TFT彩屏上显示出来 所需硬件: STM32F103C8T6最小系统板、OV76…

学习docker第三弹------Docker镜像以及推送拉取镜像到阿里云公有仓库和私有仓库

docker目录 1 Docker镜像dockers镜像的进一步理解 2 Docker镜像commit操作实例案例内容是ubuntu安装vim 3 将本地镜像推送至阿里云4 将阿里云镜像下载到本地仓库5 后记 1 Docker镜像 镜像,是docker的三件套之一(镜像、容器、仓库)&#xff0…

大模型~合集14

我自己的原文哦~ https://blog.51cto.com/whaosoft/12286799 # Attention as an RNN Bengio等人新作:注意力可被视为RNN,新模型媲美Transformer,但超级省内 , 既能像 Transformer 一样并行训练,推理时内存需求又不随 token 数线性…

基于因果推理的强对流降水临近预报问题研究

我国地域辽阔,自然条件复杂,灾害性天气种类繁多,地区差异性大。雷雨大风、冰雹、短时强降水等强对流天气是造成经济损失、危害生命安全最严重的一类灾害性天气。由于强对流降水具有高强度、小空间尺度等特点,一直是气象预报领域的…

vue组件传值之$attrs

1.概述:$attrs用于实现当前组件的父组件,向当前组件的子组件通信(祖-》孙) 2.具体说明:$attrs是一个对象,包含所有父组件传入的标签属性。 注意:$attrs会自动排除props中声明的属性&#xff0…

矩阵系统哪家好~矩阵短视频运营~怎么矩阵OEM

一、引言 在当今的数字化时代,矩阵系统在众多领域中发挥着至关重要的作用,如视频监控、信号切换、自动化控制等。然而,如何判断一个矩阵系统是否好用成为了许多用户面临的问题。本文将从多个方面探讨矩阵系统好用与否的判断标准,希…

Python | Leetcode Python题解之第492题构造矩形

题目: 题解: class Solution:def constructRectangle(self, area: int) -> List[int]:w int(sqrt(area))while area % w:w - 1return [area // w, w]

QtCreator14调试Qt5.15出现 Launching Debugger 错误

1、问题描述 使用QtCreator14调试程序,Launching Debugger 显示红色,无法进入调试模式。 故障现象如下: 使能Debugger Log窗口,显示: 325^error,msg"Error while executing Python code." 不过&#xff…