自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

导入必要的库:

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score

准备数据:

class1_points = np.array([[1.9, 1.2],
                          [1.5, 2.1],
                          [1.9, 0.5],
                          [1.5, 0.9],
                          [0.9, 1.2],
                          [1.1, 1.7],
                          [1.4, 1.1]])

class2_points = np.array([[3.2, 3.2],
                          [3.7, 2.9],
                          [3.2, 2.6],
                          [1.7, 3.3],
                          [3.4, 2.6],
                          [4.1, 2.3],
                          [3.0, 2.9]])
# 提取两类特征,输入特征维度为2
x1_data = np.concatenate((class1_points[:, 0], class2_points[:, 0]), axis=0)
x2_data = np.concatenate((class1_points[:, 1], class2_points[:, 1]), axis=0)
label = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))), axis=0)

# 将数据转换为 PyTorch 张量
X = torch.tensor(np.column_stack((x1_data, x2_data)), dtype=torch.float32)
y = torch.tensor(label, dtype=torch.float32).view(-1, 1)

定义模型:

# 定义逻辑回归模型
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(2, 1)  # 输入特征维度为 2,输出为 1

    def forward(self, x):
        return torch.sigmoid(self.linear(x))

损失函数和优化器:

# 初始化模型、损失函数和优化器
model = LogisticRegression()
criterion = nn.BCELoss()  # 二分类交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01)

训练模型及保存:

# 训练模型
epochs = 5000
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')
# 保存模型
torch.save(model.state_dict(), 'model3.pth')
print("模型已保存")

模型加载及预测:

# 加载模型
loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('model3.pth',
                                        map_location=torch.device('cpu'),
                                        weights_only=True))
loaded_model.eval()

# 进行预测
with torch.no_grad():
    predictions = loaded_model(X)
    predicted_labels = (predictions > 0.5).float()

# 展示预测结果和实际结果
print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())

计算精确度和召回率及F1分数:

# 计算精确度、召回率和 F1 分数
precision = precision_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
recall = recall_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
f1 = f1_score(y.numpy().flatten(), predicted_labels.numpy().flatten())

print(f"精确度(Precision): {precision:.4f}")
print(f"召回率(Recall): {recall:.4f}")
print(f"F1 分数: {f1:.4f}")

结果展示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/963785.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何运行Composer安装PHP包 安装JWT库

1. 使用Composer Composer是PHP的依赖管理工具,它允许你轻松地安装和管理PHP包。对于JWT,你可以使用firebase/php-jwt这个库,这是由Firebase提供的官方库。 安装Composer(如果你还没有安装的话): 访问Co…

《Linux服务与安全管理》| 数据库服务器安装和配置

《Linux服务与安全管理》| 数据库服务器安装和配置 目录 《Linux服务与安全管理》| 数据库服务器安装和配置 任务一: 安装PostgreSQL数据库,设置远程登录,客户端可以成功登录并操作数据库。 任务二: 安装MySQL数据库&#xf…

【贪心算法篇】:“贪心”之旅--算法练习题中的智慧与策略(一)

✨感谢您阅读本篇文章,文章内容是个人学习笔记的整理,如果哪里有误的话还请您指正噢✨ ✨ 个人主页:余辉zmh–CSDN博客 ✨ 文章所属专栏:贪心算法篇–CSDN博客 文章目录 一.贪心算法1.什么是贪心算法2.贪心算法的特点 二.例题1.柠…

DRM系列七:Drm之CREATE_DUMB

