PyTorch 基础篇(2):线性回归(Linear Regression)

  
  
  1. # 包
  2. import torch
  3. import torch.nn as nn
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  
  
  1. # 超参数设置
  2. input_size = 1
  3. output_size = 1
  4. num_epochs = 60
  5. learning_rate = 0.001
  6.  
  7. # Toy dataset
  8. # 玩具资料:小数据集
  9. x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
  10. [9.779], [6.182], [7.59], [2.167], [7.042],
  11. [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
  12.  
  13. y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
  14. [3.366], [2.596], [2.53], [1.221], [2.827],
  15. [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)
  16.  
  17. # 线性回归模型
  18. model = nn.Linear(input_size, output_size)
  19.  
  20. # 损失函数和优化器
  21. criterion = nn.MSELoss()
  22. optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
  
  
  1. # 训练模型
  2. for epoch in range(num_epochs):
  3. # 将Numpy数组转换为torch张量
  4. inputs = torch.from_numpy(x_train)
  5. targets = torch.from_numpy(y_train)
  6.  
  7. # 前向传播
  8. outputs = model(inputs)
  9. loss = criterion(outputs, targets)
  10. # 反向传播和优化
  11. optimizer.zero_grad()
  12. loss.backward()
  13. optimizer.step()
  14. if (epoch 1) % 5 == 0:
  15. print (‘Epoch [{}/{}], Loss: {:.4f}’.format(epoch 1, num_epochs, loss.item()))
  
  
  1. Epoch [5/60], Loss: 7.7737
  2. Epoch [10/60], Loss: 3.2548
  3. Epoch [15/60], Loss: 1.4241
  4. Epoch [20/60], Loss: 0.6824
  5. Epoch [25/60], Loss: 0.3820
  6. Epoch [30/60], Loss: 0.2602
  7. Epoch [35/60], Loss: 0.2109
  8. Epoch [40/60], Loss: 0.1909
  9. Epoch [45/60], Loss: 0.1828
  10. Epoch [50/60], Loss: 0.1795
  11. Epoch [55/60], Loss: 0.1781
  12. Epoch [60/60], Loss: 0.1776
  
  
  1. # 绘制图形
  2. # torch.from_numpy(x_train)将X_train转换为Tensor
  3. # model()根据输入和模型,得到输出
  4. # detach().numpy()预测结结果转换为numpy数组
  5. predicted = model(torch.from_numpy(x_train)).detach().numpy()
  6. plt.plot(x_train, y_train, ‘ro’, label=‘Original data’)
  7. plt.plot(x_train, predicted, label=‘Fitted line’)
  8. plt.legend()
  9. plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/224871.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SpringSecurity6 | 默认用户生成(下)

✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏: Java从入门到精通 ✨特色专栏&#xf…

MTU TCP-MSS(转载)

MTU MTU 最大传输单元(Maximum Transmission Unit,MTU)用来通知对方所能接受数据服务单元的最大尺寸,说明发送方能够接受的有效载荷大小。 是包或帧的最大长度,一般以字节记。如果MTU过大,在碰到路由器时…

kyuubi整合flink yarn application model

目录 概述配置flink 配置kyuubi 配置kyuubi-defaults.confkyuubi-env.shhive 验证启动kyuubibeeline 连接使用hive catalogsql测试 结束 概述 flink 版本 1.17.1、kyuubi 1.8.0、hive 3.1.3、paimon 0.5 整合过程中,需要注意对应的版本。 注意以上版本 姊妹篇 k…

tgf - 一个开箱即用的golang游戏服务器框架

tgf框架 tgf框架是使用golang开发的一套游戏分布式框架.属于开箱即用的项目框架,目前适用于中小型团队,独立开发者,快速开发使用.框架提供了一整套开发工具,并且定义了模块开发规范.开发者只需要关注业务逻辑即可,无需关心用户并发和节点状态等复杂情况. 使用介绍 创建业务逻辑…

JavaScript面向对象编程的奥秘揭秘:掌握核心概念与设计模式

​🌈个人主页:前端青山 🔥系列专栏:JavaScript篇 🔖人终将被年少不可得之物困其一生 依旧青山,本期给大家带来JavaScript篇专栏内容:JavaScript-面向对象 目录 什么是面向对象? 类与对象的主要区别 创建…

python+pytest接口自动化(9)-cookie绕过登录(保持登录状态)

在编写接口自动化测试用例或其他脚本的过程中,经常会遇到需要绕过用户名/密码或验证码登录,去请求接口的情况,一是因为有时验证码会比较复杂,比如有些图形验证码,难以通过接口的方式去处理;再者&#xff0c…

气膜厂家怎样确保产品质量和售后服务?

气膜厂家作为一家专业生产气膜产品的企业,确保产品质量和提供良好的售后服务是我们的责任和使命。为了确保产品质量和售后服务的可靠性,我们采取了以下措施。 起初,我们严格按照国家标准和相关行业规范进行生产。气膜产品的质量是产品能否长…

编织魔法世界——计算机科学的奇幻之旅

文章目录 每日一句正能量前言为什么当初选择计算机行业计算机对自己人生道路的影响后记 每日一句正能量 人生就像赛跑,不在乎你是否第一个到达尽头,而在乎你有没有跑完全程。 前言 计算机是一个神奇的领域,它可以让人们创造出炫酷的虚拟世界…

Linux常用命令——as命令

在线Linux命令查询工具 as 汇编语言编译器 补充说明 as命令GNU组织推出的一款汇编语言编译器,它支持多种不同类型的处理器。 语法 as(选项)(参数)选项 -ac:忽略失败条件; -ad:忽略调试指令; -ah:包括…

nVisual能为数据中心解决什么问题?

nVisual通过可视化的管理方式,使数据中心管理者能够有效且高效地管理数据中心的资产、线缆、容量、变更;使数据中心管理者能够获得如下问题的答案,以便能够快速做出更好、更明智的决策: 1.资产管理 我们有什么&#x…

VMware Linux(Centos)虚拟机扩容根目录磁盘空间

给VMWare虚拟机根目录扩容,简单有效!_迷倒万千少女的Csir的博客-CSDN博客 https://blog.csdn.net/m0_64206944/article/details/131453844?spm1001.2014.3001.5506 上述链接融合参考下面文章 VMware Linux(Centos)虚拟机扩容根目录磁盘空间 centosli…

Redis quicklist源码+listpack源码(6.0+以上版本)

ziplist设计上的问题,每一次增删改都需要计算前面元素的空间和长度(prevlen),这种设计缺陷非常明显,一旦其中一个entry发生修改,以这个entry后面开始,全部需要重新计算prevlen,因此诞…

台灯哪个品牌比较好?适合考研党的台灯推荐

眼睛作为人体非常重要的器官之一,它承担着接受和感知光线的功能。然而,长时间暴露在强光下或者不适当的光线环境下可能会对眼睛健康造成一定的影响。许多学生党以及上班族可能深有体会,在日常读写以及长时间面对电子产品中,很容易…

数字文化大观:TikTok影响下的全球文娱

在数字时代的大潮中,社交媒体平台正成为全球文娱产业的重要引擎之一。而TikTok,作为一款以短视频为特色的社交应用,正深刻地改变着全球文娱的面貌。 本文将深入研究TikTok对全球文娱的影响,探讨数字文化在这一平台的催化下如何迅…

超大规模集成电路设计----CMOS组合逻辑门(六)

本文仅供学习,不作任何商业用途,严禁转载。绝大部分资料来自----数字集成电路——电路、系统与设计(第二版)及中国科学院段成华教授PPT 超大规模集成电路设计----CMOS组合逻辑门(六) 6.1 静态CMOS设计6.1.1 互补CMOS6.1.1.1 互补…

本项目基于Spring boot的AMQP模块,整合流行的开源消息队列中间件rabbitMQ,实现一个向rabbitMQ

在业务逻辑的异步处理,系统解耦,分布式通信以及控制高并发的场景下,消息队列有着广泛的应用。本项目基于Spring的AMQP模块,整合流行的开源消息队列中间件rabbitMQ,实现一个向rabbitMQ添加和读取消息的功能。并比较了两种模式&…

【头歌系统数据库实验】实验2 MySQL软件操作及建库建表建数据

目录 第1关:创建数据库 第2关:创建供应商表S,并插入数据 第3关:创建零件表P,并插入数据 第4关:创建工程项目表J,并插入数据 第5关:创建供应情况表SPJ,并插入数据 …

dtaidistance笔记:dtw_ndim (高维时间序列之间的DTW)

1 数据 第一个维度是sequence的index,每一行是多个元素(表示这一时刻的record) from dtaidistance.dtw_ndim import *s1 np.array([[0, 0],[0, 1],[2, 1],[0, 1],[0, 0]], dtypenp.double) s2 np.array([[0, 0],[2, 1],[0, 1],[0, .5],[0…

Elasticsearch SQL插件调研与问题整理

在最新的es8.11版本中,开始有了es|ql语言。非常接近sql,但是还是不太一样。而在之前的版本中,sql能力很弱,并且属于白金版本的内容。也就是说需要氪金才能体验,才能使用。 我是es研发工程师。负责公司内部的es集群的日…

Netty线程模型

Netty线程模型 Netty中两个线程池, 分别是BossGroup和WorkGroup, 线程模型如下图所示: 模型解释: Netty 抽象出两组线程池BossGroup和WorkerGroup,BossGroup专门负责接收客户端的连接, WorkerGroup专门负责网络的读写BossGroup和WorkerGr…