监督学习之逻辑回归

逻辑回归(Logistic Regression)

逻辑回归是一种用于二分类(binary classification)问题的统计模型。尽管其名称中有“回归”二字,但逻辑回归实际上用于分类任务。它的核心思想是通过将线性回归的输出映射到一个概率值,以进行类别预测。

1. 模型概述

逻辑回归的基本公式为:

P ( y = 1 ∣ x ) = σ ( z ) = 1 1 + e − z P(y=1|x) = \sigma(z) = \frac{1}{1 + e^{-z}} P(y=1∣x)=σ(z)=1+ez1

其中:

  • ( P ( y = 1 ∣ x P(y=1|x P(y=1∣x) ) 是给定特征 ( x x x ) 时,因变量 ( y y y ) 等于 1 的概率。
  • ( z = β 0 z = \beta_0 z=β0 + β 1 x 1 \beta_1x_1 β1x1 + β 2 x 2 \beta_2 x_2 β2x2 + … \ldots + β n x n \beta_n x_n βnxn ) 是线性组合。
  • ( σ ( z ) \sigma(z) σ(z) ) 是 sigmoid 函数,将输出值映射到 0 0 0 1 1 1之间。
2. Sigmoid 函数

Sigmoid 函数的形状如下:

σ ( z ) = 1 1 + e − z \sigma(z) = \frac{1}{1 + e^{-z}} σ(z)=1+ez1

  • 当 ( z z z ) 为负时,函数输出接近于 0 0 0;当 ( z z z ) 为正时,函数输出接近于 1 1 1
  • 这种特性使得 sigmoid 函数非常适合用于概率预测。
3. 损失函数

逻辑回归的损失函数为交叉熵损失(cross-entropy loss),用于衡量模型预测与实际标签之间的差异。其公式为:

L ( β ) = − 1 N ∑ i = 1 N [ y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] L(\beta) = -\frac{1}{N} \sum_{i=1}^{N} [y_i \log(\hat{y}_i) + (1-y_i) \log(1-\hat{y}_i)] L(β)=N1i=1N[yilog(y^i)+(1yi)log(1y^i)]

其中:

  • ( N N N ) 是样本数量。
  • ( y i y_i yi ) 是实际标签。
  • ( y ^ i \hat{y}_i y^i ) 是预测概率。

逻辑回归的损失函数求解通常通过 最大似然估计梯度下降 等优化算法进行。逻辑回归模型中常用的损失函数是 交叉熵损失,目标是通过最小化损失函数来找到最佳的模型参数。

1. 逻辑回归中的损失函数

(1)损失函数

逻辑回归的损失函数基于交叉熵(Cross-Entropy Loss),用于衡量模型预测的概率分布与实际标签之间的差异。对于二分类问题,其形式为:

L ( β ) = − 1 N ∑ i = 1 N [ y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] L(\beta) = - \frac{1}{N} \sum_{i=1}^{N} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right] L(β)=N1i=1N[yilog(y^i)+(1yi)log(1y^i)]

其中:

  • ( N N N ) 是样本数量。
  • ( y i y_i yi ) 是第 ( i i i ) 个样本的真实标签( 0 0 0 1 1 1)。
  • ( y ^ i = σ ( z i ) \hat{y}_i = \sigma(z_i) y^i=σ(zi) ) 是第 ( i i i ) 个样本的预测概率。
  • ( z i = β 0 + β 1 x i 1 + β 2 x i 2 + ⋯ + β n x i n z_i = \beta_0 + \beta_1 x_{i1} + \beta_2 x_{i2} + \dots + \beta_n x_{in} zi=β0+β1xi1+β2xi2++βnxin ) 是线性组合。
  • ( σ ( z ) \sigma(z) σ(z) ) 是 sigmoid 函数,定义为:
    σ ( z ) = 1 1 + e − z \sigma(z) = \frac{1}{1 + e^{-z}} σ(z)=1+ez1
    这将线性回归的输出 ( z z z ) 映射到 ( ( 0 0 0, 1 1 1) ) 之间,作为类别为 1 1 1 的预测概率。
(2)如何求解损失函数

求解逻辑回归的损失函数通常使用 梯度下降 等优化方法。目标是找到使损失函数最小的参数 ( β \beta β ),即 最小化交叉熵损失。求解过程可以概括为以下步骤:

** 计算梯度**

为了最小化损失函数,我们需要对每个参数 ( β j \beta_j βj) 计算损失函数的偏导数(即梯度),并通过优化算法(如梯度下降)进行更新。

对于交叉熵损失函数,梯度计算公式为:

∂ L ∂ β j = − 1 N ∑ i = 1 N ( y i − y ^ i ) x i j \frac{\partial L}{\partial \beta_j} = -\frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i) x_{ij} βjL=N1i=1N(yiy^i)xij
其中:

  • ( x i j x_{ij} xij ) 是第 ( i i i ) 个样本的第 ( j j j ) 个特征。
  • ( y i y_i yi ) 是第 ( i i i ) 个样本的实际标签。
  • ( y ^ i \hat{y}_i y^i) 是第 ( i i i ) 个样本的预测概率。