本系列文章基于linux 5.15 DRM驱动的显存由GEM(Graphics execution management)管理。 一、创建流程 创建buf时,user层提供需要buf的width,height以及bpp(bite per pixel),然后调用drmIoctl(fd, DRM_IOCTL_MODE_CREATE_DUMB, &…

Python从0到100(八十七):CNN网络详细介绍及WISDM数据集模型仿真

前言: 零基础学Python:Python从0到100最新最全教程。 想做这件事情很久了,这次我更新了自己所写过的所有博客,汇集成了Python从0到100,共一百节课,帮助大家一个月时间里从零基础到学习Python基础语法、Pyth…

WPF进阶 | WPF 动画特效揭秘:实现炫酷的界面交互效果

WPF进阶 | WPF 动画特效揭秘:实现炫酷的界面交互效果 前言一、WPF 动画基础概念1.1 什么是 WPF 动画1.2 动画的基本类型1.3 动画的核心元素 二、线性动画详解2.1 DoubleAnimation 的使用2.2 ColorAnimation 实现颜色渐变 三、关键帧动画深入3.1 DoubleAnimationUsin…

实践网络安全:常见威胁与应对策略详解

📝个人主页🌹:一ge科研小菜鸡-CSDN博客 🌹🌹期待您的关注 🌹🌹 引言 在数字化转型的浪潮中,网络安全的重要性已达到前所未有的高度。无论是个人用户、企业,还是政府机构…

web-SQL注入-CTFHub

前言 在众多的CTF平台当中,作者认为CTFHub对于初学者来说,是入门平台的不二之选。CTFHub通过自己独特的技能树模块,可以帮助初学者来快速入门。具体请看官方介绍:CTFHub。 作者更新了CTFHub系列,希望小伙伴们多多支持…

LabVIEW如何高频采集温度数据?

在LabVIEW中进行高频温度数据采集时,选择合适的传感器(如热电偶或热电阻)和采集硬件是关键。下面是一些建议,帮助实现高效的温度数据采集: 1. 传感器选择: 热电偶(Thermocouple)&am…

Deep Sleep 96小时:一场没有硝烟的科技保卫战

2025年1月28日凌晨3点,当大多数人还沉浸在梦乡时,一场没有硝烟的战争悄然打响。代号“Deep Sleep”的服务器突遭海量数据洪流冲击,警报声响彻机房,一场针对中国关键信息基础设施的网络攻击来势汹汹! 面对美国发起的这场…

Golang Gin系列-9:Gin 集成Swagger生成文档

文档一直是一项乏味的工作(以我个人的拙见),但也是编码过程中最重要的任务之一。在本文中,我们将学习如何将Swagger规范与Gin框架集成。我们将实现JWT认证,请求体作为表单数据和JSON。这里唯一的先决条件是Gin服务器。…

三、递推关系与母函数,《组合数学(第4版)》卢开澄 卢华明

文章目录 一、似函数、非函数1.1 母函数1.2 母函数的简单应用1.3 整数拆分1.4 Ferrers 图像1.5 母函数能做什么1.6 递推关系1.6.1 Hanoi 问题1.6.2 偶数个5怎么算 1.7 Fibonacci 序列1.7.1 Fibonacci 的奇妙性质1.7.2 Fibonacci 恒等式1.7.3 Fibonacci 的直接表达式1.7.4 Fibon…

97,【5】buuctf web [极客大挑战 2020]Greatphp

进入靶场 审代码 <?php // 关闭所有 PHP 错误报告&#xff0c;防止错误信息泄露可能的安全隐患 error_reporting(0);// 定义一个名为 SYCLOVER 的类 class SYCLOVER {// 定义类的公共属性 $sycpublic $syc;// 定义类的公共属性 $loverpublic $lover;// 定义魔术方法 __wa…

手机上运行AI大模型(Deepseek等)

最近deepseek的大火&#xff0c;让大家掀起新一波的本地部署运行大模型的热潮&#xff0c;特别是deepseek有蒸馏的小参数量版本&#xff0c;电脑上就相当方便了&#xff0c;直接ollamaopen-webui这种类似的组合就可以轻松地实现&#xff0c;只要硬件&#xff0c;如显存&#xf…

JAVA安全—反射机制攻击链类对象成员变量方法构造方法

前言 还是JAVA安全&#xff0c;哎&#xff0c;真的讲不完&#xff0c;太多啦。 今天主要是讲一下JAVA中的反射机制&#xff0c;因为反序列化的利用基本都是要用到这个反射机制&#xff0c;还有一些攻击链条的构造&#xff0c;也会用到&#xff0c;所以就讲一下。 什么是反射…

数据库和数据表的创建、修改、与删除

1.标识符命名规则 数据库名、表名不得超过30个字符&#xff0c;变量名限制为29个 必须只能包含A-Z,a-z,0-9,_共63个字符 数据库名、表名、字段名等对象名中间不能包含空格 同一个MySQL软件中&#xff0c;数据库不能同名&#xff1b;同一个库中&#xff0c;表不能重名&#…

(电脑版)植物大战僵尸幼儿园版本,开启你的冒险之旅!

欢迎来到植物大战僵尸中文版&#xff0c;园长Jen已准备好迎接你的挑战&#xff01;在这个充满乐趣和策略的游戏中&#xff0c;你将体验到多种游戏模式&#xff0c;每种模式都带来不同的挑战和乐趣。 游戏模式&#xff1a; 冒险模式&#xff1a;踏上刺激的冒险旅程&#xff0c;…

SpringBoot中关于knife4j 中的一些相关注解

1、效果图 对比可以明显的看到加了注解与没有加注解所表现出来的效果不同&#xff08;加了注解的更加明了清晰&#xff09; 2、实现效果 Tag注解‌用于为测试方法或测试类添加标签&#xff0c;以便在执行测试时根据标签进行过滤。使用Tag注解可以更灵活地控制测试的执行&#…

双目标定与生成深度图

基于C#联合Halcon实现双目标定整体效果 一&#xff0c;标定 1&#xff0c;标定前准备工作 &#xff08;获取描述文件与获取相机参数&#xff09; 针对标准标定板可以直接调用官方提供描述文件&#xff0c;也可以自己生成描述文件后用PS文件打印 2&#xff0c;相机标定 &…

操作系统1.4

一、操作系统内核 二、微内核与大内核 三、补充