【Sklearn-驯化】一文搞懂很难的EM算法以及实践

【Sklearn-驯化】一文搞懂很难的EM算法以及实践

 
本次修炼方法请往下查看
在这里插入图片描述

🌈 欢迎莅临我的个人主页 👈这里是我工作、学习、实践 IT领域、真诚分享 踩坑集合,智慧小天地!
🎇 免费获取相关内容文档关注:微信公众号,发送 pandas 即可获取
🎇 相关内容视频讲解 B站

🎓 博主简介:AI算法驯化师,混迹多个大厂搜索、推荐、广告、数据分析、数据挖掘岗位 个人申请专利40+,熟练掌握机器、深度学习等各类应用算法原理和项目实战经验

🔧 技术专长: 在机器学习、搜索、广告、推荐、CV、NLP、多模态、数据分析等算法相关领域有丰富的项目实战经验。已累计为求职、科研、学习等需求提供近千次有偿|无偿定制化服务,助力多位小伙伴在学习、求职、工作上少走弯路、提高效率,近一年好评率100%

📝 博客风采: 积极分享关于机器学习、深度学习、数据分析、NLP、PyTorch、Python、Linux、工作、项目总结相关的实用内容。

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

🌵文章目录🌵

    • 🎯 1. 基本介绍
    • 💡 2. 原理介绍
      • 2.1 EM算法
      • 2.2 GMM算法
    • 🚀 3. 代码实践
    • 🚀 4. 注意事项
    • 🚀 5. 总结

下滑查看解决方法

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

  

🎯 1. 基本介绍

  高斯混合模型(Gaussian Mixture Model, GMM)是一种概率模型,它假设数据点是从多个高斯分布中生成的。GMM在聚类分析中非常有用,特别是当数据集包含多个子群体且每个子群体具有不同的特征分布时。

  • 在贝叶斯学派,如果知道观测变量以及其分布函数,我们可以通过极大似然或者极大后验的方式去求取相关的分布函数的参数 θ \theta θ,但是当观测数据中含有未知的隐变量时,此时无法通过上述的方式直接去求取相关的分布参数,而EM算法就是为了去解决这个问题的
  • 相关问题:三个硬币,每次选一个硬币去抛,记录3次的结果,现在想要知道三个硬币都为正面的概率,就是求解三个硬币的模型参数
  • EM算法的应用有Kmeans算法、GMM混合模型
  • em算法的基本思想就是在给定初始参数的情况下,计算隐变量和观测变量的联合分布中隐变量的log期望(也就是建立相关的Q函数),然后在对Q函数进行进行极大似然估计,更新参数 θ \theta θ

💡 2. 原理介绍

2.1 EM算法

  • 输入:观测变量Y,隐变量Z,联合分布 P ( Y , Z ∣ θ ) P(Y,Z|\theta) P(Y,Zθ), 条件分布 P ( Z ∣ Y , θ ) P(Z|Y, \theta) P(ZY,θ),
  • 输出:模型参数 θ \theta θ
  • 选择初始化参数 θ ( 0 ) \theta^{(0)} θ(0),开始迭代
  • E步, 记 θ ( i ) \theta^{(i)} θ(i)为第i次迭代参数 θ \theta θ的估计值,在第i+1次迭代的E步,计算:

Q ( θ , θ ( i ) ) = E Z [ l o g P ( Y , Z ∣ θ ) ∣ Y , θ ( i ) ] Q(\theta, \theta^{(i)})=E_Z[logP(Y,Z|\theta)|Y, \theta^{(i)}] Q(θ,θ(i))=EZ[logP(Y,Zθ)Y,θ(i)]

  • 由于隐变量为离散的变量,因此:

Q ( θ , θ ( i ) ) = ∑ Z l o g P ( Y , Z ∣ θ ) P ( Z ∣ Y , θ ( i ) ) Q(\theta, \theta^{(i)})=\sum_{Z}logP(Y,Z|\theta)P(Z|Y,\theta^{(i)}) Q(θ,θ(i))=ZlogP(Y,Zθ)P(ZY,θ(i))

  • 其中 P ( Z ∣ Y , θ ( i ) ) P(Z|Y,\theta^{(i)}) P(ZY,θ(i))时在给定观测数据Y和参数估计 θ ( i ) \theta^{(i)} θ(i)下隐变量数据Z的条件概率分布
  • M步,求使得 Q ( θ , θ ( i ) ) Q(\theta, \theta^{(i)}) Q(θ,θ(i))极大化的 θ \theta θ,并确定第i+1次迭代的参数估计值 θ ( i + 1 ) \theta^{(i+1)} θ(i+1):