使用梯度下降更新参数梯度下降法通过以下公式迭代更新参数:

β j = β j − α ∂ L ∂ β j \beta_j = \beta_j - \alpha \frac{\partial L}{\partial \beta_j} βj=βjαβjL

其中:

  • ( α \alpha α ) 是学习率(控制每次更新步长的大小)。
  • ( ∂ L ∂ β j \frac{\partial L}{\partial \beta_j} βjL ) 是损失函数对参数 ( β j \beta_j βj ) 的梯度。

通过不断更新参数,使得损失函数逐渐减小,直到达到全局或局部最优解。

(3) 代码示例:逻辑回归中的梯度下降

以下是使用 Python 实现逻辑回归梯度下降的示例:

import numpy as np

# Sigmoid 函数
def sigmoid(z):
    return 1 / (1 + np.exp(-z))

# 损失函数 (交叉熵)
def compute_loss(y, y_pred):
    return -np.mean(y * np.log(y_pred) + (1 - y) * np.log(1 - y_pred))

# 梯度下降算法
def gradient_descent(X, y, learning_rate=0.1, num_iterations=1000):
    m, n = X.shape
    beta = np.zeros(n)  # 初始化参数
    for i in range(num_iterations):
        z = np.dot(X, beta)
        y_pred = sigmoid(z)
        gradients = np.dot(X.T, (y_pred - y)) / m
        beta -= learning_rate * gradients
        if i % 100 == 0:
            loss = compute_loss(y, y_pred)
            print(f"Iteration {i}: Loss = {loss}")
    return beta

# 示例数据
X = np.array([[1, 2], [1, 3], [2, 2], [2, 3]])  # 样本数据
y = np.array([0, 0, 1, 1])  # 标签数据

# 在样本数据前面加一列 1 用于偏置项 (截距项)
X_bias = np.c_[np.ones(X.shape[0]), X]

# 运行梯度下降求解参数
beta = gradient_descent(X_bias, y)
print("求解得到的参数:", beta)
4. 优缺点

优点

  • 简单易懂:逻辑回归模型简单,易于实现和解释。
  • 概率输出:模型输出的是预测的概率,可以用于更细致的决策。
  • 适用于线性可分问题:在特征与目标变量之间存在线性关系时,表现良好。

缺点

  • 线性假设:假设特征与目标之间存在线性关系,不适用于复杂的非线性关系。辑回归假设特征和类别之间的关系是线性的,对于复杂非线性问题,表现不如其他模型(如决策树、神经网络)。
  • 受特征选择影响:模型对输入特征敏感,需要合适的特征选择和处理。
  • 容易过拟合:在特征数量较多时,可能会发生过拟合,特别是当样本量不足时。
  • 无法解决多分类问题:标准的逻辑回归只适用于二分类问题,若要应用于多分类问题,需要使用 Softmax 回归或一对多策略。

5. 代码示例

以下是使用 Python 的 scikit-learn 库实现逻辑回归的示例:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

# 生成示例数据
X, y = make_classification(n_samples=100, n_features=2, n_classes=2, n_informative=2, n_redundant=0, random_state=42)

# 拆分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建逻辑回归模型
model = LogisticRegression()
model.fit(X_train, y_train)

# 进行预测
y_pred = model.predict(X_test)

