深度学习之pytorch实现线性回归

度学习之pytorch实现线性回归

  • pytorch用到的函数
    • torch.nn.Linearn()函数
    • torch.nn.MSELoss()函数
    • torch.optim.SGD()
  • 代码实现
  • 结果分析

pytorch用到的函数

torch.nn.Linearn()函数

torch.nn.Linear(in_features, # 输入的神经元个数
           out_features, # 输出神经元个数
           bias=True # 是否包含偏置
           )

在这里插入图片描述

作用j进行线性变换
Linear(1, 1) : 表示一维输入,一维输出

torch.nn.MSELoss()函数

在这里插入图片描述

torch.optim.SGD()

优化器对象
在这里插入图片描述

代码实现

import torch

x_data = torch.tensor([[1.0], [2.0], [3.0]])  # 将x_data设置为tensor类型数据
y_data = torch.tensor([[2.0], [4.0], [6.0]])


class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()  # 继承父类
        self.linear = torch.nn.Linear(1, 1)
        # 用torch.nn.Linear来构造对象  (y = w * x + b)
        
    def forward(self, x):
        y_pred = self.linear(x) #调用之前的构造的对象(调用构造函数),计算 y = w * x + b
        return y_pred


model = LinearModel()

criterion = torch.nn.MSELoss(size_average=False)  # 定义损失函数,不求平均损失(为False)

#优化器对象
# #model.parameters()会扫描module中的所有成员,如果成员中有相应权重,那么都会将结果加到要训练的参数集合上
# #类似权重的更新
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 定义梯度优化器为随机梯度下降

for epoch in range(10000):  # 训练过程
    y_pred = model(x_data)  # 向前传播,求y_pred
    loss = criterion(y_pred, y_data)  # 根据y_pred和y_data求损失
    print(epoch, loss)
    
    # 记住在backward之前要先梯度归零
    
    optimizer.zero_grad()  # 将优化器数值清零
    loss.backward()  # 反向传播,计算梯度
    optimizer.step()  # 根据梯度更新参数


#打印权重和b
print("w = ", model.linear.weight.item())
print("b = ", model.linear.bias.item())


#检测模型
x_test = torch.tensor([4.0])
y_test = model(x_test)
print('y_pred = ', y_test.data)  # 测试

结果分析

9961 tensor(4.0927e-12, grad_fn=)
9962 tensor(4.0927e-12, grad_fn=)
9963 tensor(4.0927e-12, grad_fn=)
9964 tensor(4.0927e-12, grad_fn=)
9965 tensor(4.0927e-12, grad_fn=)
9966 tensor(4.0927e-12, grad_fn=)
9967 tensor(4.0927e-12, grad_fn=)
9968 tensor(4.0927e-12, grad_fn=)
9969 tensor(4.0927e-12, grad_fn=)
9970 tensor(4.0927e-12, grad_fn=)
9971 tensor(4.0927e-12, grad_fn=)
9972 tensor(4.0927e-12, grad_fn=)
9973 tensor(4.0927e-12, grad_fn=)
9974 tensor(4.0927e-12, grad_fn=)
9975 tensor(4.0927e-12, grad_fn=)
9976 tensor(4.0927e-12, grad_fn=)
9977 tensor(4.0927e-12, grad_fn=)
9978 tensor(4.0927e-12, grad_fn=)
9979 tensor(4.0927e-12, grad_fn=)
9980 tensor(4.0927e-12, grad_fn=)
9981 tensor(4.0927e-12, grad_fn=)
9982 tensor(4.0927e-12, grad_fn=)
9983 tensor(4.0927e-12, grad_fn=)
9984 tensor(4.0927e-12, grad_fn=)
9985 tensor(4.0927e-12, grad_fn=)
9986 tensor(4.0927e-12, grad_fn=)
9987 tensor(4.0927e-12, grad_fn=)
9988 tensor(4.0927e-12, grad_fn=)
9989 tensor(4.0927e-12, grad_fn=)
9990 tensor(4.0927e-12, grad_fn=)
9991 tensor(4.0927e-12, grad_fn=)
9992 tensor(4.0927e-12, grad_fn=)
9993 tensor(4.0927e-12, grad_fn=)
9994 tensor(4.0927e-12, grad_fn=)
9995 tensor(4.0927e-12, grad_fn=)
9996 tensor(4.0927e-12, grad_fn=)
9997 tensor(4.0927e-12, grad_fn=)
9998 tensor(4.0927e-12, grad_fn=)
9999 tensor(4.0927e-12, grad_fn=)