θ ( i ) = a r g m a x θ Q ( θ , θ ( i ) ) \theta^{(i)}=\underset{\theta}{argmax}Q(\theta, \theta^{(i)}) θ(i)=θargmaxQ(θ,θ(i))

  • 不断的重复上述的过程,知道算法收敛
  • 对于这种非全局最优算法的停止条件,基本都是某个参数不在变化就停止,或者就是某个条件不再变化就算法停止迭代
  • 对于EM算法的收敛性是通过Jensen不等式得出的:

t f ( x 1 ) + ( 1 − t ) f ( x 2 ) ⩾ f ( t x 1 + ( 1 − t ) x 2 ) tf(x_1)+(1-t)f(x_2) \geqslant f(tx_1+(1-t)x_2) tf(x1)+(1t)f(x2)f(tx1+(1t)x2)
f ( ∑ i = 1 M λ i x i ) ⩽ ∑ i = 1 M λ i f ( x i ) f(\sum_{i=1}^{M}\lambda_ix_i) \leqslant \sum_{i=1}^{M}\lambda_if(x_i) f(i=1Mλixi)i=1Mλif(xi)
其中第二个式子为第一个式子的泛化形式。

2.2 GMM算法

  • EM算法一个很大的应用就是GMM模型,该模型反应的是一堆由多个高斯混合模型生成的数据,每个样本是来自哪个高斯分布,这是个典型的隐变量模型
  • 输入高斯混合模型表达式如下:

P ( y ∣ θ ) = ∑ k = 1 K α k ϕ ( y ∣ θ k ) P(y|\theta)=\sum_{k=1}^{K}\alpha_k\phi(y|\theta_k) P(yθ)=k=1Kαkϕ(yθk)

  • 在这个问题里,需要估计参数 θ \theta θ
  • 明确该问题中的隐变量,写出相关log联合分布函数:

定义 r j k r_{jk} rjk,其中其值为1或者0,为1表示第j个观测变量来自第k个高斯分布

  • 因此,可以得出相关的联合分布为:

P ( y , r ∣ θ ) = ∏ j = 1 N P ( y j , r j 1 , r j 2 . . . , r j k ∣ θ ) P(y,r|\theta)=\prod_{j=1}^{N}P(y_j, r_{j1}, r_{j2}...,r_{jk}|\theta) P(y,rθ)=j=1NP(yj,rj1,rj2...,rjkθ)
P ( y , r ∣ θ ) = ∏ k = 1 K ∏ j = 1 N [ α k ϕ ( y i ∣ θ k ) ] r j k P(y,r|\theta)=\prod_{k=1}^{K}\prod_{j=1}^{N}[\alpha_k\phi(y_i|\theta_k)]^{r_{jk}} P(y,rθ)=k=1Kj=1N[αkϕ(yiθk)]rjk
l o g P ( y , r ∣ θ ) = ∑ k = 1 K [ n k l o g α k + ∑ j = 1 N r j k [ l o g ( 1 2 π ) − l o g σ k − 1 2 σ k 2 ( y i − u k ) 2 ] ] logP(y,r|\theta)=\sum_{k=1}^{K}[n_klog_{\alpha_k}+\sum_{j=1}^{N}r_{jk}[log(\frac{1}{\sqrt{2\pi}})-log \sigma_k-\frac{1}{2\sigma_k^2}(y_i-u_k)^2]] logP(y,rθ)=k=1K[nklogαk+j=1Nrjk[log(2π 1)logσk2σk21(yiuk)2]]

  • 对上述的Q函数求取期望 E ( Q r j k ) E(Q_{r_{jk}}) E(Qrjk),则:
  • E ( r j t ∣ y , θ ) = α k ϕ ( y i ∣ θ k ) ∑ k = 1 K α k ϕ ( y i ∣ θ k ) E(r_{jt}|y, \theta)=\frac{\alpha_k\phi(y_i|\theta_k)}{\sum_{k=1}^{K}\alpha_k\phi(y_i|\theta_k)} E(rjty,θ)=k=1Kαkϕ(yiθk)αkϕ(yiθk),将其代入到Q函数中,可到只含有参数 θ \theta θ的Q函数表达式
  • 通过M步对Q函数进行极大似然估计,则可以得出参数 θ \theta θ的更新:

