机器学习day7

自定义数据集 使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数

代码

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optimizer
import matplotlib.pyplot as plt
 
class1_points = np.array([[2.1, 1.8],
                          [1.9, 2.4],
                          [2.2, 1.2],
                          [1.8, 1.5],
                          [1.3, 1.7],
                          [1.6, 2.1],
                          [1.7, 1.4]])
 
class2_points = np.array([[3.5, 3.4],
                          [3.8, 2.7],
                          [3.4, 2.9],
                          [3.1, 3.6],
                          [3.9, 2.4],
                          [4.0, 2.8],
                          [3.3, 2.5]])
 
x_train = np.concatenate((class1_points, class2_points), axis=0)
y_train = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))))
 
x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
 
seed = 42
torch.manual_seed(seed)
 
class LogisticRegreModel(nn.Module):
    def __init__(self):
        super(LogisticRegreModel, self).__init__()
        self.fc = nn.Linear(2, 1)
 
    def forward(self, x):
        x = self.fc(x)
        x = torch.sigmoid(x)
        return x
 
model = LogisticRegreModel()
 
cri = nn.BCELoss()
lr = 0.05
optimizer = optimizer.SGD(model.parameters(), lr=lr)
 
fig, (ax1, ax2) = plt.subplots(1, 2)
epoch_list = []
epoch_loss = []
 
