逻辑回归 算法推导与基于Python的实现详解

文章目录

  • 1 逻辑回归概述
  • 2 逻辑回归公式推导与求解
    • 2.1 公式推导
    • 2.2公式求解
  • 3 基于Python的实现
    • 3.1可接收参数
    • 3.2 完整代码示例

1 逻辑回归概述

image.png

逻辑回归(Logistic Regression)是一种用于分类问题的统计学习方法。它基于线性回归的原理,通过将线性函数的输出值映射到[0,1]区间上的概率值,从而进行分类。

逻辑回归的输入是一组特征变量,它通过计算每个特征与对应系数的乘积,加上截距项得到线性函数,然后将该函数的输出值经过sigmoid函数的映射,得到概率值。

逻辑回归常用于二分类问题,即将样本分为两类,如判断一封邮件是否为垃圾邮件。逻辑回归还可以扩展到多分类问题,如将样本分为三类或更多类别。

逻辑回归具有简单、高效、易于理解等优点,在实际应用中被广泛使用,如金融风控、医学诊断、推荐系统等领域。

2 逻辑回归公式推导与求解

2.1 公式推导

P ( y = 1 ∣ x ) = 1 1 + e − ( β 0 + β 1 x 1 + β 2 x 2 + . . . + β p x p ) P(y=1|x) = \frac{1}{1+e^{-(\beta_0+\beta_1x_1+\beta_2x_2+...+\beta_px_p)}} P(y=1∣x)=1+e(β0+β1x1+β2x2+...+βpxp)1
逻辑回归的公式大家可能很熟悉,但并不知道推导过程,事实上这一推导过程也十分简单。

给定输入特征 x \mathbf{x} x,逻辑回归模型的输出 y y y 可以表示为:

y = P ( y = 1 ∣ x ) = σ ( w T x ) y = P(y=1|\mathbf{x}) = \sigma(\mathbf{w}^T \mathbf{x}) y=P(y=1∣x)=σ(wTx)

其中, σ ( z ) \sigma(z) σ(z) 表示 sigmoid 函数,定义为: σ ( z ) = 1 / ( 1 + e − z ) \sigma(z) = 1 / (1 + e^{-z}) σ(z)=1/(1+ez)。这一函数在之后的深度学习中也会经常用到。

至于什么使用sigmoid函数,原因也很简单,很多教材都把这个重要的思考环节忽略了,博主在此进行补充。
我们希望找到一个函数的值域在[0,1]的函数,但是这种函数并不容易找到,线性回归使用的公式 y = θ x y=\theta x y=θx的值域是(-∞,+∞)。因此,我们引入了odds(几率)的概念。 o d d s = P 1 − P odds=\frac{P}{1-P} odds=1PPodds的取值为(0,+∞),而对于log函数,其定义域刚好为(0,+∞),值域为(-∞,+∞)。因此,我们便可以构造一个函数 log ⁡ ( P 1 − P ) = θ x \log(\frac{P}{1-P})=\theta x log(1PP)=θx对P进行求解,即得到了逻辑回归的基本形式 P = 1 1 + e − x P=\frac{1}{1+e^{-x}} P=1+ex1也就是我们所常说的sigmoid函数。

2.2公式求解

为了方便推导,我们假设训练数据集包含 m m m 个样本,每个样本有 n n n 个特征,即 X ∈ R m × n \mathbf{X} \in \mathbb{R}^{m\times n} XRm×n,标签为 y ∈ { 0 , 1 } m y \in \{0, 1\}^m y{0,1}m。为了构建模型,我们需要使用训练数据集求解模型参数 w \mathbf{w} w

我们使用最大似然估计来求解模型参数。最大似然估计的目标是找到一组模型参数 w \mathbf{w} w,使得训练数据集出现的概率最大。设训练数据集中的第 i i i 个样本的输入特征为 x i \mathbf{x}_i xi,输出为 y i y_i yi,其概率表示为:

P ( y i ∣ x i ; w ) = σ ( w T x i ) y i ( 1 − σ ( w T x i ) ) 1 − y i P(y_i|\mathbf{x}_i; \mathbf{w}) = \sigma(\mathbf{w}^T \mathbf{x}_i)^{y_i} (1 - \sigma(\mathbf{w}^T \mathbf{x}_i))^{1-y_i} P(yixi;w)=σ(wTxi)yi(1σ(wTxi))1yi

训练数据集的概率可以表示为:

