用Pytorch实现线性回归模型

目录

  • 回顾
  • Pytorch实现
    • 步骤
    • 1. 准备数据
    • 2. 设计模型
      • class LinearModel
      • 代码
    • 3. 构造损失函数和优化器
    • 4. 训练过程
    • 5. 输出和测试
    • 完整代码
  • 练习

回顾

前面已经学习过线性模型相关的内容,实现线性模型的过程并没有使用到Pytorch。
这节课主要是利用Pytorch实现线性模型。
学习器训练:

  • 确定模型(函数)
  • 定义损失函数
  • 优化器优化(SGD)

之前用过Pytorch的Tensor进行Forward、Backward计算。
现在利用Pytorch框架来实现。

Pytorch实现

步骤

  1. 准备数据集
  2. 设计模型(计算预测值y_hat):从nn.Module模块继承
  3. 构造损失函数和优化器:使用PytorchAPI
  4. 训练过程:Forward、Backward、update

1. 准备数据

在PyTorch中计算图是通过mini-batch形式进行,所以X、Y都是多维的Tensor。
在这里插入图片描述

import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

2. 设计模型

在之前讲解梯度下降算法时,我们需要自己计算出梯度,然后更新权重。
在这里插入图片描述
而使用Pytorch构造模型,重点时在构建计算图和损失函数上。
在这里插入图片描述

class LinearModel

通过构造一个 class LinearModel类来实现,所有的模型类都需要继承nn.Module,这是所有神经忘了模块的基础类。
class LinearModel这种定义的模型类必须包含两个部分:

  • init():构造函数,进行初始化。
    def __init__(self):
        super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。
        # torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元
        self.linear = torch.nn.Linear(1, 1)

在这里插入图片描述

  • forward():进行前馈计算
    (backward没有被写,是因为在这种模型类里面会自动实现)

Class nn.Linear 实现了magic method call():它使类的实例可以像函数一样被调用。通常会调用forward()。

    def forward(self, x):
        y_pred = self.linear(x)#调用linear对象,输入x进行预测
        return y_pred

代码

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。
        # torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元
        self.linear = torch.nn.Linear(1, 1)
    def forward(self, x):
        y_pred = self.linear(x)#调用linear对象,输入x进行预测
        return y_pred

model = LinearModel()#实例化LinearModel()

3. 构造损失函数和优化器

采用MSE作为损失函数

torch.nn.MSELoss(size_average,reduce)

  • size_average:是否求mini-batch的平均loss。
  • reduce:降维,不用管。

在这里插入图片描述SGD作为优化器torch.optim.SGD(params, lr):

  • params:参数
  • lr:学习率

在这里插入图片描述

criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b

4. 训练过程

  1. 预测
  2. 计算loss
  3. 梯度清零
  4. Backward
  5. 参数更新
    简化:Forward–>Backward–>更新
#4. Training Cycle
for epoch in range(100):
    y_pred = model(x_data)#Forward:预测
    loss = criterion(y_pred, y_data)#Forward:计算loss
    print(epoch, loss)
    optimizer.zero_grad()#梯度清零
    loss.backward()#backward:计算梯度
    optimizer.step()#通过step()函数进行参数更新

5. 输出和测试

# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

完整代码

import torch
#1. Prepare dataset
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

#2. Design Model
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。
        # torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元
        self.linear = torch.nn.Linear(1, 1)
    def forward(self, x):
        y_pred = self.linear(x)#调用linear对象,输入x进行预测
        return y_pred

model = LinearModel()#实例化LinearModel()

# 3. Construct Loss and Optimize
criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b

#4. Training Cycle
for epoch in range(100):
    y_pred = model(x_data)#Forward:预测
    loss = criterion(y_pred, y_data)#Forward:计算loss
    print(epoch, loss)
    optimizer.zero_grad()#梯度清零
    loss.backward()#backward:计算梯度
    optimizer.step()#通过step()函数进行参数更新

# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

输出结果:

85 tensor(0.2294, grad_fn=)
86 tensor(0.2261, grad_fn=)
87 tensor(0.2228, grad_fn=)
88 tensor(0.2196, grad_fn=)
89 tensor(0.2165, grad_fn=)
90 tensor(0.2134, grad_fn=)
91 tensor(0.2103, grad_fn=)
92 tensor(0.2073, grad_fn=)
93 tensor(0.2043, grad_fn=)
94 tensor(0.2014, grad_fn=)
95 tensor(0.1985, grad_fn=)
96 tensor(0.1956, grad_fn=)
97 tensor(0.1928, grad_fn=)
98 tensor(0.1900, grad_fn=)
99 tensor(0.1873, grad_fn=)
w = 1.711882472038269
b = 0.654958963394165
y_pred = tensor([[7.5025]])