# 评估模型
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)

print("准确率:", accuracy)
print("混淆矩阵:\n", conf_matrix)
print("分类报告:\n", class_report)

# 绘制决策边界
plt.scatter(X[:, 0], X[:, 1], c=y, cmap='coolwarm', edgecolors='k')
xlim = plt.gca().get_xlim()
ylim = plt.gca().get_ylim()

xx, yy = np.meshgrid(np.linspace(xlim[0], xlim[1], 100), np.linspace(ylim[0], ylim[1], 100))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

plt.contourf(xx, yy, Z, alpha=0.3, cmap='coolwarm')
plt.title('逻辑回归决策边界')
plt.xlabel('特征1')
plt.ylabel('特征2')
plt.show()

结果
在这里插入图片描述

在这里插入图片描述

6. 总结

逻辑回归是一种简单而有效的分类模型,适合于解决二分类问题。尽管它有一些局限性(如线性假设),但在许多实际应用中,逻辑回归因其易于解释和实现而被广泛使用。通过合适的特征选择和数据处理,逻辑回归能够在很多情况下提供可靠的分类结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/900876.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何限制电脑软件的安装?

1.修改注册表(需谨慎操作,建议备份注册表): 打开“运行”对话框,输入 regedit 打开注册表编辑器。 导航到 HKEY_CURRENT_USER\Software\Microsoft\Windows\CurrentVersion\Policies\Explorer。 创建新的DWORD值&…

2024双11买什么东西比较好?双十一购物清单,双十一囤货清单排名

今年双十一好价确实多,一方面是年底促销,一方面国补也很给力,种草很久的产品趁着这个时间下单最好不过了,不知道各位有哪些心水好物,我今年入手了不少生活用品和数码类产品,下文就挑选几款我觉得特别值得入…

基于Multisim的四人智力竞赛抢答器设计与仿真

1)设计任务 设计一台可供 4 名选手参加比赛的智力竞赛抢答器。 用数字显示抢答倒计时间,由“9”倒计到“0”时,无人抢答,蜂鸣器连续响 1 秒。选手抢答时,数码显示选手组号,同时蜂鸣器响 1 秒,倒…

使用Prometheus对微服务性能自定义指标监控

背景 随着云计算和容器化技术的不断发展,微服务架构逐渐成为现代软件开发的主流趋势。微服务架构将大型应用程序拆分成多个小型、独立的服务,每个服务都可以独立开发、部署和扩展。这种架构模式提高了系统的可伸缩性、灵活性和可靠性,但同时…

Appium中的api(一)

目录 1.基础python代码准备 1--参数的一些说明 2--python内所要编写的代码 解释 2.如何获取包名和界面名 1-api 2-完整代码 代码解释 3.如何关闭驱动连接 4.安装卸载app 1--卸载 2--安装 5.判断app是否安装 6.将应用放到后台在切换为前台的时间 7.UIAutomatorViewer的使用 1--找…

学习笔记——路由——IP组播-PIM-DM(密集模式)前言概述

7、PIM-DM(密集模式) (1)前言 PIM-DM(PIM Dense Mode)使用“推(Push)模式”转发组播报文,一般应用于组播组成员规模相对较小、相对密集的网络。 在实现过程中,它会假设网络中的组成员分布非常稠密,每个网段都可能存在组成员。当有活跃的组…

TLS协议基本原理与Wireshark分析

01背 景 随着车联网的迅猛发展,汽车已经不再是传统的机械交通工具,而是智能化、互联化的移动终端。然而,随之而来的是对车辆通信安全的日益严峻的威胁。在车联网生态系统中,车辆通过无线网络与其他车辆、基础设施以及云端服务进行…

JavaEE----多线程(四)----阻塞队列的介绍和初步实现

文章目录 1.阻塞队列1.1作用一:解耦合1.2作用二:削峰填谷1.3系统里面的阻塞队列的使用1.4实现普通队列1.5在普通队列的基础上面实现阻塞队列1.6设计优化1.7实现初步的生产者消费者模型 1.阻塞队列 阻塞队列的最大意义:就是实现“生产者消费者…

Pyramidal Flow使用指南:快手、北大、北邮,开源可免费商用视频生成模型,快速上手教程