u ^ k = ∑ j = 1 N r ^ j k y j ∑ j = 1 N r ^ j k \hat{u}_k=\frac{\sum_{j=1}^{N}\hat{r}_{jk}y_j}{\sum_{j=1}^{N}\hat{r}_{jk}} u^k=j=1Nr^jkj=1Nr^jkyj
σ ^ k = ∑ j = 1 N r ^ j k ( y j − u k ) 2 ∑ j = 1 N r ^ j k \hat{\sigma}_k=\frac{\sum_{j=1}^{N}\hat{r}_{jk}(y_j-u_k)^2}{\sum_{j=1}^{N}\hat{r}_{jk}} σ^k=j=1Nr^jkj=1Nr^jk(yjuk)2
α ^ k = ∑ j = 1 N r ^ j t N \hat{\alpha}_k=\frac{\sum_{j=1}^{N}\hat{r}_{jt}}{N} α^k=Nj=1Nr^jt

  • 然后不断的重复更新上述的过程,即可得到最终的参数 θ \theta θ

🚀 3. 代码实践

  为了更加清晰的看明白gmm算法的执行过程,我们构建相关的数据进行执行,下面为具体的实例:

from sklearn.mixture import GaussianMixture
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt

# 生成模拟数据
X, _ = make_blobs(n_samples=300, centers=4, cluster_std=0.60, random_state=0)
# 创建GMM模型实例
gmm = GaussianMixture(n_components=4, random_state=42)

# 训练模型
gmm.fit(X)
# 聚类标签
predicted_labels = gmm.predict(X)
# 绘制聚类结果
plt.scatter(X[:, 0], X[:, 1], c=predicted_labels, cmap='viridis', marker='o')
plt.show()

# 打印对数似然值和BIC
print(f"对数似然值: {gmm.score(X)}")
print(f"BIC: {gmm.bic(X)}")

🚀 4. 注意事项

  • GMM模型的初始化参数 n_components 需要根据数据集的实际情况进行选择。
  • GMM对初始参数敏感,可能需要多次运行以获得稳定的结果。
  • GMM模型的聚类数 K 选择可能需要依赖领域知识或模型选择准则。

🚀 5. 总结

  GMM是一种强大的聚类工具,适用于数据点来自多个高斯分布的情况。scikit-learn提供了易于使用的GMM实现,允许我们通过指定高斯分布的数量来对数据进行聚类。通过本博客的代码示例,我们学习了如何创建GMM模型,进行训练、聚类标签分配、预测和评估。希望这篇博客能够帮助你更好地理解GMM模型,并将其应用于实际的聚类任务中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/767472.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Unity扩展 Text支持超链接文本

重点提示:当前的文本扩展支持多个超链接,支持修改超链接规则和支持修改超链接颜色。 近期在邮件文本中用到了超链接。最初是在邮件窗口中新加一个按钮用来超链接跳转,之后发现效果表现不如直接在文本中添加,后经过几个小时的资料…

两步解决Hugging Face下载模型速度慢/连接超时/无法下载问题

博主使用的配置是 x86_64 Linux服务器 第一步设置代理镜像: export HF_ENDPOINT=https://hf-mirror.com 第二步(使用代码时,删除引号): --token参数表示下载的模型是否需要登录验证(部分模型需要token) huggingface-cli download --token "获取的hf_*******token…

如何选择优质模型?SD3性能究竟如何?

遇到难题不要怕!厚德提问大佬答! 厚德提问大佬答12 厚德提问大佬答第十二期 你是否对AI绘画感兴趣却无从下手?是否有很多疑问却苦于没有大佬解答带你飞?从此刻开始这些问题都将迎刃而解!你感兴趣的话题,厚德…

vivado联合modelsim仿真

一. 编译Vivado仿真库 打开Vivado,Tools -> Compile Simulation Libraries 二. 设置仿真工具和库路径 因为新建工程的默认仿真工具是Vivado Simulator,所以要使用Modelsim仿真,每个新工程都要设置一次,方法如下: …

2024年【安全生产监管人员】考试题及安全生产监管人员试题及解析

题库来源:安全生产模拟考试一点通公众号小程序 安全生产监管人员考试题根据新安全生产监管人员考试大纲要求,安全生产模拟考试一点通将安全生产监管人员模拟考试试题进行汇编,组成一套安全生产监管人员全真模拟考试试题,学员可通…

现代工作场所中的睡岗检测算法应用

在现代职场环境中,员工的工作状态直接影响到公司的整体效益。睡岗现象,即员工在工作时间内打瞌睡或睡觉,不仅降低了生产力,还可能带来安全隐患。因此,如何有效地检测和预防睡岗行为成为了企业管理中的一个重要课题。随…

真的,今年我劝各位真的别轻易离职

前言 在当前的职业浪潮中,把握时机深入学习大模型技术无疑是明智之举。随着行业对这类专业人才的渴求日益增长,无论是巨头企业还是新兴初创,都在加速推进大模型的应用实践,而真正能将技术有效落地并转化为业务价值的专家却供不应…

详细解释下flutter初始示例的代码