P ( y ∣ X ; w ) = ∏ i P ( y i ∣ x i ; w ) = ∏ i σ ( w T x i ) y i ( 1 − σ ( w T x i ) ) 1 − y i P(y|\mathbf{X}; \mathbf{w}) = \prod_i P(y_i|\mathbf{x}_i; \mathbf{w}) = \prod_i \sigma(\mathbf{w}^T \mathbf{x}_i)^{y_i} (1 - \sigma(\mathbf{w}^T \mathbf{x}_i))^{1-y_i} P(yX;w)=iP(yixi;w)=iσ(wTxi)yi(1σ(wTxi))1yi

对数似然函数为:

L ( w ) = log ⁡ P ( y ∣ X ; w ) = ∑ i [ y i log ⁡ ( σ ( w T x i ) ) + ( 1 − y i ) log ⁡ ( 1 − σ ( w T x i ) ) ] L(\mathbf{w}) = \log P(y|\mathbf{X}; \mathbf{w}) = \sum_i [y_i \log(\sigma(\mathbf{w}^T \mathbf{x}_i)) + (1-y_i) \log(1 - \sigma(\mathbf{w}^T \mathbf{x}_i))] L(w)=logP(yX;w)=i[yilog(σ(wTxi))+(1yi)log(1σ(wTxi))]

我们的目标是最大化对数似然函数 L ( w ) L(\mathbf{w}) L(w)。使用梯度上升算法来求解最优参数 w \mathbf{w} w。对 L ( w ) L(\mathbf{w}) L(w) 求导,得到: ∂ L ( w ) ∂ w = ∑ i ( σ ( w T x i ) − y i ) x i \frac{\partial L(\mathbf{w})}{\partial \mathbf{w}} = \sum_i(\sigma(\mathbf{w}^T \mathbf{x}_i) - y_i)\mathbf{x}_i wL(w)=i(σ(wTxi)yi)xi 使用梯度上升算法,每次更新 w \mathbf{w} w 的值为: w ← w + α ∑ i ( σ ( w T x i ) − y i ) x i \mathbf{w} \leftarrow \mathbf{w} + \alpha \sum_i (\sigma(\mathbf{w}^T \mathbf{x}_i) - y_i)\mathbf{x}_i ww+αi(σ(wTxi)yi)xi 其中, α \alpha α 是学习率。

3 基于Python的实现

3.1可接收参数

在Python中,使用Scikit-learn库中的LogisticRegression类可以创建逻辑回归模型。

下面是LogisticRegression类的主要参数和方法:

参数:

  • penalty: 惩罚项,可以是‘l1’、‘l2’、‘elasticnet’、‘none’中的一种,默认为‘l2’。
  • C: 正则化系数,用于控制模型的复杂度,C值越小,模型越简单,默认为1.0。
  • solver: 用于优化问题的算法,可以是‘newton-cg’、‘lbfgs’、‘liblinear’、‘sag’、‘saga’中的一种,默认为‘lbfgs’。
  • max_iter: 最大迭代次数,用于控制优化算法的迭代次数,默认为100

3.2 完整代码示例

我们使用sklearn自带的威斯康辛州乳腺癌数据集,进行模型的训练和预测。

  1. 首先导入sklearn等必要的包与数据:
import pandas as pd  
from sklearn.datasets import load_breast_cancer  
from sklearn.linear_model import LogisticRegression  
from sklearn.model_selection import train_test_split  
import matplotlib.pyplot as plt  
import matplotlib as mpl  
  
## 设置字符集,防止中文乱码  
mpl.rcParams['font.sans-serif'] = [u'simHei']  
mpl.rcParams['axes.unicode_minus'] = False  
  
# 加载数据集  
data = load_breast_cancer()  
# 转换为DataFrame  
df = pd.DataFrame(data.data, columns=data.feature_names)  
df['target'] = pd.Series(data.target)

此处skearn自带的数据已经是清洗后的版本,如果使用的是原始威斯康辛州乳腺癌数据集或者其他个人数据集,需要对数据进行查看、清洗与特征初筛。比如特征集可能包含患者的身份ID等无用信息,将此类信息直接删除即可。

同时可以使用常用的函数查看数据集状态,比如

# 检测非数据类型与缺失值  
print(df.info())
# 检查异常值  
print(df.describe())

