16.线性回归代码实现

线性回归的实操与理解

介绍

线性回归是一种广泛应用的统计方法,用于建模一个或多个自变量(特征)与因变量(目标)之间的线性关系。在机器学习和数据科学中,线性回归是许多入门者的第一个模型,它提供了对监督学习问题的基础理解。本文将介绍线性回归的基本概念,并通过Python和PyTorch库来实操线性回归模型,深入理解其训练和预测过程。

线性回归的基本概念

线性回归假设目标变量(y)是输入变量(X)的线性组合,并可以通过最小二乘法来估计模型的参数(权重w和偏置b)。数学上,线性回归模型可以表示为:

y=w1​x1​+w2​x2​+…+wn​xn​+b

或者更一般地,使用矩阵形式表示:

y=XW+b

其中,X 是特征矩阵,W 是权重向量,b 是偏置项。

实操:使用PyTorch实现线性回归

1. 导入必要的库

首先,我们需要导入PyTorch和其他必要的库。

import torch  
import torch.nn as nn  
import torch.optim as optim  
import numpy as np  
import matplotlib.pyplot as plt


2. 生成模拟数据

为了演示线性回归,我们将生成一些模拟数据。

# 设置随机种子  
torch.manual_seed(0)  
np.random.seed(0)  
  
# 生成数据  
n_samples = 100  
x = torch.randn(n_samples, 1) * 10  # 输入数据  
w_true = 2  
b_true = 1  
y = x * w_true + b_true + torch.randn(n_samples, 1) * 0.5  # 真实标签


3. 定义线性回归模型

使用PyTorch的nn.Module来定义线性回归模型。

class LinearRegressionModel(nn.Module):  
    def __init__(self, input_dim=1, output_dim=1):  
        super(LinearRegressionModel, self).__init__()  
        self.linear = nn.Linear(input_dim, output_dim)  
  
    def forward(self, x):  
        out = self.linear(x)  
        return out


4. 初始化模型和优化器

实例化模型,并定义损失函数和优化器。

# 初始化模型  
model = LinearRegressionModel()  
  
# 定义损失函数和优化器  
criterion = nn.MSELoss()  
optimizer = optim.SGD(model.parameters(), lr=0.01)


5. 训练模型

通过迭代训练数据来训练模型。

# 训练模型  
num_epochs = 1000  
for epoch in range(num_epochs):  
    # 前向传播  
    outputs = model(x)  
    loss = criterion(outputs, y)  
      
    # 反向传播和优化  
    optimizer.zero_grad()  # 清空梯度  
    loss.backward()  # 反向传播计算梯度  
    optimizer.step()  # 更新参数  
  
    if (epoch+1) % 100 == 0:  
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))


6. 评估模型

在训练完成后,我们可以评估模型的性能。但在这个简单的例子中,我们主要关注于模型是否能学习到正确的权重和偏置。

7. 可视化结果

我们可以将预测结果和真实数据可视化出来。

# 提取训练后的参数  
w, b = model.linear.weight.item(), model.linear.bias.item()  
print('w = {}, b = {}'.format(w, b))  
  
# 可视化结果  
predicted = model(x).detach().numpy()  
plt.scatter(x.numpy(), y.numpy(), color='blue', label='True data')  
plt.plot(x.numpy(), predicted, color='red', linewidth=2, label='Predicted data')  
plt.legend()  
plt.show()


总结

通过本文的实操,我们深入理解了线性回归的基本原理和其在PyTorch中的实现方式。我们生成了模拟数据,定义了线性回归模型,并使用随机梯度下降优化器来训练模型。通过可视化结果,我们可以看到模型能够很好地拟合生成的数据,并且学习到的权重和偏置与真实

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/641490.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【机器学习】机器学习基础概念与初步探索

❀机器学习 📒1. 引言📒2. 机器学习概述📒3. 机器学习基础概念🎉2.1 机器学习的分类🎉2.2 数据预处理🌈数据清洗与整合🌈 特征选择和特征工程🌈数据标准化与归一化 📒4. …

Android Studio 所有历史版本下载

一、官网链接 https://developer.android.google.cn/studio/archive 操作 二、AndroidDevTools地址 https://www.androiddevtools.cn/ 参考 https://blog.csdn.net/qq_27623455/article/details/103008937

电表远传抄表是什么?

1.电表远传抄表:简述 电表远传抄表,又称为远程控制自动抄表系统,是电力行业的智能化技术运用,它通过无线或通信网络技术,完成对电表数据信息的远程收集解决。此项技术不仅提升了抄水表高效率,降低了人工偏…

RK3568平台(UART篇)uart应用编程读取模块数据

一.串口介绍 串口设备是嵌入式开发中最常用的外设之一,通过串口打印信息可以调试程序的运行,通 过串口也可以链接很多种外设,比如串口打印机,蓝牙,wifi,GPS,GPRS 等等。 数据传输方式&#xf…

C++ | Leetcode C++题解之第97题交错字符串

题目&#xff1a; 题解&#xff1a; class Solution { public:bool isInterleave(string s1, string s2, string s3) {auto f vector <int> (s2.size() 1, false);int n s1.size(), m s2.size(), t s3.size();if (n m ! t) {return false;}f[0] true;for (int i …

全同态加密生态项目盘点:FHE技术的崛起以及应用

撰文&#xff1a;Chris&#xff0c;Techub News 在当今数字化的时代&#xff0c;隐私保护已成为一个全球性的焦点话题&#xff0c;特别是在加密货币和区块链技术快速发展的背景下。虽然当前的隐私技术在保护数据安全方面多有欠缺&#xff0c;引发了广泛的关注和批评&#xff0c…

Java枚举的本质