w = 1.9999985694885254
b = 2.979139480885351e-06
y_pred = tensor([8.0000])

因为轮数过多,这里展示后面几轮
模型的准确性,跟轮数的多少有关系 ,如果轮数为100,最后测试结果的y_pred肯定不为8.00,这里轮数为10000,预测结果跟实际结果基本一样

这里是轮数为100,结果是 7点多,有一定误差
0 tensor(101.4680, grad_fn=)
1 tensor(45.8508, grad_fn=)
2 tensor(21.0819, grad_fn=)
3 tensor(10.0458, grad_fn=)
4 tensor(5.1234, grad_fn=)
5 tensor(2.9227, grad_fn=)
6 tensor(1.9338, grad_fn=)
7 tensor(1.4844, grad_fn=)
8 tensor(1.2754, grad_fn=)
9 tensor(1.1736, grad_fn=)
10 tensor(1.1195, grad_fn=)
11 tensor(1.0869, grad_fn=)
12 tensor(1.0639, grad_fn=)
13 tensor(1.0453, grad_fn=)
14 tensor(1.0288, grad_fn=)
15 tensor(1.0134, grad_fn=)
16 tensor(0.9985, grad_fn=)
17 tensor(0.9841, grad_fn=)
18 tensor(0.9699, grad_fn=)
19 tensor(0.9559, grad_fn=)
20 tensor(0.9421, grad_fn=)
21 tensor(0.9286, grad_fn=)
22 tensor(0.9153, grad_fn=)
23 tensor(0.9021, grad_fn=)
24 tensor(0.8891, grad_fn=)
25 tensor(0.8764, grad_fn=)
26 tensor(0.8638, grad_fn=)
27 tensor(0.8513, grad_fn=)
28 tensor(0.8391, grad_fn=)
29 tensor(0.8271, grad_fn=)
30 tensor(0.8152, grad_fn=)
31 tensor(0.8034, grad_fn=)
32 tensor(0.7919, grad_fn=)
33 tensor(0.7805, grad_fn=)
34 tensor(0.7693, grad_fn=)
35 tensor(0.7582, grad_fn=)
36 tensor(0.7474, grad_fn=)
37 tensor(0.7366, grad_fn=)
38 tensor(0.7260, grad_fn=)
39 tensor(0.7156, grad_fn=)
40 tensor(0.7053, grad_fn=)
41 tensor(0.6952, grad_fn=)
42 tensor(0.6852, grad_fn=)
43 tensor(0.6753, grad_fn=)
44 tensor(0.6656, grad_fn=)
45 tensor(0.6561, grad_fn=)
46 tensor(0.6466, grad_fn=)
47 tensor(0.6373, grad_fn=)
48 tensor(0.6282, grad_fn=)
49 tensor(0.6192, grad_fn=)
50 tensor(0.6103, grad_fn=)
51 tensor(0.6015, grad_fn=)
52 tensor(0.5928, grad_fn=)
53 tensor(0.5843, grad_fn=)
54 tensor(0.5759, grad_fn=)
55 tensor(0.5676, grad_fn=)
56 tensor(0.5595, grad_fn=)
57 tensor(0.5514, grad_fn=)
58 tensor(0.5435, grad_fn=)
59 tensor(0.5357, grad_fn=)
60 tensor(0.5280, grad_fn=)
61 tensor(0.5204, grad_fn=)
62 tensor(0.5129, grad_fn=)
63 tensor(0.5056, grad_fn=)
64 tensor(0.4983, grad_fn=)
65 tensor(0.4911, grad_fn=)
66 tensor(0.4841, grad_fn=)
67 tensor(0.4771, grad_fn=)
68 tensor(0.4703, grad_fn=)
69 tensor(0.4635, grad_fn=)
70 tensor(0.4569, grad_fn=)
71 tensor(0.4503, grad_fn=)
72 tensor(0.4438, grad_fn=)
73 tensor(0.4374, grad_fn=)
74 tensor(0.4311, grad_fn=)
75 tensor(0.4250, grad_fn=)
76 tensor(0.4188, grad_fn=)
77 tensor(0.4128, grad_fn=)
78 tensor(0.4069, grad_fn=)
79 tensor(0.4010, grad_fn=)
80 tensor(0.3953, grad_fn=)
81 tensor(0.3896, grad_fn=)
82 tensor(0.3840, grad_fn=)
83 tensor(0.3785, grad_fn=)
84 tensor(0.3730, grad_fn=)
85 tensor(0.3677, grad_fn=)
86 tensor(0.3624, grad_fn=)
87 tensor(0.3572, grad_fn=)
88 tensor(0.3521, grad_fn=)
89 tensor(0.3470, grad_fn=)
90 tensor(0.3420, grad_fn=)
91 tensor(0.3371, grad_fn=)
92 tensor(0.3322, grad_fn=)
93 tensor(0.3275, grad_fn=)
94 tensor(0.3228, grad_fn=)
95 tensor(0.3181, grad_fn=)
96 tensor(0.3136, grad_fn=)
97 tensor(0.3091, grad_fn=)
98 tensor(0.3046, grad_fn=)
99 tensor(0.3002, grad_fn=)
w = 1.6352288722991943
b = 0.8292105793952942
y_pred = tensor([7.3701])

