机器学习——感知机模型

机器学习系列文章

入门必读:机器学习介绍


文章目录

  • 机器学习系列文章
  • 前言
  • 1. 感知机
    • 1.1 感知机定义
    • 1.2 感知机学习策略
  • 2. 代码实现
    • 2.1 构建数据
    • 2.2 编写函数
    • 2.3 迭代
  • 3. 总结


前言

大家好,大家好✨,这里是bio🦖。这次为大家带来的是感知机模型。下面跟我一起来了解感知机模型吧!

感知机 (Perceptron) 是二类分类的线性分类模型 ,其输入为实例的特征向量 ,输出为实例的类别 ,分别为 +1-1。1957年,由康奈尔航空实验室(Cornell Aeronautical Laboratory)弗兰克·罗森布拉特 (Frank Rosenblatt)提出。它可以被视为一种最简单形式的前馈神经网络,是一种二元线性分类器。在人工神经网络领域中,感知机也被指为单层的人工神经网络,以区别于较复杂的多层感知机。在这里插入图片描述


1. 感知机

1.1 感知机定义

感知器使用特征向量来表示二元分类器,把矩阵上的输入 x \mathcal{x} x(实数值向量)映射到输出值 y \mathcal{y} y 上(一个二元的值)。
f ( x ) = { + 1 ,    i f   w ⋅ x + b > 0 − 1 ,    e l s e f(x) = \begin{cases} +1,\,\, if\,w\cdot x+b>0\\ -1,\,\,else\\ \end{cases} f(x)={+1,ifwx+b>01,else

w \mathcal{w} w 是实数的表示权重的向量, w ⋅ x \mathcal{w} \cdot \mathcal{x} wx 是点积。 b \mathcal{b} b 是偏置,一个不依赖于任何输入值的常数。


1.2 感知机学习策略

假设训练数据集是线性可分的 ,如下图所示。感知机学习的目标是求得一个能够将训练集正实例点和负实例点完全正确分开的直线 L \mathcal{L} L。 为了找出这样的超平而 , 即确定感知机模型参数 w \mathcal{w} w b \mathcal{b} b ,需要确定一个学习策略 , 即定义损失函数并将损失函数极小化 。

损失函数的一个选择是误分类数据点的数量 。 但是这样的损失函数不是参数 w \mathcal{w} w b \mathcal{b} b 的连续可导函数,不易优化 。 损失函数的另一个选择是误分类数据点到直线 L \mathcal{L} L 的总距离。感知机所采用的就是后者 。
在这里插入图片描述

  • 对于错误分类的数据点 ( x i , y i ) (\mathcal{x_i, y_i}) (xi,yi),总有:
    − y i ⋅ ( w ⋅ x i + b ) > 0 \mathcal{-y_i\cdot(w\cdot x_i+b) > 0} yi(wxi+b)>0
  • 错误分类点到直线 L \mathcal{L} L 的距离为:
    1 ∥ w ∥ ∣ w ⋅ x i + b ∣ \mathcal{\frac{1}{\|w\|} \vert w \cdot x_i +b \vert} w1wxi+b
  • 假设直线 L \mathcal{L} L 的误分类点集合为 m \mathcal{m} m , 那么所有误分类点到直线 L \mathcal{L} L 的总距离为:
    − 1 ∥ w ∥ ∑ i m y i ⋅ ( w ⋅ x i + b ) \mathcal{-\frac{1}{\|w\|} \sum_{i}^{m}y_i\cdot(w \cdot x_i +b)} w1imyi(wxi+b)
  • 不考虑 1 ∥ w ∥ \mathcal{\frac{1}{\|w\|}} w1,感知机的损失函数为:
    K ( w , b ) = − ∑ i m y i ⋅ ( w ⋅ x i + b ) \mathcal{K(w, b)= - \sum_{i}^{m}y_i\cdot(w \cdot x_i +b)} K(w,b)=imyi(wxi+b)

显然,损失函数 K \mathcal{K} K 是非负的。如果没有误分类点,损失函数值是 0 。而且,误分类点越少,误分类点离超平面越近,损失函数值就越小 。

而感知机的优化算法采用的是随机梯度下降算法 (Stochastic Gradient Descent)(后续更新),用误分类数据驱动损失函数 K \mathcal{K} K 不断减小。本文将采取二维数据,来展示感知机的工作过程。


2. 代码实现

2.1 构建数据

首先创建二维数据,并用线性回归模型拟合出直线 L \mathcal{L} L。代码如下:

import numpy as np
from sklearn.datasets import make_classification
from sklearn import linear_model
import matplotlib.pyplot as plt
import random

# two-dimention data
td_data = make_classification(n_samples=20, n_features=2, n_informative=2, n_redundant=0, n_clusters_per_class=1, random_state=24)

td_data = list(td_data)
td_data[1] = np.array([1 if i == 0 else -1 for i in td_data[1]])
td_data = tuple(td_data)

# visualized data
fig, ax = plt.subplots()
scatter = ax.scatter(td_data[0][:, 0], td_data[0][:, 1], c=td_data[1], alpha=0.6, cmap="cool")
legend_1 = ax.legend(*scatter.legend_elements(), title="Classes", loc="upper left")
ax.add_artist(legend_1)
ax.set_xlabel("Feature_1")
ax.set_ylabel("Feature_2")

# add minimal residual sum of squares line as gold standard
reg = linear_model.LinearRegression()

# reshape for model fitting
reg.fit(td_data[0][:, 0].reshape(-1, 1), td_data[0][:, 1].reshape(-1, 1))
print(f"the intercept is {reg.intercept_[0]} and the coefficient is {reg.coef_[0][0]}")
formula = f"f(x)={round(reg.coef_[0][0], 2)}*x1-x2{round(reg.intercept_[0], 2)}"

# create a x axis for plotting
create_x_axis = np.linspace(min(td_data[0][:, 0]), max(td_data[0][:, 0]), 100).reshape(-1, 1)
predicted_value = reg.predict(create_x_axis)

ax.plot(create_x_axis, predicted_value, c="gold", alpha=0.8, label=formula)
handles, labels = ax.get_legend_handles_labels()
legend_2 = ax.legend(handles, labels, loc="lower right")

plt.show()

根据代码输出的结果可知,由线性回归模型拟合出的直线 L = 0.53 x 1 + x 2 − 0.6 \mathcal{L = 0.53x_1+x_2-0.6} L=0.53x1+x20.6
在这里插入图片描述

2.2 编写函数

接下来编写可复用的函数,减少代码编写量。partial_derivative_w 函数用于对变量 w \mathcal{w} w 求偏导,partial_derivative_b 函数用于对变量 b \mathcal{b} b 求偏导,decision_funtion 函数用于决策是否继续进行迭代,plot_function 函数绘制迭代结果图。

# take the partial derivative of w and b
def partial_derivative_w(data_point, label_point):
    # feature_1 * feature_2 * y
    result_w_1 = data_point[0] * label_point
    result_w_2 = data_point[1] * label_point
    return [result_w_1, result_w_2]

def partial_derivative_b(label_point):
    # label
    result_b = label_point
    return result_b

# decision function. w and b will be change if exist data point make 
def decision_funtion(weight_1, weigh_2, intercept):
    # if y*(w*x+b) < 0, the data point is wrongly classified.
    result = td_data[1] * ((td_data[0][:, 0] * weight_1) + (td_data[0][:, 1] * weight_2) + intercept)

    if len(result[np.where(result < 0)]) != 0:
        print(result)
        wrong_dp_index = np.where(result == result[np.where(result < 0)][0])[0][0]
        wrong_dp = td_data[0][wrong_dp_index]
        wrong_lb = td_data[1][wrong_dp_index]
        return [True, wrong_dp, wrong_lb]
    else:
        print("interation end")
        return [False, None, None]

def plot_function(weight_1, weight_2, intercept):
    fig, ax = plt.subplots()
    scatter = ax.scatter(td_data[0][:, 0], td_data[0][:, 1], c=td_data[1], alpha=0.6)
    ax.legend(*scatter.legend_elements(), title="Classes")
    ax.set_xlabel("Feature_1")
    ax.set_ylabel("Feature_2")
    b = intercept/weight_2
    hyperplane = [(-(weight_1/weight_2) * i) - b for i in create_x_axis]
    ax.plot(create_x_axis, hyperplane, c='green', alpha=0.5)
    plt.show()

2.3 迭代

设置特征一的初始权重为 0,特征二的初始权重为 0,初始截距为 0,学习率为 0.1,迭代次数为1000次,随机从数据中选择一个数据点作为分类错误数据点后开始迭代。

# initiate weight, intercept and learning rate
weight_1 = 0
weight_2 = 0
intercept = 0
learn_rate = 0.1


# iteration times
iteration_times = 1000

# random value in two dimention data
random_index = random.randint(0, 19)
feature_point = td_data[0][random_index]
label_point = td_data[1][random_index]
# it is not correctly classified for any data point resulting in loss function equte 0.

for iteration in range(iteration_times):
    # w1 = w0 + (learn_rate * y * x)
    new_weight_1 = weight_1 + (learn_rate * partial_derivative_w(feature_point, label_point)[0])
    new_weight_2 = weight_2 + (learn_rate * partial_derivative_w(feature_point, label_point)[1])
    # b1 = b0 + learn_rate * y
    new_intercept = intercept + (learn_rate * partial_derivative_b(label_point))
    
    # decision
    decision_condition, wrong_dp, wrong_lp = decision_funtion(new_weight_1, new_weight_2, new_intercept)
    
    if decision_condition:
        weight_1 = new_weight_1
        weight_2 = new_weight_2
        intercept = new_intercept
        # wrong data point
        feature_point = wrong_dp
        label_point = wrong_lp
        print(f"The {iteration + 1} iteration\tweight_1={weight_1}\tweight_2={weight_2}\tintercept={intercept}\n")
        plot_function(weight_1, weight_2, intercept)
    else:
        print(f"The {iteration + 1} iteration\tweight_1={new_weight_1}\tweight_2={new_weight_2}\tintercept={new_intercept}\n")
        plot_function(new_weight_1, new_weight_2, new_intercept)
        break

迭代结果如下表所示,在迭代到第八次的时候,感知机模型成功将所有数据点正确分类。

迭代次数效果图片
1在这里插入图片描述
2在这里插入图片描述
3在这里插入图片描述
4在这里插入图片描述
5在这里插入图片描述
6在这里插入图片描述
7在这里插入图片描述
8在这里插入图片描述

3. 总结

以上就是本次更新的全部内容。关于本次内容有一下缺点:

  • 用于迭代的错误分类数据点没有被绘制出来。
  • 由于跳过了大部分数学知识,内容衔接没有做好。
  • 迭代数据点完全随机,复现过程可能不同。

后续将会更新:

  • 感知机模型的数学解释
  • 随机梯度算法的解释
  • 可视化迭代过程的错误分类数据点(可能)

喜欢本次内容的小伙伴麻烦👍点赞+👍关注。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/438250.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

lvs+keepalive

虚拟路由冗余协议(Virtual Router Redundancy Protocol&#xff0c;简称VRRP) VRRP能够在不改变组网的情况下&#xff0c;将多台路由器虚拟成一个虚拟路由器&#xff0c;通过配置虚拟路由器的IP地址为默认网关&#xff0c;实现网关的备份。 协议版本: VRRPv2&#xff08;常用&…

日常生活小技巧 -- USR-TCP232-M4(读取IP)

下载&#xff1a;[Configuration Software]USR-TCP232-M4_V2.3.4.106

Vue2+3

vue相关介绍 Vue的两种使用方式&#xff1a; 1、vue核心包开发 场景&#xff1a;局部模块改造 2、vue核心包&vue插件工程化开发 场景&#xff1a;整站开发 概念&#xff1a;vue是用于构建用户界面的渐进式框架 创建vue实例 创建Vue实例&#xff0c;初始化渲染步骤&am…

Unity类银河恶魔城学习记录8-4 P80 Blackhole ability state源代码

Alex教程每一P的教程原代码加上我自己的理解初步理解写的注释&#xff0c;可供学习Alex教程的人参考 此代码仅为较上一P有所改变的代码 【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili Entity.cs using System.Collections; using System.Collections.Generic…

Windows系统获取硬盘读写速度(固态、机械通用)

管理员身份运行cmd&#xff0c;注意不要用powershell&#xff0c;可能会闪退 运行命令 winsat disk -drive C 获得C盘的读写速度 同理 winsat disk -drive D 可获得D盘的读写性能

【产品文档分类及撰写路径】

一、产品文档的分类 产品文档根据所处阶段和面相对象的差异大致可以分为三类&#xff1a; 商业需求文档 (BRD)&#xff1a;商业需求文档是面向公司高层和项目组&#xff0c;目的是为了获得资金、资源支持。市场需求文档 (MRD)&#xff1a;市场需求文档是面向运营和市场销售人员…

Python读取.csv/.txt文件并画图

文章目录 读取.txt文件并画出每列数据读取.csv文件并画出每列数据python绘图的基本知识补充 读取.txt文件并画出每列数据 main.ipynb //注意文件类型为.ipynbimport matplotlib.pyplot as plt import numpy as npdata np.loadtxt(result.txt) fig plt.figure(dpi100,figsize…

哪个牌子宠物空气净化器好?质量好的宠物空气净化器推荐

即使我们很爱自家的宠物&#xff0c;但我们也无法否认处理房间里飘荡的宠物毛发和皮屑&#xff0c;以及那些令人不快的气味&#xff08;比如地毯上的意外和垃圾桶里的气味&#xff09;的挑战。对于过敏患者来说&#xff0c;这几乎是无法忍受的。寻找有效的方法来减少这些问题对…

四个有用的Android开发技巧

一. 通过堆栈快速定位系统版本 这个地方主要分享大家两个个技巧&#xff0c;通过问题堆栈简快速定位当前系统版本&#xff1a; 1. 快速区分当前系统版本是Android10以下&#xff0c;还是Android10及以上&#xff1b; 首先Android10及以上引入了一个新的服务Service&#xff…

程序异常结束退出 无输出 无显式报错日志 爆栈

需求 开一个很大的数组&#xff08;300万&#xff09; ❗ 错误示例 #include <stdio.h>int main() {int size 3000000;int a[size];a[size-1] 999;printf("%d",a[size-1]);return 0; }&#x1f60b; 解决方案 局部变量存储在栈空间 &#xff08;较小&…

【数据结构】哈希

在一个数据序列中查找某一个数据元素&#xff0c;是数据管理时经常涉及的&#xff0c;通常以比较的方式来完成&#xff0c;典型的案例有无序序列的暴力查找&#xff08;O(N)&#xff09;、有序序列的二分查找&#xff08;O(logN)&#xff09;、平衡搜索树&#xff08;O(logN)&a…

CentOS网络故障排查秘笈:实战指南

前言 作为一名热爱折腾 Linux 的技术达人&#xff0c;我深知网络故障会让人抓狂&#xff01;在这篇文章里&#xff0c;我和你分享了我的心得体会&#xff0c;从如何分析问题、识别瓶颈&#xff0c;到利用各种神器解决网络难题。不管你是新手小白还是老鸟大神&#xff0c;这里都…

Get了!原来朋友圈定时发布如此简单!

你是不是也有这样的烦恼&#xff1f;微信号太多&#xff0c;有时候会顾不上发朋友圈&#xff1f; 别担心&#xff01;微信管理系统来帮你解决这个问题&#xff0c;实现朋友圈定时发布&#xff01;让我们一起来看看如何利用微信管理系统的强大功能来提升我们的社交媒体效率吧。…

Material Studio 中 DMol3 计算材料吸附能

1.先导入Cif文件 2.切表面 3.沿着你要切的晶面切 4.扩胞 5.加真空层&#xff08;一般加10埃&#xff09; 现在就是这样的了 6.然后对其结构优化&#xff08;高斯几何优化&#xff09; 7.再在体系上加原子或者想要的材料 8.Outmal文件中最后的Ef就是整个体系的能量&#xff0…

WPF中如何设置自定义控件(四)

在上一节讲到过将自定义界面进行左右分割,左边作为登录的用户信息和菜单&#xff1b;右边作为详情展示界面。 这节准备在上一节的基础上对左侧部分进行细分&#xff0c;完善功能&#xff0c;左侧的内容准备做两部分填充&#xff1a;一是登录用户的信息&#xff0c;一是菜单。因…

牛客网 华为机试 坐标移动

本题是需要将输入的字符串&#xff0c;得到移动位置的信息&#xff0c;同时要判断移动信息的合法性。 所以我们可以考虑先通过正则表达式过滤得到正确的字符串。 正确的字符串应该以ADWS其中一个字母开头&#xff0c;然后后面接着1个或者2个&#xff08;0-9&#xff09;的数字。…

护眼台灯哪个品牌最好?2024五款主流台灯实测推荐!

对很多家长来说&#xff0c;孩子的健康比什么都重要。不过现在的儿童青少年近视率却非常高&#xff0c;正因为如此护眼台灯就一直是家长十分关注的灯具。可如今市场中却存在很多劣质产品&#xff0c;忽视产品做工以及选材用料等问题&#xff0c;导致照明体验感差、使用不方便&a…

大华IPC网络摄像机如何保存视频

一、背景 通常网络相机&#xff08;IPC&#xff09;不会自带存储功能&#xff0c;需要接入录像机&#xff08;NVR&#xff09;进行保存。 其中NVR也分软件存储及硬件存储&#xff0c;这里不提&#xff0c;这边单独说FTP存储 二、配置前提 要配置FTP存储需要&#xff1a;①网络…

java017 - Java抽象类

1、概述 一般情况&#xff0c;动物是抽象的&#xff0c;所以不能被new,比如你在Animal类中定义一个成员方法eat,你不能定义具体内容&#xff0c;比如吃鱼或者吃白菜&#xff0c;因为动物是抽象的。 一个没有方法体的方法&#xff0c;应该定义为抽象方法&#xff0c;而类中如果…

PaddleSeg分割框架解读[01] readme解读

简介 PaddleSeg是基于飞桨PaddlePaddle的端到端图像分割套件,内置45+模型算法及140+预训练模型,支持配置化驱动和API调用开发方式,打通数据标注、模型开发、训练、压缩、部署的全流程,提供语义分割、交互式分割、Matting、全景分割四大分割能力,助力算法在医疗、工业、遥…