目录 1.枚举简介 1.1.规范 1.2.枚举类真实的样子 1.3.枚举类的特点 1.4.枚举可以使用的方法 1.4.1.toString()方法 1.4.2.valueOf方法 1.4.3.values方法 1.4.4.ordinal方法 1.5.枚举的用法 1.5.1.常量 1.5.2.switch 1.5.3.枚举中增加方法 1.5.4.覆盖枚举方法 1.5…

热题系列章节1

22. 括号生成 数字 n 代表生成括号的对数&#xff0c;请你设计一个函数&#xff0c;用于能够生成所有可能的并且 有效的 括号组合。 示例 1&#xff1a; 输入&#xff1a;n 3 输出&#xff1a;[“((()))”,“(()())”,“(())()”,“()(())”,“()()()”] 示例 2&#xff1a…

LeetCode/NowCoder-链表经典算法OJ练习3

孜孜不倦&#xff1a;孜孜&#xff1a;勤勉&#xff0c;不懈怠。指工作或学习勤奋不知疲倦。&#x1f493;&#x1f493;&#x1f493; 目录 说在前面 题目一&#xff1a;返回倒数第k个节点 题目二&#xff1a;链表的回文结构 题目三&#xff1a;相交链表 SUMUP结尾 说在前…

两篇文章讲透数据结构之堆(一)!

目录 1.堆的概念 2.堆的实现方式 3.堆的功能 4.堆的声明 5.堆的实现 5.1堆的初始化 5.2堆的插入 5.2.1向上调整算法 5.2.2堆的插入 5.3堆的删除 5.3.1向下调整算法 5.3.2堆的删除 5.4获取堆顶元素 5.5获取堆的元素个数 5.6判断堆是否为空 5.7打印堆 5.8建堆 …

SQL开窗函数

文章目录 概念&#xff1a;语法&#xff1a;常用的窗口函数及示例&#xff1a;求平均值&#xff1a;AVG() &#xff1a;求和&#xff1a;SUM():求排名&#xff1a;移动平均计数COUNT():求最大MXA()/小MIN()值求分区内的最大/最小值求当前行的前/后一个值 概念&#xff1a; 开窗…

算法题1:电路开关(HW)

题目描述 实验室对一个设备进行通断测试,实验员可以操控开关进行通断,有两种情况: ps,图没记下来,凭印象画了类似的 初始时,3个开关的状态均为断开;现给定实验员操控记录的数组 records ,records[i] = [time, switchId],表示在时刻 time 更改了开关 switchId 的状态…

多线程(C++11)

多线程&#xff08;C&#xff09; 文章目录 多线程&#xff08;C&#xff09;前言一、std::thread类1.线程的创建1.1构造函数1.2代码演示 2.公共成员函数2.1 get_id()2.2 join()2.3 detach()2.4 joinable()2.5 operator 3.静态函数4.类的成员函数作为子线程的任务函数 二、call…

AOP编程

AOP编程 AOP&#xff0c;面向切面编程&#xff0c;一种编程范式&#xff0c;指导开发者如何组织程序结构。 OOP&#xff0c;面向对象编程&#xff0c;一种编程思想。 AOP&#xff0c;提供了一种机制,可以将一些横切系统中多个模块的共同逻辑(如日志记录、事务管理、安全控制等…

SQL面试题练习 —— 波峰波谷

来源&#xff1a;字节今日头条 目录 1 题目2 建表语句3 题解 1 题目 有如下数据&#xff0c;记录每天每只股票的收盘价格&#xff0c;请查出每只股票的波峰和波谷的日期和价格&#xff1b; 波峰定义&#xff1a;股票价格高于前一天和后一天价格时为波峰 波谷定义&#xff1a;股…

MoE 系列论文解读:Gshard、FastMoE、Tutel、MegaBlocks 等

节前&#xff0c;我们组织了一场算法岗技术&面试讨论会&#xff0c;邀请了一些互联网大厂朋友、今年参加社招和校招面试的同学。 针对大模型技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备面试攻略、面试常考点等热门话题进行了深入的讨论。 总结链接…

Unity在Windows平台播放HEVC/H.265格式视频的底层原理

相关术语、概念 HEVC/H.265 HEVC&#xff08;High Efficiency Video Coding&#xff09;是一种视频压缩标准&#xff0c;也被称为H.265。它是一种高效的视频编码标准&#xff0c;可以提供比之前的标准&#xff08;如H.264&#xff09;更高的压缩率&#xff0c;同时保持较高的…

力扣HOT100 - 31. 下一个排列

解题思路&#xff1a; 数字是逐步增大的 步骤如下&#xff1a; class Solution {public void nextPermutation(int[] nums) {int i nums.length - 2;while (i > 0 && nums[i] > nums[i 1]) i--;if (i > 0) {int j nums.length - 1;while (j > 0 &&…

015_表驱动编程思想(c实现)

【背景】 数据压倒一切。如果选择了正确的数据结构并把一切组织的井井有条&#xff0c;正确的算法就不言自明。编程的核心是数据结构&#xff0c;而不是算法。 ——Rob Pike 上面是这个名人说过的话&#xff0c;那么c语言之父 丹尼斯麦卡利斯泰尔里奇 的《c程序设计》里曾经…

【Linux取经路】基于信号量和环形队列的生产消费者模型

文章目录 一、POSIX 信号量二、POSIX 信号量的接口2.1 sem_init——初始化信号量2.2 sem_destroy——销毁信号量2.3 sem_wait——等待信号量2.4 sem_post——发布信号量 三、基于环形队列的生产消费者模型3.1 单生产单消费模型3.2 多生产多消费模型3.3 基于任务的多生产多消费模…