Process finished with exit code 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/394382.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Android 发布蒲公英平台自动更新

蒲公英官网:https://www.pgyer.com/ 首先弄明白蒲公英平台的SDK更新机制:蒲公英 - 文档中心 - SDK 自动更新机制 (pgyer.com) 下面直接开始代码操作 1.添加蒲公英maven库 maven { url "https://raw.githubusercontent.com/Pgyer/mvn_repo_pgyer…

Matlab论文插图绘制模板第136期—极坐标气泡图

在之前的文章中,分享了Matlab笛卡尔坐标系的气泡图的绘制模板: 进一步,再来分享一下极坐标气泡图。 先来看一下成品效果: 特别提示:本期内容『数据代码』已上传资源群中,加群的朋友请自行下载。有需要的朋…

基于微信小程序的校园跑腿系统的研究与实现,附源码

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…

threehit漏洞复现以及防御

说白了跟sql-liql靶场二次注入一样,也是一个转义函数而这次是,入库的时候不转义,出库的时候会转义导致这个漏洞出现 开始测试: 这是我注册完test刚登录的情况 找注入点更新数据的update,很容易找到在age段 这次我注册…

12.QT文件对话框 文件的弹窗选择-QFileDialog

目录 前言: 技能: 内容: 1. 界面 2.信号槽 3.其他函数 参考: 前言: 通过按钮实现文件弹窗选择以及关联的操作 效果图就和平时用电脑弹出的选文件对话框一样 技能: QString filename QFileDialog::ge…

(九)【Jmeter】线程(Threads(Users))之bzm-Free-Form Arrivals Thread Group

简述 操作路径如下: 作用:支持自由形式的用户到达模式,具有更高的灵活性,与Arrivals Thread Group类似,不过是通过设置起始值、终止值和持续时间来达到压测目的。配置:通过图形界面或脚本定义用户到达曲线。使用场景:模拟复杂的用户到达模式,满足特定业务需求。优点:…

第三百五十三回

文章目录 1. 概念介绍2. 使用方法2.1 获取所有时区2.2 转换时区时间 3. 示例代码4. 内容总结 我们在上一章回中介绍了"分享一些好的Flutter站点"相关的内容,本章回中将介绍timezone包.闲话休提,让我们一起Talk Flutter吧。 1. 概念介绍 我们在…

open3d k-means 聚类

k-means 聚类 一、算法原理1、介绍2、算法步骤 二、代码1、机器学习生成kmeans聚类2、点云学习生成聚类 三、结果1、原点云2、机器学习生成kmeans聚类3、点云学习生成聚类 四、相关链接 一、算法原理 1、介绍 K-means聚类算法是一种无监督学习算法,主要用于数据聚…