什么是 Pyramidal Flow? Pyramidal Flow 是由快手科技、北京大学和北京邮电大学联合推出的开源视频生成模型,它是完全开源的,发布在 MIT 许可证下,允许商业使用、修改和再分发。该模型能够通过文本描述生成最高10秒、分辨率为128…

铜业机器人剥片 - SNK施努卡

SNK施努卡有色行业电解车间铜业机器人剥片 铜业机器人剥片技术是针对传统人工剥片效率低下、工作环境恶劣及生产质量不稳定的痛点而发展起来的自动化解决方案。 面临人工剥片的诸多挑战,包括低效率、工作环境差、人员流动大以及产品质量控制不精确等问题。 人工剥片…

Idea基于JRbel实现项目热部署修改Java、Xml文件无需重启项目

Idea基于JRbel实现项目热部署修改Java、Xml文件无需重启项目 1.JRbel服务安装2.JRbel插件安装3.JRbel配置 1.JRbel服务安装 直接装插件的话,需要用到一个服务地址,服务下载链接:(现在没时间搞,会尽快加上)…

合合信息亮相PRCV大会,探讨生成式AI时代的内容安全与系统构建加速

一、前言 在人工智能技术的飞速发展下,生成式AI已经成为推动社会进步的重要力量。然而,随着技术的不断进步,内容安全问题也日益凸显。如何确保在享受AI带来的便利的同时,保障信息的真实性和安全性,已经成为整个行业待解…

Jmeter自动化实战

前言 由于系统业务流程很复杂,在不同的阶段需要不同的数据,且数据无法重复使用,每次造新的数据特别繁琐,故想着能不能使用jmeter一键造数据 创建录制模板 录制模板参考 首先创建一个录制模板 因为会有各种请求头,cookies,签名,认证信息等原因,导致手动复制粘贴的的全面导致接…

Flutter TextField和Button组件开发登录页面案例

In this section, we’ll go through building a basic login screen using the Button and TextField widgets. We’ll follow a step-bystep approach, allowing you to code along and understand each part of the process. Let’s get started! 在本节中,我们…

NVIDIA发布Nemotron-70B-Instruct,超越GPT-4o和Claude 3.5的AI模型

一、Nemotron-70B-Instruct 是什么 Nemotron-70B-Instruct 是由 NVIDIA 基于 Meta 的 Llama 3.1-70B 模型开发的先进大语言模型(LLM)。该模型采用了新颖的神经架构搜索(Neural Architecture Search,NAS)方法和知识蒸馏…

【华为HCIP实战课程二十】OSPF特殊区域NSSA配置详解,网络工程师

一、NSSA(Not So Stubby Area)区域 在NSSA区域内可以拥有ASBR,并且重分发进入OSPF的路由是以7类LSA形式存在,该类型的LSA只能存在于NSSA区域内不接收5类LSA,ABR过滤外部进入该区域的4 5类LSA,可以引入外部…

题解 力扣 LeetCode 739 每日温度 C++

题目传送门: 739. 每日温度 - 力扣(LeetCode)https://leetcode.cn/problems/daily-temperatures/description/ 思路: 就是单调栈的思路,具体见代码 不知道单调栈的,可以看我的这篇文章: 数…

web3对象如何连接以太网络节点

实例化web3对象 当我们实例化web3对象,我们一般开始用本地址,如下 import Web3 from web3 var web3 new Web3(Web3.givenProvider || ws://localhost:5173)我们要和以太网进行交互,所以我们要将’ws://localhost:5173’的本地地址换成以太…

【Linux学习】(6)编译器gcc/g++

前言 本节重点:掌握gcc/g编译器的使用,并了解其过程,原理 一、Linux编译器-gcc/g使用 1. gcc/g的基本使用 在前面我们学习了vim,知道如何在Linux中编写代码。但又是如何编译代码的?——在Linux中我们编译代码使用的是…

UDP(用户数据报协议)端口监控

随着网络的扩展,确保高效的设备通信对于优化网络功能变得越来越重要。在这个过程中,端口发挥着重要作用,它是实现外部设备集成的物理连接器。通过实现数据的无缝传输和交互,端口为网络基础设施的顺畅运行提供了保障。端口使数据通…