可以看到误差还比较大,可以增加训练轮次,训练1000次后的结果:

980 tensor(2.1981e-07, grad_fn=)
981 tensor(2.1671e-07, grad_fn=)
982 tensor(2.1329e-07, grad_fn=)
983 tensor(2.1032e-07, grad_fn=)
984 tensor(2.0737e-07, grad_fn=)
985 tensor(2.0420e-07, grad_fn=)
986 tensor(2.0143e-07, grad_fn=)
987 tensor(1.9854e-07, grad_fn=)
988 tensor(1.9565e-07, grad_fn=)
989 tensor(1.9260e-07, grad_fn=)
990 tensor(1.8995e-07, grad_fn=)
991 tensor(1.8728e-07, grad_fn=)
992 tensor(1.8464e-07, grad_fn=)
993 tensor(1.8188e-07, grad_fn=)
994 tensor(1.7924e-07, grad_fn=)
995 tensor(1.7669e-07, grad_fn=)
996 tensor(1.7435e-07, grad_fn=)
997 tensor(1.7181e-07, grad_fn=)
998 tensor(1.6931e-07, grad_fn=)
999 tensor(1.6700e-07, grad_fn=)
w = 1.9997280836105347
b = 0.0006181497010402381
y_pred = tensor([[7.9995]])

练习

用以下这些优化器替换SGD,得到训练结果并画出损失曲线图。
在这里插入图片描述
比如说:Adam的loss图:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/324203.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

DNS主从服务器配置

主从服务器配置: (1)完全区域传送:复制整个区域文件 #主DNS服务器的配置【主dns服务器的ip地址为192.168.168.129】 #编辑DNS系统配置信息(我这里写的增加的信息,源文件里面有很多内容) [root…

(超详细)4-YOLOV5改进-添加ShuffleAttention注意力机制