扩展语音识别系统:增强功能与多语言支持

一、引言 在之前的博客中,我们成功构建了一个基于LibriSpeech数据集的英文语音识别系统。现在,我们将对系统进行扩展,增加一些增强功能,并尝试支持多语言识别。 二、增加增强功能 语音合成 --除了语音识别,我们还可以…

SpringMVC的执行流程

过去的开发中,视图阶段(老旧JSP等) 1.首先用户发送请求到前端控制器DispatcherServlet(这是一个调度中心) 2.前端控制器DispatcherServlet收到请求后调用处理器映射器HandlerMapping 3.处理器映射器HandlerMapping找到具体的处理器,可查找xml配置或注…

简单理解VQGAN

简单理解VQGAN TL; DR:与 VQVAE 类似,隐层压缩表征自回归生成的两阶段图像生成方法。增加感知损失和对抗损失,提高压缩表征模型解码出图片的清晰度。还可以通过编码并预置条件表征,实现条件生成。 隐层压缩表征自回归生成&#…

迁移SVN和GIT的云端数据

在新服务器搭建GIT仓库 教程很多,大致的流程是: 1. 新建linux用户密码专用于git操作 2. 新建git库的存放文件夹并在此初始化git 3. 配置git库所在目录权限 *只需要有一个库和有一个用户,与在windows上建库是一样的。不需要搭建类似gitla…

深入解析Android AIDL:实现跨进程通信的利器

深入解析Android AIDL:实现跨进程通信的利器 1. 介绍Android AIDL Android Interface Definition Language (AIDL) 是一种Android系统中的跨进程通信机制。AIDL允许一个应用程序的组件与另一个应用程序的组件通信,并在两者之间传输数据。 AIDL的主要作…

云手机受欢迎背后的原因及未来展望

随着办公模式的演变,云手机的热潮迅速兴起。在各种办公领域,云手机正展现出卓越的实际应用效果。近年来,跨境电商行业迎来了蓬勃发展,其与国内电商的差异不仅体现在整体环境上,更在具体的操作层面呈现出独特之处。海外…

短链接系统测试报告

目录 项目背景 项目功能 自动化测试 总结 项目背景 随着互联网的发展,链接(URL)变得越来越长且复杂,这不仅影响用户体验,还可能由于字符限制导致在某些平台或应用中无法完整显示。为了解决这一问题,我…

上位机图像处理和嵌入式模块部署(boost库的使用)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 作为c程序员来说,除了qt之外,另外值得学的开发库就是boost。boost本身包含的内容非常多,基本我们常用的功能都已…

ChatGPT实战100例 - (17) 用ChatGPT实现音频长度测量和音量调整

文章目录 ChatGPT实战100例 - (17) 用ChatGPT实现音频长度测量和音量调整获取音频长度pydub获取音频长度获取时长精确到秒格式设定 mutagen获取音频长度 调整音量视频音量调整注意事项 ChatGPT实战100例 - (17) 用ChatGPT实现音频长度测量和音量调整 老王媳妇说上次那个pip挺好…

分布式学习笔记

1. CAP理论 Consistency(一致性):用户访问分布式系统中的任意节点,得到的数据必须一致。 Availability(可用性):用户访问集群中的任意健康节点,必须得到相应,而不是超时…

VSCODE上使用python_Django

接上篇 https://blog.csdn.net/weixin_44741835/article/details/136135996?csdn_share_tail%7B%22type%22%3A%22blog%22%2C%22rType%22%3A%22article%22%2C%22rId%22%3A%22136135996%22%2C%22source%22%3A%22weixin_44741835%22%7D VSCODE官网: Editing Python …

汽车网络安全--关于供应商网络安全能力维度的思考

目录 1.关于CSMS的理解 2.OEM如何评审供应商 2.1 质量评审 2.2 网络安全能力评审 3.小结 1.关于CSMS的理解 最近在和朋友们交流汽车网络安全趋势时,讨论最多的是供应商如何向OEM证明其网络安全能力。 这是很重要的一环,因为随着汽车网络安全相关强…