详细解释下flutter初始示例的代码 main 首句导入需要的包 类似于其他语言的import main函数为入口函数 包裹MyApp类 MyApp 这个类继承自无状态类 可见myapp不管理任何状态 build方法是所有widget内必须实现的方法 此处返回一个 ChangeNotferiProvider 可以看到它用于管理应…

理解 REST API 和 GraphQL 的区别

你可能听说过 GraphQL,但对它与 REST 的区别还不完全确定。今天我们将介绍 REST 和 GraphQL 的一些基本原理,以及它们的不同使用场景。 GraphQL 作为 REST API 的替代品越来越受欢迎,不过它不一定是完全的“替代品”。 根据你的使用情景&am…

LLM笔记:训练大模型之并行化

1 数据并行 最常见的并行化手段主要是把数据分成多个块,然后每个节点就可以在本地独立的跑各自的数据任务,最后再和其他节点通信,进而汇总最后的结果好处就是计算效率高,每个节点可以独自计算自己的任务且这种方法易于实现缺点就…

【Python基础篇】一篇文章入门Python,进入Python的世界

文章目录 0.前言1.打印(Hello,World)2.创建变量3.打印升级3.1 打印一句话中间加变量3.2 sep设置分隔符3.3 end和换行 4. 注释 0.前言 大家好,我是小辰,前几天做了个重大的决定,学习python。 首先&#xff0…

博途S7-1500PLC“虚轴“编程应用

1、CODESYS如何添加虚轴 如何添加虚轴(AM400PLC)-CSDN博客文章浏览阅读164次。EtherCAT运动控制总线启用的时候,选择EtherCAT总线任务周期。选择好后,选择点击添加。https://rxxw-control.blog.csdn.net/article/details/139898985虚轴是利用软件算法实现的运动控制轨迹规划…

外挂级OCR神器:免费文档解析、表格识别、手写识别、古籍识别、PDF转Word

智能文档解析:大模型友好的文档解析工具 PDF转Markdown 支持将任意格式的文件(图片、PDF、Doc/Docx、网页等)解析为Markdown或Json格式,以对LLM友好的方式呈现。 更高速度:100页PDF最快1.5s完成解析 更大…

DEBOPIE框架:打造最好的ChatGPT交易机器人

本文介绍了如何利用 DEBOPIE 框架并基于 ChatGPT 创建高效交易机器人,并强调了在使用 AI 辅助交易时需要注意的限制以及操作步骤。原文: Build the Best ChatGPT Trading Bots with my “DEBOPIE” Framework 如今有大量文章介绍如何通过 ChatGPT 帮助决定如何以及在…

Hi3861 OpenHarmony嵌入式应用入门--TCP Server

本篇使用的是lwip编写tcp服务端。需要提前准备好一个PARAM_HOTSPOT_SSID宏定义的热点,并且密码为PARAM_HOTSPOT_PSK LwIP简介 LwIP是什么? A Lightweight TCP/IP stack 一个轻量级的TCP/IP协议栈 详细介绍请参考LwIP项目官网:lwIP - A Li…

6.7、函数的分文件编写

mian函数部分代码 #include <iostream> using namespace std; #include <string> #include "swap.h"//函数的分文件编写 //实现两个数字进行交换的函数//函数的声明 //void swap(int a,int b); //函数的定义 //void swap(int a, int b) //{ // int temp…

9. Revit API UI: UIView、UIDocument、框选聚焦

9. Revit API UI: UIView、UIDocument、框选聚焦 UI命名空间下的API&#xff0c;到这里差不多就要讲完了&#xff0c;同Application那篇所讲的几个类与接口&#xff0c;都是带UI的对应了一个不带UI的&#xff0c;如UIApplication和Application&#xff0c;作用呢&#xff0c;也…

Python基于PyQt5和卷积神经网络分类模型(CNN分类算法)实现时装类别识别系统GUI界面项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档视频讲解&#xff09;&#xff0c;如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 随着深度学习技术的发展&#xff0c;计算机视觉领域取得了显著的进步&#xff0c;特别是在图像分类、目…

Java8环境安装(jdk1.8安装)详细教程

Java 8环境安装&#xff08;jdk1.8安装&#xff09;详细教程 Java 8&#xff08;也称为JDK 1.8&#xff09;&#xff0c;是Oracle公司于2014年3月发布的一个重要的Java语言版本。这个版本自发布以来&#xff0c;因其众多的新特性和改进&#xff0c;被认为是Java语言发展历程中…

渗透测试之注入

命令注入 命令注入相关分隔符&#xff1a; 字符说明;仅限Linux环境&#xff0c;用于隔开命令&#xff0c;按顺序执行|前面命令的输出结果作为后面命令的输入内容||前提是前面的命令执行失败&#xff0c;和&&号相反&前后两条命令依次执行&&前提是前面的命…