使用Pytorch实现linear_regression

使用Pytorch实现线性回归

# import necessary packages
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# Set necessary Hyper-parameters.
input_size = 1
output_size = 1
num_epochs = 60
learning_rate = 0.001
# Define a Toy dataset.
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168], 
                    [9.779], [6.182], [7.59], [2.167], [7.042], 
                    [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)

y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573], 
                    [3.366], [2.596], [2.53], [1.221], [2.827], 
                    [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)

# Confirm the data shape.
print(x_train.shape, y_train.shape)
(15, 1) (15, 1)
# Linear regression model
model = nn.Linear(input_size, output_size)
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, )
# Train the model
for epoch in range(num_epochs):
    # Convert numpy arrays to torch tensors
    inputs = torch.from_numpy(x_train)
    targets = torch.from_numpy(y_train)

    # Forward pass
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # Backward and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Set an output counter
    if (epoch+1) % 5 == 0:
        print('Epoch [{}/{}], loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

# Plot the graph
predicted = model(torch.from_numpy(x_train)).detach().numpy()
plt.plot(x_train, y_train, 'ro', label='Original data')
plt.plot(x_train, predicted, label='Fitted line')
plt.legend()
plt.show()
Epoch [5/60], loss: 7.1598
Epoch [10/60], loss: 3.0717
Epoch [15/60], loss: 1.4154
Epoch [20/60], loss: 0.7443
Epoch [25/60], loss: 0.4722
Epoch [30/60], loss: 0.3618
Epoch [35/60], loss: 0.3169
Epoch [40/60], loss: 0.2985
Epoch [45/60], loss: 0.2909
Epoch [50/60], loss: 0.2876
Epoch [55/60], loss: 0.2861
Epoch [60/60], loss: 0.2853

在这里插入图片描述

# Save the model checkpoint
torch.save(model.state_dict(), 'model_param.ckpt')
torch.save(model, 'model.ckpt')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/174879.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

GB28181视频监控国标平台EasyGBS如何进行服务迁移?

视频流媒体安防监控国标GB28181平台EasyGBS视频能力丰富,部署灵活,既能作为业务平台使用,也能作为安防监控视频能力层被业务管理平台调用。国标GB28181视频EasyGBS平台可提供流媒体接入、处理、转发等服务,支持内网、公网的安防视…

直播岗位认知篇

一、直播岗位概述 直播岗位,也称为直播主播或直播运营,是指在互联网直播平台上进行直播活动的工作岗位。该岗位的主要职责是通过直播形式,向观众展示自己的才艺、分享生活、销售产品或服务,并引导观众互动和参与。直播主播需要具…

【C++】泛型编程 ⑪ ( 类模板的运算符重载 - 函数实现 写在类外部的不同的 .h 头文件和 .cpp 代码中 )

文章目录 一、类模板的运算符重载 - 函数实现 写在类外部的不同的 .h 头文件和 .cpp 代码中1、分离代码 后的 友元函数报错信息 - 错误示例Student.h 头文件内容Student.cpp 代码文件内容Test.cpp 代码文件内容执行报错信息 2、问题分析 二、代码示例 - 函数实现 写在类外部的不…

设计模式总结-笔记

一个目标:管理变化,提供复用! 两种手段:分解vs.抽象 八大原则: 依赖倒置原则(DIP) 开放封闭原则(OCP) 单一职责原则(SRP) Liskov替换原则&a…

Python pip 镜像源设置指南

文章目录 Python pip 镜像源设置指南前言安装单个包使用PyPI镜像使用镜像升级 pip设为默认pip镜像结语 Python pip 镜像源设置指南 前言 平时在使用 pip 安装一些包的时候速度非常慢,本文介绍如何在 Python3 下设置 PyPI 设置镜像源,本文以给 Python3 设置清华 镜像源举例. …

【JavaEE】Servlet实战案例:表白墙网页实现

一、功能展示 输入信息: 点击提交: 二、设计要点 2.1 明确前后端交互接口 🚓接口一:当用户打开页面的时候需要从服务器加载已经提交过的表白数据 🚓接口二:当用户新增一个表白的时候,…

2024电脑录屏软件排行第一Camtasia喀秋莎

真的要被录屏软件给搞疯了,本来公司说要给新人做个培训视频,想着把视频录屏一下,然后简单的剪辑一下就可以了。可谁知道录屏软件坑这么多,弄来弄去头都秃了,不过在头秃了几天之后,终于让我发现了一个值得“…

Ant Design Vue 树形表格计算盈收金额

树形表格计算 一、盈收金额计算1、根据需要输入的子级位置,修改数据2、获取兄弟节点数据,并计算兄弟节点的金额合计3、金额合计,遍历给所有的父级 一、盈收金额计算 1、根据需要输入的子级位置,修改数据 2、获取兄弟节点数据&am…

求二叉树的最大密度(可运行)

最大密度:二叉树节点数值的最大值 如果没有输出结果,一定是建树错误!!!!!!! 我设置输入的是字符型数据,比较的ASCII值。 输入:FBE###CE### 输…

支付宝生僻字选择器

本文的数据来源于支付宝网页版本生僻字选择器。 let rareWords[{spell: "a",words: ["奡", "靉", "叆"]}, {spell: "b",words: ["仌", "昺", "竝", "霦", "犇", "愊…

CSDN流量卷领取和使用保姆级教程——流量卷,恭喜获得每日任务奖励【1500曝光】可获得新增曝光,阅读转化,点赞转化,新增关注-流量卷,流量卷,流量卷

希望本文能够给您带来一定的帮助,文章粗浅,敬请批评指正! 目录 话不多说,直接上干货: 第一步:流量卷领取教程:点击内容管理:​编辑 第二步:点击首页:​编辑…

【C++】C++11(2)

文章目录 一、新的类功能二、可变参数模板(了解)三、lambda表达式1. C98中的一个例子2.lambda表达式3.lambda表达式语法4.函数对象与lambda表达式 四、包装器1.function包装器2.bind 五、线程库1.thread类的简单介绍2.线程函数参数3.原子性操作库(atomic…

Transformer的一点理解,附一个简单例子理解attention中的QKV

Transformer用于目标检测的开山之作DETR,论文作者在附录最后放了一段简单的代码便于理解DETR模型。 DETR的backbone用的是resnet-50去掉了最后的AdaptiveAvgPool2d和Linear这两层。 self.backbone nn.Sequential(*list(resnet50(pretrainedTrue).children())[:-2…

MyBatis:关联查询

MyBatis 前言关联查询附懒加载对象为集合时的关联查询 前言 在 MyBatis:配置文件 文章中,最后介绍了可以使用 select 标签的 resultMap 属性实现关联查询,下面简单示例 关联查询 首先,先创建 association_role 和 association_…

上海亚商投顾:沪指冲高回落 短剧、地产股集体走强

上海亚商投顾前言:无惧大盘涨跌,解密龙虎榜资金,跟踪一线游资和机构资金动向,识别短期热点和强势个股。 一.市场情绪 三大指数早盘冲高,创业板指盘初涨超1%,午后则集体下行翻绿,北证50一度大涨…

求二叉树中指定节点所在的层数(可运行)

运行环境.cpp 我这里设置的是查字符e的层数,大家可以在main函数里改成自己想查的字符。(输入的字符一定是自己树里有的)。 如果没有输出结果,一定是建树错误!!!!!&…

Go语言常用命令详解(三)

文章目录 前言常用命令go get示例参数说明 go install示例参数说明 go list示例 go mod示例参数说明 go work基本用法示例 go tool示例 go version示例 go vet示例 总结写在最后 前言 接着上一篇继续介绍Go语言的常用命令 常用命令 以下是一些常用的Go命令,这些命…

Hfish安全蜜罐部署

一、Hfish蜜罐介绍 HFish蜜罐官网 HFish是一款社区型免费蜜罐,侧重企业安全场景,从内网失陷检测、外网威胁感知、威胁情报生产三个场景出发,为用户提供可独立操作且实用的功能,通过安全、敏捷、可靠的中低交互蜜罐增加用户在失陷…

Vue3 相较 Vue2 做的重大更新

双向数据绑定方法 vue2 Object.definePropertie() vue3 Proxy VDOM 性能瓶颈突破 做了静态标记,静态内容不会去再对比 通过位运算对比得出其的静态标记情况 Fragments 允许组件多个根节点 vue3 会虚拟一个根节点,但实际不会渲染虚拟的节点 Tree-S…

力扣第463题 岛屿的周长 C++ 深度优先搜索 + 思维判断的边界

题目 463. 岛屿的周长 简单 相关标签 深度优先搜索 广度优先搜索 数组 矩阵 给定一个 row x col 的二维网格地图 grid ,其中:grid[i][j] 1 表示陆地, grid[i][j] 0 表示水域。 网格中的格子 水平和垂直 方向相连(对角线…