这里官方的数据运行结果如下,可以看到已经不需要做什么修改了。其输出如下:
image.png
如果Type中出现object,通常意味着这一类的某一行存在非数值类型,此时可以使用

df['A'] = pd.to_numeric(df['A'], errors='coerce').astype(float)

这段函数可以将df的A列转为float类型,并将不能转换的数值变为空值。之后与其他缺失值共同使用dropna删除即可。
2. 模型建立与拟合

# 分割数据集  
X = df.drop('target', axis=1)  
y = df['target']  
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)  
  
# 创建逻辑回归模型  
clf = LogisticRegression(max_iter=100)  
  
# 拟合模型  
clf.fit(X_train, y_train)

建立模型时,通常不用修改超参数,也可根据数据集的实际特点或拟合后的一些警告(比如函数最终没有收敛)进行超参数的调整。
3. 模型预测与绘图

# 预测  
y_pred = clf.predict(X_test)  
  
# 评估  
accuracy = clf.score(X_test, y_test)  
  
print("预测结果:", y_pred)  
print("准确率:", accuracy)  
  
plt.plot(range(len(X_test)), y_test, 'ro', markersize=4, zorder=3, label=u'真实值')  
plt.plot(range(len(X_test)), y_pred, 'go', markersize=10, zorder=2, label=u'预测值')  
plt.legend()  
plt.show()

从图片上可以看出,红色圈与绿色圈同时出现的点为预测正确的数据,二者单独出现的点位为预测错误的数据。此处对预测结果的判断,是基于概率大于0.5和小于0.5来分割的。如果我们想达到当大于0.8的概率为0才视为0时(降低漏诊概率),可以使用下面的方法。
image.png

  1. 自定义阈值
print(clf.predict_proba(X_test)[:,0]>0.8)

clf.predict_proba(X_test)用于输出不同类别的概率,其输出型输入下:
image.png
如果想要获取分类为0的可能性大于0.8的数据,提取粗该列的数值加以判断即可。

本篇文章代码可点此下载

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/5097.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【软考——系统架构师】架构、系分、软设的区别和联系

🔎这里是【软考——系统架构师】,关注我考试轻松过线 👍如果对你有帮助,给博主一个免费的点赞以示鼓励 欢迎各位🔎点赞👍评论收藏⭐️ 文章目录👀三科相同点👀三科不同点--上午题&am…

CISAW-CISDR灾难备份与恢复专业级认证

证书样板: 认证机构 中国网络安全审查技术与认证中心(英文缩写为:CCRC,原为中国信息安全认证中心)于 2006 年由中央机构编制委员会办公室批准成立,为国家市场监督管理总局直属事业单位。依据《网络安全法》 《网络安…

Java每日一练(20230401)

目录 1. 合并K个升序链表 🌟🌟🌟 2. 最长有效括号 🌟🌟🌟 3. 分割回文串 🌟🌟 🌟 每日一练刷题专栏 🌟 Golang每日一练 专栏 Python每日一练 专栏 …

Linux系统一键安装最新内核并开启 BBR 脚本

本脚本适用环境 系统支持:CentOS 6,Debian 8,Ubuntu 16 虚拟技术:OpenVZ 以外的,比如 KVM、Xen、VMware 内存要求:≥128M 更新日期:2022 年 5 月 11 日 关于本脚本 1、本脚本已在 蓝易云 上…

后端Springboot框架搭建APi接口开发(第二章)

上一章我讲述了如何使用Mybatis操作数据库。这一章我讲述如何利用Sptring框架搭建API接口 第一节:封装SqlSessionFactory工具类 在API操作数据库大量调用SqlSessionFactory,因此应将SqlSessionFactory封装成工具类供方法随时调用 在文件结构中的util文…

前端直接生成GIF动态图实践

前言去年在博客中发了两篇关于GIF动态生成的博客,GIF图像动态生成-JAVA后台生成和基于FFmpeg的Java视频Mp4转GIF初探,在这两篇博客中都是采用JAVA语言在后台进行转换。使用JAVA的同学经过自己的改造和开发也可以应用在项目上。前段时间有朋友私下问&…

MySQL主从复制、读写分离(MayCat2)实现数据同步

文章目录1.MySQL主从复制原理。2.实现MySQL主从复制(一主两从)。3.基于MySQL一主两从配置,完成MySQL读写分离配置。(MyCat2)1.MySQL主从复制原理。 MySQL主从复制是一个异步的复制过程,底层是基于Mysql数据…

