本篇文章是博主在人工智能等领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对人工智能等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在AI学习笔记:
AI学习笔记(12)---《H-K算法介绍与MATLAB实现》
H-K算法介绍与MATLAB实现
目录
1. 前言
2.相关知识
2.1H-K 算法分类模型
2.2样本的规范化
2.3最小二乘准则函数
2.4伪逆
3.H-K 算法原理
4.编程实现
4.1伪逆的计算MATLAB代码
4.2 H-K 算法求解MATLAB代码
4.3 测试输入
1. 前言
理解并掌握 H-K 算法的基本原理,并针对 H-K算法实现需要掌握的前置知识进行练习实现。
掌握 H-K 算法的求解步骤,并编程实现一个使用 H-K 算法求解分类模型的例子。
2.相关知识
2.1H-K 算法分类模型
H-K 算法与单层感知器算法具有一定的相似性,都属于线性分类的几何分类方法,即要求输入的样本集本身是线性可分的。但不同的是,H-K 算法对于线性不可分的样本集可以进行判别,这一点是单层感知器所不具备的。
H-K 算法的出发点是以均方误差为准则函数而建立起来的。H-K 算法的分类模型与单层感知器一样,引入权值向量 w∈Rn
来作为分类模型, H-K 算法学习的过程就是求解 w∈Rn
,如果样本集 x∈X
线性可分为两个集合X+,X−
,那么应该存在一个权值向量 w
使得
其中,我们将 y
称为线性可分类的标签值。
2.2样本的规范化
我们将标签值 y
值为负的样本称为负样本,如果我们对样本集进行这样的操作:
那么分类问题可以写为如下形式:
wx′>0
此时,由x′
构成的经过以上处理过的样本集称为规范化的样本向量。此时,分类问题变成了求解不等式 wx′>0
的问题。
2.3最小二乘准则函数
求解不等式 wx′>0
的问题可以进一步转化为 wx′=b(b>0)
的问题,同时对所有的样本进行联立,得到方程组 wX=B
,其中向量 B
的每个分量 bi>0
。针对这样一个有多个可能解的方程组,我们可以求它的最小二乘解。即使得
取极小值(上式即为 H-K 算法的最小二乘准则函数)。此时可以使用梯度下降法或者伪逆法进行求解。为了下一关更好的介绍 H-K 算法,这里你需要了解伪逆的求法。
2.4伪逆
伪逆(pseudo inverse)是对于矩阵逆的一种推广,满足一定性质的矩阵都可以成为矩阵A的伪逆。伪逆具有存在且唯一的良好性质,可以用于解决最小二乘和最小范数问题。矩阵 X
的伪逆计算方式如下:
3.H-K 算法原理
通过对分类问题的一步步抽象,H-K 算法求解分类问题变成了一个线性方程组求解的问题,即:
其中 w
是要求解的参数,X
是所有样本规范化之后得到矩阵,B
是所有分量大于 0 的向量。
如果对于任意 B
这个方程组都无解,那么说明样本集线性不可分。
以上线性方程式可以通过最小二乘变成一个优化问题,即找到 w
使得
尽可能的小。用 J
对 w
求偏导,
其中,X#=(XTX)−1XT
是 X
的伪逆。
因此,求 w
就相当于求满足前述条件的 B
。
H-K 算法步骤
H-K 算法步骤也就是通过迭代一步一步的逼近出理想的 B
。
- 步骤 1:随机初始化
B
,其每个分量均为正值; - 步骤 2:计算
w=X#B
; - 步骤 3:计算
e=Xw−b
; - 步骤 4:根据
e
的取值情况作进一步处理:- 若分量全部大于等于 0,算法结束,求解完成;
- 若分量有正有负,则跳到步骤 5 继续迭代;
- 若分量全部小于 0,停止迭代,样本线性不可分。
- 步骤 5:完成更新计算 并跳转到步骤 2; 其中
∣e∣
是向量的模运算,即二范数。
4.编程实现
4.1伪逆的计算MATLAB代码
function [result] = step1_mission(samples, labels)
for i = 1:length(labels) % 遍历每一个标签
if labels(i) < 0 % 如果标签小于0
samples(i,:) = -samples(i,:); % 对应样本取反
end
end
result = inv(samples.' * samples) * samples.'; % 计算伪逆矩阵
end
4.2 H-K 算法求解MATLAB代码
function [w, b, flag] = step2_mission(samples, labels)
samples_X = diag(labels) * samples; % 将样本矩阵与标签对角矩阵相乘,生成样本变换矩阵 samples_X
X = pinv(samples_X); % 计算样本变换矩阵的伪逆矩阵 X
[~, num] = size(labels); % 获取标签的数量,即样本数量 num
b = ones(num, 1); % 初始化长度为 num 的全 1 向量 b
while(1 == 1) % 无限循环,直到满足终止条件
w = X * b; % 计算权重向量 w
e = samples_X * w - b; % 计算误差向量 e
flagt = 0; % 初始化正误差计数器
flagf = 0; % 初始化负误差计数器
for i = 1:num % 遍历每一个样本
if (e(i) > -0.01) % 若误差大于 -0.01,正误差计数器加一
flagt = flagt + 1;
end
if (e(i) < -0.01) % 若误差小于 -0.01,负误差计数器加一
flagf = flagf + 1;
end
end
if (flagt == 0) % 若没有正误差,设置标志为 -1 并跳出循环
flag = -1;
break
end
if (flagf == 0) % 若没有负误差,设置标志为 1 并跳出循环
flag = 1;
break
end
b = b + 0.1 * (e + norm(e)); % 更新向量 b,步长为 0.1
end
end
4.3 测试输入
MATLAB终端输入下面指令
测试step1_mission(samples, labels)
samples = [-7.82 -4.58 -3.97; -6.68 3.16 2.71; 4.36 -2.91 2.09; 6.72 0.88 2.80; -8.64 3.06 3.50; -6.87 0.57 -5.45; 4.47 -2.62 5.76; 6.73 -2.01 4.18; -7.71 2.34 -6.33; -6.91 -0.49 -5.68; 6.18 2.81 5.82; 6.72 -0.93 -4.04; -6.25 -0.26 0.56; -6.94 -1.22 1.13; 8.09 0.20 2.25; 6.81 0.17 -4.15; -5.19 4.24 4.04; -6.38 -1.74 1.43; 4.08 1.30 5.33; 6.27 0.93 -2.78];
labels = [-1.0000 1.0000 1.0000 1.0000 -1.0000 -1.0000 1.0000 1.0000 -1.0000 -1.0000 1.0000 1.0000 -1.0000 -1.0000 1.0000 1.0000 1.0000 -1.0000 1.0000 1.0000];
s1 = samples(1:6,1:end);
l1 = labels(1:6);
s2 = samples(1:10,1:end);
l2 = labels(1:10);
s3 = samples(10:end,1:end);
l3 = labels(10:end);
s4 = samples;
l4 = labels;
disp('task1');
r1 = step1_mission(s1, l1)
disp('task2');
r1 = step1_mission(s2, l2)
disp('task3');
r1 = step1_mission(s3, l3)
disp('task4');
r1 = step1_mission(s4, l4)
测试step2_mission(samples, labels)
samples = [-7.82 -4.58 -3.97; -6.68 3.16 2.71; 4.36 -2.91 2.09; 6.72 0.88 2.80; -8.64 3.06 3.50; -6.87 0.57 -5.45; 4.47 -2.62 5.76; 6.73 -2.01 4.18; -7.71 2.34 -6.33; -6.91 -0.49 -5.68; 6.18 2.81 5.82; 6.72 -0.93 -4.04; -6.25 -0.26 0.56; -6.94 -1.22 1.13; 8.09 0.20 2.25; 6.81 0.17 -4.15; -5.19 4.24 4.04; -6.38 -1.74 1.43; 4.08 1.30 5.33; 6.27 0.93 -2.78];
labels = [-1.0000 1.0000 1.0000 1.0000 -1.0000 -1.0000 1.0000 1.0000 -1.0000 -1.0000 1.0000 1.0000 -1.0000 -1.0000 1.0000 1.0000 1.0000 -1.0000 1.0000 1.0000];
w1 = [1.0 1.0 1.0];
w2 = [-1.0 1.0 -1.0];
w3 = [-1.0 0.0 2.0];
w4 = [1.0 1.0 -1.0];
b1 = 1.0;
b2 = 2.0;
b3 = 3.0;
b4 = 4.0;
eta1 = 0.1;
eta2 = 0.2;
eta3 = 0.3;
eta4 = 0.5;
s1 = samples(5:15,1:end);
l1 = labels(5:15);
s2 = samples(1:10,1:end);
l2 = labels(1:10);
s3 = samples(10:end,1:end);
l3 = labels(10:end);
disp('task1');
[w,b,f] = step2_mission(s1, l1)
disp('task2');
[w,b,f] = step2_mission(s2, l2)
disp('task3');
[w,b,f] = step2_mission(s3, l3)
文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者私信联系作者。