epoches = 1000
for epoch in range(1, epoches + 1):
    y_pre = model(x_train_tensor)
    loss = cri(y_pre, y_train_tensor.unsqueeze(1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
 
    if epoch % 50 == 0 or epoch == 1:
        print(f"epoch:{epoch},loss:{loss.item()}")
        w1, w2 = model.fc.weight.data[0]
        b = model.fc.bias.data[0]
 
        slope = -w1 / w2
        intercept = -b / w2
 
        x_min, x_max = 0, 5
        x = np.array([x_min, x_max])
        y = slope * x + intercept
 
        ax1.clear()
        ax1.plot(x, y, 'r')
        ax1.scatter(x_train[:len(class1_points), 0], x_train[:len(class1_points), 1])
        ax1.scatter(x_train[len(class1_points):, 0], x_train[len(class1_points):, 1])
 
        ax2.clear()
        epoch_list.append(epoch)
        epoch_loss.append(loss.item())
        ax2.plot(epoch_list, epoch_loss, 'b')
        plt.pause(1)

from sklearn.metrics import accuracy_score, recall_score, f1_score

# 进行预测
with torch.no_grad():
    y_pred = model(x_train_tensor)
    y_pred_class = (y_pred > 0.5).float()  # 使用阈值0.5将概率转化为二分类结果

# 计算准确率、召回率和F1分数
accuracy = accuracy_score(y_train, y_pred_class)
recall = recall_score(y_train, y_pred_class)
f1 = f1_score(y_train, y_pred_class)

print(f"准确率: {accuracy:.4f}")
print(f"召回率: {recall:.4f}")
print(f"F1分数: {f1:.4f}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/963755.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

手机上运行AI大模型(Deepseek等)

最近deepseek的大火,让大家掀起新一波的本地部署运行大模型的热潮,特别是deepseek有蒸馏的小参数量版本,电脑上就相当方便了,直接ollamaopen-webui这种类似的组合就可以轻松地实现,只要硬件,如显存&#xf…

JAVA安全—反射机制攻击链类对象成员变量方法构造方法

前言 还是JAVA安全,哎,真的讲不完,太多啦。 今天主要是讲一下JAVA中的反射机制,因为反序列化的利用基本都是要用到这个反射机制,还有一些攻击链条的构造,也会用到,所以就讲一下。 什么是反射…

数据库和数据表的创建、修改、与删除

1.标识符命名规则 数据库名、表名不得超过30个字符,变量名限制为29个 必须只能包含A-Z,a-z,0-9,_共63个字符 数据库名、表名、字段名等对象名中间不能包含空格 同一个MySQL软件中,数据库不能同名;同一个库中,表不能重名&#…

(电脑版)植物大战僵尸幼儿园版本,开启你的冒险之旅!

欢迎来到植物大战僵尸中文版,园长Jen已准备好迎接你的挑战!在这个充满乐趣和策略的游戏中,你将体验到多种游戏模式,每种模式都带来不同的挑战和乐趣。 游戏模式: 冒险模式:踏上刺激的冒险旅程,…

SpringBoot中关于knife4j 中的一些相关注解

1、效果图 对比可以明显的看到加了注解与没有加注解所表现出来的效果不同(加了注解的更加明了清晰) 2、实现效果 Tag注解‌用于为测试方法或测试类添加标签,以便在执行测试时根据标签进行过滤。使用Tag注解可以更灵活地控制测试的执行&#…

双目标定与生成深度图

基于C#联合Halcon实现双目标定整体效果 一,标定 1,标定前准备工作 (获取描述文件与获取相机参数) 针对标准标定板可以直接调用官方提供描述文件,也可以自己生成描述文件后用PS文件打印 2,相机标定 &…

操作系统1.4

一、操作系统内核 二、微内核与大内核 三、补充

Leetcode—598. 区间加法 II【简单】

2025每日刷题&#xff08;206&#xff09; Leetcode—598. 区间加法 II 实现代码 class Solution { public:int maxCount(int m, int n, vector<vector<int>>& ops) {int ans m * n;int x ops.size();if(ops.empty()) {return ans;}int xm ops[0][0], ym …

【机器学习篇】K-Means 算法详解:从理论到实践的全面解析

网罗开发 &#xff08;小红书、快手、视频号同名&#xff09; 大家好&#xff0c;我是 展菲&#xff0c;目前在上市企业从事人工智能项目研发管理工作&#xff0c;平时热衷于分享各种编程领域的软硬技能知识以及前沿技术&#xff0c;包括iOS、前端、Harmony OS、Java、Python等…

GWO优化LSBooST回归预测matlab

灰狼优化算法&#xff08;Grey Wolf Optimizer&#xff0c;简称 GWO&#xff09;&#xff0c;是一种群智能优化算法&#xff0c;由澳大利亚格里菲斯大学的 Mirjalii 等人于 2014 年提出。该算法的设计灵感源自灰狼群体的捕食行为&#xff0c;核心思想是模仿灰狼社会的结构与行为…

图形学笔记 - 5-光线追踪 - 辐射度量学

文章目录 辐射度量学辐射能和通量&#xff08;功率&#xff09;Radiant Energy and Flux (Power)辐射强度 Radiant Intensity辐照度Irradiance朗伯余弦定律Lambert’s Cosine Law Radiance辐亮度Incident Radiance入射辐亮度Exiting Radiance出射辐亮度 双向反射分布函数 Bidir…

低代码产品表单渲染架构

在React和Vue没有流行起来的时候&#xff0c;低代码产品的表单渲染设计通常会使用操作Dom的方式实现。 下面是一个表单的例子&#xff1a; 产品层 用户通过打开表单&#xff0c;使用不同业务场景业务下的表单页面&#xff0c;中间的Render层就是技术实现。 每一个不同业务的表单…

游戏引擎 Unity - Unity 设置为简体中文、Unity 创建项目

Unity Unity 首次发布于 2005 年&#xff0c;属于 Unity Technologies Unity 使用的开发技术有&#xff1a;C# Unity 的适用平台&#xff1a;PC、主机、移动设备、VR / AR、Web 等 Unity 的适用领域&#xff1a;开发中等画质中小型项目 Unity 适合初学者或需要快速上手的开…

Sentinel 断路器在Spring Cloud使用

文章目录 Sentinel 介绍同类对比微服务雪崩问题问题原因问题解决方案请求限流线程隔离失败处理服务熔断解决雪崩问题的常见方案有哪些&#xff1f; Sentineldocker 安装账号/ 密码项目导入簇点链路请求限流线程隔离Fallback服务掉线时的处理流程 服务熔断 Sentinel 介绍 随着微…

我们信仰AI?从神明到人工智能——信任的进化

信任的进化&#xff1a; 信任是我们最宝贵的资产。而现在&#xff0c;它正像黑色星期五促销的廉价平板电视一样&#xff0c;被一点点拆解。在过去&#xff0c;世界很简单&#xff1a;人们相信晚间新闻、那些满是灰尘书籍的教授&#xff0c;或者手持病历、眉头紧锁的医生。而如…

51单片机 05 矩阵键盘

嘻嘻&#xff0c;LCD在RC板子上可以勉强装上&#xff0c;会有一点歪。 一、矩阵键盘 在键盘中按键数量较多时&#xff0c;为了减少I/O口的占用&#xff0c;通常将按键排列成矩阵形式&#xff1b;采用逐行或逐列的“扫描”&#xff0c;就可以读出任何位置按键的状态。&#xf…

OpenCV4,快速入门,Windows 系统 OpenCV4.10.0 开发环境搭建

文章目录 1. 摘要2. OpenCV 4.10.02. 环境配置2.1 配置包含目录2.2 配置库目录2.3 配置链接器2.4 配置环境变量并重启VS 3. 编写测试代码 1. 摘要 本文详细介绍在Windows系统下配置OpenCV 4.10开发环境的步骤&#xff0c;适用于Visual Studio 2022等C开发场景。内容涵盖&#…

unity学习26:用Input接口去监测: 鼠标,键盘,虚拟轴,虚拟按键

目录 1 用Input接口去监测&#xff1a;鼠标&#xff0c;键盘&#xff0c;虚拟轴&#xff0c;虚拟按键 2 鼠标 MouseButton 事件 2.1 鼠标的基本操作 2.2 测试代码 2.3 测试情况 3 键盘Key事件 3.1 键盘的枚举方式 3.2 测试代码同上 3.3 测试代码同上 3.4 测试结果 4…

FreeRTOS从入门到精通 第十三章(信号量)

参考教程&#xff1a;【正点原子】手把手教你学FreeRTOS实时系统_哔哩哔哩_bilibili 一、信号量知识回顾 1、概述 &#xff08;1&#xff09;信号量是一种解决同步问题的机制&#xff0c;可以实现对共享资源的有序访问&#xff0c;FreeRTOS中使用的是二值信号量、计数型信号…

业务系统文件上传和互传如何做到高效又安全?

在当今数字化的商业环境中&#xff0c;企业越来越依赖于高效的业务系统来处理日常运营。无论是金融、医疗、能源还是政府机构&#xff0c;内部各类业务系统如OA、供应商管理系统等都需要终端用户上传各种类型的文件附件。然而&#xff0c;在确保这些文件能够快速而准确地上传和…