自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测

代码1实现逻辑回归并保存模型

import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

data = [[-0.5, 7.7], [1.8, 98.5], [0.9, 57.8], [0.4, 39.2], [-1.4, -15.7], [-1.4, -37.3], [-1.8, -49.1], [1.5, 75.6],
        [0.4, 34.0], [0.8, 62.3]]

data = np.array(data)

x_data = data[:, 0]
y_data = data[:, 1]
x_train = torch.tensor(x_data, dtype=torch.float32)
y_train = torch.tensor(y_data, dtype=torch.float32)

dataset = TensorDataset(x_train, y_train)
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)

print(x_train)

criterion = nn.MSELoss()
model = nn.Sequential(nn.Linear(1, 1))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

epochs = 500
for n in range(1, epochs + 1):
    epoch_loss = 0
    for batch_x, batch_y in dataloader:
        y_pred = model(batch_x.unsqueeze(1))
        batch_loss = criterion(y_pred.squeeze(1), batch_y)
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
        epoch_loss = epoch_loss + batch_loss.item()
        avg_loss = epoch_loss / len(dataloader)
        if n % 10 == 0 or n == 1:
            print(f'epoch:{n}, loss:{epoch_loss}')
            torch.save(model, 'entire_model.pth')

运行结果:

epoch:1, loss:1784.125
epoch:1, loss:5864.14892578125
epoch:10, loss:1340.3311767578125
epoch:10, loss:2131.239013671875
epoch:20, loss:473.2213439941406
epoch:20, loss:750.4476013183594
epoch:30, loss:146.10401916503906
epoch:30, loss:306.4013214111328
epoch:40, loss:93.8193588256836
epoch:40, loss:161.09381866455078
epoch:50, loss:22.086835861206055
epoch:50, loss:113.13406944274902
epoch:60, loss:66.36778259277344
epoch:60, loss:95.92328262329102
epoch:70, loss:37.97149658203125
epoch:70, loss:90.1829833984375
epoch:80, loss:27.833377838134766
epoch:80, loss:88.26276016235352
epoch:90, loss:18.019649505615234
epoch:90, loss:86.4774284362793
epoch:100, loss:36.294681549072266
epoch:100, loss:86.4249382019043
epoch:110, loss:25.34766960144043
epoch:110, loss:85.6535472869873
epoch:120, loss:71.55767059326172
epoch:120, loss:85.6304121017456
epoch:130, loss:54.34508514404297
epoch:130, loss:85.70595932006836
epoch:140, loss:58.45751953125
epoch:140, loss:85.98778343200684
epoch:150, loss:24.026874542236328
epoch:150, loss:85.38119125366211
epoch:160, loss:31.197525024414062
epoch:160, loss:85.36103820800781
epoch:170, loss:23.816781997680664
epoch:170, loss:85.37735176086426
epoch:180, loss:66.44145202636719
epoch:180, loss:86.07975769042969
epoch:190, loss:49.096153259277344
epoch:190, loss:85.98376846313477
epoch:200, loss:38.83055877685547
epoch:200, loss:87.02980041503906
epoch:210, loss:22.55113410949707
epoch:210, loss:85.6132755279541
epoch:220, loss:60.618438720703125
epoch:220, loss:85.66439247131348
epoch:230, loss:24.166812896728516
epoch:230, loss:85.43827819824219
epoch:240, loss:36.66695022583008
epoch:240, loss:85.72342681884766
epoch:250, loss:50.92716979980469
epoch:250, loss:86.27684783935547
epoch:260, loss:37.27833557128906
epoch:260, loss:85.69609069824219
epoch:270, loss:50.637638092041016
epoch:270, loss:86.2179069519043
epoch:280, loss:60.93098068237305
epoch:280, loss:85.41929817199707
epoch:290, loss:34.782196044921875
epoch:290, loss:85.72705841064453
epoch:300, loss:30.515146255493164
epoch:300, loss:85.36332130432129
epoch:310, loss:33.87446594238281
epoch:310, loss:85.8970718383789
epoch:320, loss:72.44877624511719
epoch:320, loss:85.54687786102295
epoch:330, loss:61.19231414794922
epoch:330, loss:85.43900299072266
epoch:340, loss:48.75373840332031
epoch:340, loss:85.7229118347168
epoch:350, loss:33.820648193359375
epoch:350, loss:85.89838409423828
epoch:360, loss:34.31058883666992
epoch:360, loss:85.568359375
epoch:370, loss:42.243125915527344
epoch:370, loss:86.17256927490234
epoch:380, loss:42.217655181884766
epoch:380, loss:86.17264938354492
epoch:390, loss:29.57950210571289
epoch:390, loss:86.71274185180664
epoch:400, loss:65.79289245605469
epoch:400, loss:86.1904239654541
epoch:410, loss:26.13401222229004
epoch:410, loss:85.56365013122559
epoch:420, loss:28.22481918334961
epoch:420, loss:86.30517959594727
epoch:430, loss:66.69523620605469
epoch:430, loss:85.44241905212402
epoch:440, loss:46.568904876708984
epoch:440, loss:87.0429573059082
epoch:450, loss:49.70370101928711
epoch:450, loss:86.14195251464844
epoch:460, loss:59.23515701293945
epoch:460, loss:85.56475257873535
epoch:470, loss:43.026187896728516
epoch:470, loss:86.16817855834961
epoch:480, loss:62.85814666748047
epoch:480, loss:86.68252944946289
epoch:490, loss:26.73046875
epoch:490, loss:86.20927047729492
epoch:500, loss:25.716747283935547
epoch:500, loss:85.53901290893555