设计模式之门面模式(C++)

作者:翟天保Steven 版权声明:著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处 一、门面模式是什么? 门面模式是一种结构型的软件设计模式,也叫外观模式,它提供了统一的接口去…

程序员的天花板到底有多高?

程序员是很多人眼里的职业天花板,毕竟薪资水平、发展前景摆在那里,而且,一个又会debug又会修电脑(划掉),还会费心为你制造浪漫的程序员,有谁不喜欢呢? 但现实中很多程序员都焦虑担忧…

从0使用TCP手撸http服务器六

html模板设计: 上篇我们说到通过路由去返回不同的html页面,如果每一个页面都是一个数组的话,那么我们可能需要很多数组,里面很多内容都是一样的,这样子会浪费我们单片机很多flash,我们需要把共同的部分固定…

神经网络之反向传播算法(加入Nesterov动量的误差反向传播算法)

文章目录1、Nesterov动量2、结合Nesterov动量的误差反向传播算法3、算法实现3.1 训练过程3.2 测试过程4、参考源码及数据集1、Nesterov动量 在动量法提出不久之后,Nesterov动量也随之被提了出来,此方法属于动量法的进一步发展,与动量法不同的…

基于windows11配置深度学习环境包含WSL2配置ubuntu20.04

基于windows11配置深度学习环境包含WSL2配置ubuntu20.04目录平台WSL2 系统准备(Windows Subsystem of Linux)安装WSL2基于WSL2中的Ubuntu安装CUDA和cudnnCUDA 安装目录 平台 系统 : windows11 专业版 CPU Intel I7 8750hq 显卡:Nvidia GTX1060移动端 显卡驱动版本…

安装Windows11提示这台电脑不符合安装此版本的Windows所需的最低系统要求

现在很多用户都会选择用U盘来安装系统,最新有用户在使用U盘安装Win11系统的时候,结果安装到第一步就提示这台电脑无法运行Windows11,这台电脑不符合安装此版本的Windows所需的最低系统要求。下面小编就来教大家解决此问题的方法。 问题解析&a…

BIO/NIO/AIO/IO多路复用简介

bio、nio、aio、io多路复用、reactor模式io,在将IO的时候,是不是都遇到过这些概念,也有种傻傻分不清?甚至别人在大谈特谈的时候,一会nio,一会io多路复用,一会又搞到reactor模式上去了&#xff1…

一文搞懂原型和原型链

在了解原型和原型链之前首先得明确它俩是什么东西: 原型:prototype 又称显示原型 1、原型是一个普通对象 2、只有构造函数才具备该属性 3、公有属性可操作 隐式原型:__proto__ 1、只有对象(普通对象、函数对象)具备 2、私有的对…

《Spring系列》第2章 解析XML获取Bean

一、基础代码 Spring加载bean实例的代码 public static void main(String[] args) throws IOException {// 1.获取资源Resource resource new ClassPathResource("bean.xml");// 2.获取BeanFactoryDefaultListableBeanFactory factory new DefaultListableBeanFa…

Airtest自动化测试工具实战演练

一开始知道Airtest大概是在年初的时候,当时,看了一下官方的文档,大概是类似Sikuli的一个工具,主要用来做游戏自动化的,通过截图的方式用来解决游戏自动化测试的难题。最近,移动端测试的同事尝试用它的poco库…

iwebsec靶场-命令执行漏洞

漏洞简介 命令执行漏洞(Command Injection)是一种常见的安全漏洞,也被称为代码注入漏洞。它允许攻击者将恶意代码注入到受攻击的应用程序中,从而可以在应用程序的上下文中执行任意命令。 命令执行漏洞通常出现在Web应用程序中&…

好的表单设计应该遵循什么规则?

在数字化时代,表单已经成为了人们生活中不可或缺的一部分。它们可能是网站注册表格、调查问卷、订单表格或者其他类型的表格。无论表单的类型是什么,都必须经过精心设计才能提供良好的用户体验。在本文中,我们将探讨如何设计一份用户体验好的…

Redis缓存双写一致性

目录双写一致性Redis与Mysql双写一致性canal配置流程代码案例双写一致性理解缓存操作细分缓存一致性多种更新策略挂牌报错,凌晨升级先更新数据库,在更新缓存先删除缓存,在更新数据库先更新数据库,在删除缓存延迟双删策略总结双写一致性 Redis与Mysql双写一致性 canal 主要是…