1、在yolov5/models下面新建一个ShuffleAttention.py文件,在里面放入下面的代码 代码如下: import numpy as np import torch from torch import nn from torch.nn import init from torch.nn.parameter import Parameterclass ShuffleAttention(nn.…

今天吃什么小游戏(基于Flask框架搭建的简单应用程序,用于随机选择午餐选项。代码分为两部分:Python部分和HTML模板部分)

今天吃什么 一个简单有趣的外卖点饭网站,不知道吃什么的时候,都可以用它自动决定你要吃的,包括各种烧烤、火锅、螺蛳粉、刀削面、小笼包、麦当劳等午餐全部都在内。点击开始它会随意调出不同的午餐,点击停止就会挑选一个你准备要吃…

小红书家居博主报价?怎么和博主合作?

小红书上各式各样的家居博主层出不穷,这些博主不仅为粉丝提供了家居装修的灵感,更为品牌带来了巨大的商业价值。 在当下家居市场竞争激烈的环境中,品牌与家居博主合作已成为了营销策略中的重要一环。博主们庞大的粉丝群体、丰富的内容产出以…

腾讯云服务器多少钱?2024年腾讯云服务器报价明细表

腾讯云服务器租用价格表:轻量应用服务器2核2G3M价格62元一年、2核2G4M价格118元一年,540元三年、2核4G5M带宽218元一年,2核4G5M带宽756元三年、轻量4核8G12M服务器446元一年、646元15个月,云服务器CVM S5实例2核2G配置280.8元一年…

区间预测 | Matlab实现LSSVM-ABKDE的最小二乘支持向量机结合自适应带宽核密度估计多变量回归区间预测

区间预测 | Matlab实现LSSVM-ABKDE的最小二乘支持向量机结合自适应带宽核密度估计多变量回归区间预测 目录 区间预测 | Matlab实现LSSVM-ABKDE的最小二乘支持向量机结合自适应带宽核密度估计多变量回归区间预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实现…

【Spring 篇】SpringMVC的数据响应:编织美妙的返回乐章

在Web开发的舞台上,数据响应就如同一场美妙的音乐演奏,而SpringMVC作为这场音乐的指挥者,如何优雅地将数据传递给前端,引发了无尽的思考和探索。本篇博客将带你走进SpringMVC的数据响应世界,解开其中的奥秘&#xff0c…

在windows11系统上利用docker搭建ubuntu记录

我的windows11系统上,之前已经安装好了window版本的docker,没有安装的小伙伴需要去安装一下。 下面直接记录安装linux的步骤: 一、创建linux容器 1、拉取镜像 docker pull ubuntu 2、查看镜像 docker images 3、创建容器 docker run --…

Kafka消费流程

Kafka消费流程 消息是如何被消费者消费掉的。其中最核心的有以下内容。 1、多线程安全问题 2、群组协调 3、分区再均衡 1.多线程安全问题 当多个线程访问某个类时,这个类始终都能表现出正确的行为,那么就称这个类是线程安全的。 对于线程安全&…

VTK开发调试环境下载(VTK开发环境一步到位直接开发,无需自己配置编译 VS2017+Qt5.12.10+VTK)

一、无与伦比的优势 直接下载代码就可以调试的VTK代码仓库。 二、资源制作原理 这个资源根据VTK源码 编译出动态库文件 pdb lib dll 文件( x64 debug ) 并将这两者同时放在一个代码仓库里,下载就能用。 三、使用方法(vtk-so…

如何结合主从复制,不停服情况下解决分库分表

首先我们要知道主从复制和分库分表两个概念,在此基础上可以将问题分为几个阶段来执行,参考了公众号 双写读老 双写双读 写新读新

为什么单片机上的程序不怎么使用malloc,而PC上经常使用?

为什么单片机上的程序不怎么使用malloc,而PC上经常使用? 在开始前我有一些资料,是我根据网友给的问题精心整理了一份「单片机的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿…

【新】Unity Meta Quest MR 开发(一):Passthrough 透视配置

文章目录 📕教程说明📕配置透视的串流调试功能📕第一步:设置 OVRManager📕第二步:添加 OVRPassthroughLayer 脚本📕第三步:在场景中添加虚拟物体📕第四步:设置…

2024年腾讯云服务器配置价格表(机型/磁盘/宽带/CPU)

腾讯云服务器租用价格表:轻量应用服务器2核2G3M价格62元一年、2核2G4M价格118元一年,540元三年、2核4G5M带宽218元一年,2核4G5M带宽756元三年、轻量4核8G12M服务器446元一年、646元15个月,云服务器CVM S5实例2核2G配置280.8元一年…

vue配置qiankun及打包上线

项目结构 基座:vue3 子应用A:vue3 子应用B: react 子应用C:vue3vite 项目目录: 配置基座 首先下载qiankun yarn add qiankun # 或者 npm i qiankun -S 所有子应用也要安装,vue-vite项目安装 cnpm ins…

【shell】读取表格文件的数据

碎碎念 shell在处理复杂问题的时候不具备优势,如果业务环境能够使用python的话用python又简单又好用,但是很多云平台的现场可能需要shell脚本文件(还好是要求bash) 但是现在有一个业务场景就是运维人员会把参数写在excel表格中 …

不同光照下HUD抬头显示器光干扰试验用太阳光模拟器

HUD干扰太阳光模拟器是机载光电系统测试中常见的问题之一。在机载光电系统测试中,太阳光模拟器是一种重要的测试设备,它可以模拟不同光照条件下的机载光电系统性能,为系统优化和调试提供数据支持。然而,当太阳光模拟器与HUD交叉作…

Python之列表中常见的方法

1.创建一个列表 list1 [1, 2, 3, 4] list2 list("1234") print(list1, list2) print(list1 list2) # 以上创建的两个列表是等价的,都是[1, 2, 3, 4] 2.添加新元素 # 末尾追加 a [1, 2, 3, 4, 5] a.append(6) print(a)# 指定位置的前面插入一个元素 a.insert(2, 1…

Java泛型的继承和通配符

泛型的继承和通配符 继承 两个容器所容纳的类类型是有子类父类的关系的 但是容器之间没有 反证法&#xff1a; 假设做法成立 ArrayList<Object> list1 null;ArrayList<String> list2 - new ArrayList<>();list1list2 list 指向list2的容器实例 list1.add&…

积极参与建设“一带一路”,川宁生物与微构工场达成战略合作

2024年1月12日&#xff0c;北京微构工场生物技术有限公司&#xff08;以下简称“微构工场”&#xff09;与伊犁川宁生物技术股份有限公司&#xff08;“川宁生物”&#xff09;宣布签订战略合作协议&#xff0c;双方将共同出资设立合资公司&#xff0c;加速生物制造产业化落地&…