代码2 加载模型进行预测:

import torch

entire_model = torch.load('entire_model.pth')
entire_model.eval()
x_test = torch.tensor([1.8], dtype=torch.float32)
with torch.no_grad():
    y_pred = entire_model(x_test)
print(y_pred)
print(entire_model)

结果:

tensor([93.7325])
Sequential(
  (0): Linear(in_features=1, out_features=1, bias=True)
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/960702.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于回归分析法的光伏发电系统最大功率计算simulink建模与仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 5.完整工程文件 1.课题概述 基于回归分析法的光伏发电系统最大功率计算simulink建模与仿真。选择回归法进行最大功率点的追踪,使用光强和温度作为影响因素,电压作为输出进行建模。…

【数据结构】 并查集 + 路径压缩与按秩合并 python

目录 前言模板朴素实现路径压缩按秩合并按树高为秩按节点数为秩 总结 前言 并查集的基本实现通常使用森林来表示不同的集合,每个集合用一棵树表示,树的每个节点有一个指向其父节点的指针。 如果一个节点是它自己的父节点,那么它就是该集合的代…

Flutter android debug 编译报错问题。插件编译报错

下面相关内容 都以 Mac 电脑为例子。 一、问题 起因:(更新 Android studio 2024.2.2.13、 Flutter SDK 3.27.2) 最近 2025年 1 月 左右,我更新了 Android studio 和 Flutter SDK 再运行就会出现下面的问题。当然 下面的提示只是其…

CSAPP学习:前言

前言 本书简称CS:APP。 背景知识 一些基础的C语言知识 如何阅读 Do-做系统 在真正的系统上解决具体的问题,或是编写和运行程序。 章节 2025-1-27 个人认为如下章节将会对学习408中的操作系统与计算机组成原理提供帮助,于是先凭借记忆将其简单…

动态规划DP 数字三角型模型 方格取数(题目详解+C++代码实现)

方格取数 原题链接 AcWing 1027. 方格取数 题目描述 设有 NN 的方格图,我们在其中的某些方格中填入正整数,而其它的方格中则放入数字0。 如下图所示: 某人从图中的左上角 A 出发,可以向下行走,也可以向右行走&…

【Linux】20.基础IO(2)

文章目录 2. 理解文件系统2.1 inode2.2 如何理解目录2.3 硬链接2.4 软链接2.5 硬链接和软链接的区别 2. 理解文件系统 2.1 inode 我们使用ls -l的时候看到的除了看到文件名,还看到了文件元数据。 ydk_108iZuf68hz06p6s2809gl3i1Z:~/108/lesson20$ ll total 8 drw…

read+write实现:链表放到文件+文件数据放到链表 的功能

思路 一、 定义链表: 1 节点结构(数据int型) 2 链表操作(创建节点、插入节点、释放链表、打印链表)。 二、链表保存到文件 1打开文件 2遍历链表、写文件: 遍历链表,write()将节点数据写入文件。…

图漾相机-ROS2-SDK-Ubuntu版本编译(新版本)

官网编译文档链接: https://doc.percipio.xyz/cam/latest/getstarted/sdk-ros2-compile.html 国内gitee下载SDK链接: https://gitee.com/percipioxyz 国外github下载SDK链接: https://github.com/percipioxyz 1.Camport ROS2 SDK 介绍 1.1 …

C# 添加、替换、提取、或删除Excel中的图片

在Excel中插入与数据相关的图片,能将关键数据或信息以更直观的方式呈现出来,使文档更加美观。此外,对于已有图片,你有事可能需要更新图片以确保信息的准确性,或者将Excel 中的图片单独保存,用于资料归档、备…

智能风控 数据分析 groupby、apply、reset_index组合拳

目录 groupby——分组 本例 apply——对每个分组应用一个函数 等价用法 reset_index——重置索引 使用前​编辑 注意事项 groupby必须配合聚合函数、 关于agglist 一些groupby试验 1. groupby对象之后。sum(一个列名) 2. groupby对象…

浅析百度AOI数据与高德AOI数据的差异性

目录 前言 一、AOI属性数据 1、百度AOI数据 2、高德AOI数据 二、AOI矢量边界 1、百度AOI空间范围 2、高德AOI空间范围 三、数据获取频次和难易程度 1、接口限制 2、数据转换成本 四、总结 前言 在当今数字化时代,地理信息数据的精准性和丰富性对于城市规划…

通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)

