刘玉琪
跟随
台湾人工智能学院
一、说明
上一篇介绍了基于密度的分群方法——DBSCAN,本篇会介绍另一个分群方法——Mean Shift,与DBSCAN一样不需要预先知道欲分群的数量,而对于分群的形状也没有限制。
然而,这个方法是基于核密度估计(kernel Density Estimation)的演算。可以想象数据是从同一个机率分布的数据集抽取的,而KDE(Kernel Density Estimation)的方法就是去估计数据的分布情况。 Mean Shift算法在许多领域都有成功的应用,例如图像分割、物体追踪等。下面将详细介绍该方法的基本概念、演算法、以及算法实作。
二、基本概念
Mean Shift主要的思想是假设数据集的密度以多个合成的核函数分布,然后随核密度分布,而数据集的所有点只要沿着密度对应的方向移动,直到位于最近最大密度的位置,意即计算密度估计曲线的最大值,便能将数据分群。
- 核密度估计(kernel Densityestimation)
利用核函数(kernel)来得出数据点x_1, x_2, … , x_n的分布来稀疏密度的分布曲线(机率分布),所以对一个数据点x来说,机率的估计可以写成
K为核函数(核函数),d为维度,h为带宽(带宽)。不同的h对核密度估计有很大的影响。太小的h会使得KDE的峰值为数据集的所有点(自成一类);手工则可以缩短一个(休闲一类)。
左图为数据集;右图为验证密度估计曲线带宽为2
左图的带宽为0.05,右图的带宽为5
三、关于核函数(kernel function)
核函数一般以零为中心点的函数,表示为
c_{k, d}是正规化参数,使得函数的积分值为1
最常见的是高斯函数,定义为
常用的核函数(kernel function)。来源:https://en.wikipedia.org/wiki/Kernel_(statistics)
Mean Shift 算法会沿着 KDE 的轻微方向寻找机率上升,因此考虑
令g(s) = -k'(s),则
前一项为核函数,后一项则为均值平移向量
利用迭代的方式更新中心点:
- 计算当前的均值平移向量,m_h(center_old)
- 中心点沿平均偏移量移动做为新的中心点,意即center_new = center_old +mean shift。
直至收敛以找到准确估计收敛的位置。
四、演算
输入:资料集D,以及带宽bandwidth
输出:目标分群集合Clusters
- 从附带分群的数据点中选择一个起始点做为中心。
2.将距离中心点小于带宽的数据点分为同群,记为集合M。
红色点为集合M里的元素
3. 计算从中心点到集合M中每个元素的计算,并做计算平均相加得到平均偏移计算均值 平移向量。
橘色向量即为均值平移向量
4. 中心点沿线平均偏移允许移动做为新的中心点,意即center = center +mean shift。
橘色点为新的中心点(会往KDE的顶部方向移动)
5. 重复步骤2、3、4,直到中心点不再动趋势(否则找到局部极大值)。若该群的中心点已归于先前所分的群中,则将两个群合并为同一群。
6. 重复以上步骤直至所有点均已完成财务状况。
五、算法实操代码
使用Sklearn.cluster.MeanShift套件:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn import datasets
#create datasets
X,y = datasets.make_blobs(n_samples=50, centers=3, n_features=2, random_state= 20, cluster_std = 1.5)
#estimate bandwidth
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=1000)
#Mean Shift method
model = MeanShift(bandwidth = bandwidth, bin_seeding = True)
model.fit(X)
labels = model.fit_predict(X)
#results visualization
plt.figure()
plt.scatter(X[:,0], X[:,1], c = labels)
plt.axis('equal')
plt.title('Prediction')
plt.show()
右图预测的结果概率会有所不同,由此估计带宽为 2.92
用于影像分割 …
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import rescale
from sklearn.cluster import MeanShift, estimate_bandwidth
import cv2
#load image
img = cv2.imread('AIA.png')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = rescale(img, 0.2)
rows, cols, chs= img.shape
#convert image shape [rows, cols, 3] into [rows*cols, 3]
feature_img = np.reshape(img, [-1, 3])
#estimate bandwidth
bandwidth = estimate_bandwidth(feature_img, quantile=0.2, n_samples=1000)
#Mean Shift method
model = MeanShift(bandwidth = bandwidth, bin_seeding = True)
model.fit(feature_img)
labels = model.fit_predict(feature_img)
#results visualization
fig = plt.figure(figsize = (20, 12))
ax = fig.add_subplot(121)
ax1 = fig.add_subplot(122)
ax.imshow(img)
ax1.imshow(np.reshape(labels, [rows, cols]))
plt.show()