大家对于智能体代理Agent一定已经非常熟悉,自主代理(Autonomous Agents) 目前在AI行业极其热门并具有巨大的潜力,能够显著提升开发者日常的工作效率、自动化日常琐碎、重复性任务,并生成全新的内容。Agent可以理解用户…

汇编的使用总结

一、汇编的组成 1、汇编指令(指令集) 数据处理指令: 数据搬移指令 数据移位指令 位运算指令 算术运算指令 比较指令 跳转指令 内存读写指令 状态寄存器传送指令 异常产生指令等 2、伪指令 不是汇编指令,但是可以起到指令的作用,伪…

S4 HANA定义税码(FTXP)

本文主要介绍在S4 HANA OP中S4 HANA定义税码相关设置。具体请参照如下内容: 定义税码(FTXP) 以上界面是根据国家的“定价过程”确定的。蓝色的行项目表示目前已经激活的行项目。 不可抵扣进项税一般用于采购业务中,因此用在进项税码中。 消费税和营业…

Git进阶笔记系列(01)Git核心架构原理 | 常用命令实战集合

读书笔记:卓越强迫症强大恐惧症,在亲子家庭、职场关系里尤其是纵向关系模型里,这两种状态很容易无缝衔接。尤其父母对子女、领导对下属,都有望子成龙、强将无弱兵的期望,然而在你的面前,他们才是永远强大的…

多级缓存(亿级并发解决方案)

多级缓存(亿级流量(并发)的缓存方案) 传统缓存的问题 传统缓存是请求到达tomcat后,先查询redis,如果未命中则查询数据库,问题如下: (1)请求要经过tomcat处…

场景设计学习-积分系统

场景设计-积分系统 1.概念和规则 积分:用户在网站的各种交互行为都可以产生积分,积分值与行为类型有关天梯榜:按照每个用户的总积分排序得到的排行榜,称为天梯榜。排名靠前的有奖励。天梯榜每个自然月为一个赛季,月初…

ML基础3-sklearn中的1个简单的分类器例子

Scikit-learn(通常缩写为sklearn)是一个流行的Python机器学习库,用于数据挖掘和数据分析任务。它建立在NumPy、SciPy和matplotlib等科学计算/可视化库的基础上,提供了丰富的工具和算法,用于处理各种机器学习问题&#…

The Simulation技术浅析(二):模型技术

一、物理模型(Physical Models) 1. 概述 物理模型基于物理定律和原理,通过模拟现实世界中物理系统的行为和相互作用来构建模型。物理模型通常用于工程、物理和化学等领域,用于预测系统在不同条件下的表现。 2. 关键技术 力学定律:例如牛顿运动定律,用于模拟物体的运动…

006 mybatis关联查询(一对一、一对多)

文章目录 一对一查询SQL语句方法一:resultType方法二:resultMap创建扩展po类Mapper映射文件Mapper接口测试代码小结 一对多查询SQL语句修改po类Mapper映射文件Mapper接口测试代码 注意:因为一个订单信息只会是一个人下的订单,所以…