鲁棒线性模型估计(Robust linear model estimation)

鲁棒线性模型估计

  • 1.RANSAC算法
    • 1.1 算法的基本原理
    • 1.2 迭代次数N的计算
    • 1.3 参考代码
  • 参考文献

当数据中出现较多异常点时,常用的线性回归OLS会因为这些异常点的存在无法正确估计线性模型的参数:
W = ( X T X ) − 1 X T Y \qquad \qquad W=(X^TX)^{-1}X^TY W=(XTX)1XTY

此时就需要寻找更鲁棒的方法过滤掉异常点,以获得更准确的模型预测参数。

1.RANSAC算法

1.1 算法的基本原理

RANSAC(random sample consensus,随机抽样一致算法)基本原理如下:
———————————————————————————————————————
变量
N \quad N N:迭代次数;
s \quad s s:每次迭代过程,用于估计模型参数最小的样本点数量;
τ \quad \tau τ:每次迭代过程,用于辨别正常点和异常点的阈值;
d \quad d d:每次迭代过程,对辨别到的正常点进行线性拟合所需的最小个数;
b e s t E r r \quad bestErr bestErr:最佳误差,用于判断是否更新模型参数,初始值可设为无穷大值。
b e s t M o d e l \quad bestModel bestModel:最佳模型。

算法

进行N次迭代:

  • 随机抽取 s s s个样本数据,对这 s s s个数据进行线性拟合,得到初始模型 m o d e l 1 model_1 model1
  • 使用 m o d e l 1 model_1 model1,对 s s s个数据中各数据的自变量进行预测,获得对应的 y ^ \hat{y} y^,计算误差值 ( y − y ^ ) 2 (y-\hat{y})^2 (yy^)2。如果误差值小于 τ \tau τ,则为正常点,否则为异常点;
  • 若正常点个数小于 d d d,直接进入下一次循环;
  • 使用正常点进行线性拟合,获取模型 m o d e l 2 model_2 model2
  • 对判断出的正常点求均方差 E r r = 1 n ∑ ( y − y ^ ) 2 Err=\cfrac{1}{n}\sum(y-\hat{y})^2 Err=n1(yy^)2 n n n为正常点的个数;
  • 如果 E r r < b e s t E r r Err<bestErr Err<bestErr,那么设置 b e s t E r r = E r r bestErr=Err bestErr=Err,同时设置 b e s t M o d e l = m o d e l 2 bestModel=model_2 bestModel=model2
    ————————————————————————————————————————

1.2 迭代次数N的计算

迭代次数 N N N的计算满足下式:

1 − ( 1 − ( 1 − e ) s ) N = p \qquad 1-(1-(1-e)^s)^N=p 1(1(1e)s)N=p

其中:
e \qquad e e为数据中异常点的概率;
p \qquad p p N N N次迭代过程中,至少存在一次采样数据全为正常点的期望概率,如设置为0.99。

因此有:
N = log ⁡ ( 1 − p ) log ⁡ ( 1 − ( 1 − e ) s ) \qquad N=\cfrac{\log(1-p)}{\log(1-(1-e)^s)} N=log(1(1e)s)log(1p)

1.3 参考代码

代码来源于https://en.wikipedia.org/wiki/Random_sample_consensus:

from copy import copy
import numpy as np
from numpy.random import default_rng
rng = default_rng()


class RANSAC:
    def __init__(self, n=10, k=100, t=0.05, d=10, model=None, loss=None, metric=None):
        self.n = n              # `n`: Minimum number of data points to estimate parameters
        self.k = k              # `k`: Maximum iterations allowed
        self.t = t              # `t`: Threshold value to determine if points are fit well
        self.d = d              # `d`: Number of close data points required to assert model fits well
        self.model = model      # `model`: class implementing `fit` and `predict`
        self.loss = loss        # `loss`: function of `y_true` and `y_pred` that returns a vector
        self.metric = metric    # `metric`: function of `y_true` and `y_pred` and returns a float
        self.best_fit = None
        self.best_error = np.inf

    def fit(self, X, y):
        for _ in range(self.k):
            ids = rng.permutation(X.shape[0])

            maybe_inliers = ids[: self.n]
            maybe_model = copy(self.model).fit(X[maybe_inliers], y[maybe_inliers])

            thresholded = (
                self.loss(y[ids][self.n :], maybe_model.predict(X[ids][self.n :]))
                < self.t
            )

            inlier_ids = ids[self.n :][np.flatnonzero(thresholded).flatten()]

            if inlier_ids.size > self.d:
                inlier_points = np.hstack([maybe_inliers, inlier_ids])
                better_model = copy(self.model).fit(X[inlier_points], y[inlier_points])

                this_error = self.metric(
                    y[inlier_points], better_model.predict(X[inlier_points])
                )

                if this_error < self.best_error:
                    self.best_error = this_error
                    self.best_fit = better_model

        return self

    def predict(self, X):
        return self.best_fit.predict(X)

def square_error_loss(y_true, y_pred):
    return (y_true - y_pred) ** 2


def mean_square_error(y_true, y_pred):
    return np.sum(square_error_loss(y_true, y_pred)) / y_true.shape[0]


class LinearRegressor:
    def __init__(self):
        self.params = None

    def fit(self, X: np.ndarray, y: np.ndarray):
        r, _ = X.shape
        X = np.hstack([np.ones((r, 1)), X])
        self.params = np.linalg.inv(X.T @ X) @ X.T @ y
        return self

    def predict(self, X: np.ndarray):
        r, _ = X.shape
        X = np.hstack([np.ones((r, 1)), X])
        return X @ self.params


if __name__ == "__main__":

    regressor = RANSAC(model=LinearRegressor(), loss=square_error_loss, metric=mean_square_error)

    X = np.array([-0.848,-0.800,-0.704,-0.632,-0.488,-0.472,-0.368,-0.336,-0.280,-0.200,-0.00800,-0.0840,0.0240,0.100,0.124,0.148,0.232,0.236,0.324,0.356,0.368,0.440,0.512,0.548,0.660,0.640,0.712,0.752,0.776,0.880,0.920,0.944,-0.108,-0.168,-0.720,-0.784,-0.224,-0.604,-0.740,-0.0440,0.388,-0.0200,0.752,0.416,-0.0800,-0.348,0.988,0.776,0.680,0.880,-0.816,-0.424,-0.932,0.272,-0.556,-0.568,-0.600,-0.716,-0.796,-0.880,-0.972,-0.916,0.816,0.892,0.956,0.980,0.988,0.992,0.00400]).reshape(-1,1)
    y = np.array([-0.917,-0.833,-0.801,-0.665,-0.605,-0.545,-0.509,-0.433,-0.397,-0.281,-0.205,-0.169,-0.0531,-0.0651,0.0349,0.0829,0.0589,0.175,0.179,0.191,0.259,0.287,0.359,0.395,0.483,0.539,0.543,0.603,0.667,0.679,0.751,0.803,-0.265,-0.341,0.111,-0.113,0.547,0.791,0.551,0.347,0.975,0.943,-0.249,-0.769,-0.625,-0.861,-0.749,-0.945,-0.493,0.163,-0.469,0.0669,0.891,0.623,-0.609,-0.677,-0.721,-0.745,-0.885,-0.897,-0.969,-0.949,0.707,0.783,0.859,0.979,0.811,0.891,-0.137]).reshape(-1,1)

    regressor.fit(X, y)

    import matplotlib.pyplot as plt
    fig, ax = plt.subplots(1, 1)
    ax.set_box_aspect(1)

    plt.scatter(X, y)

    line = np.linspace(-1, 1, num=100).reshape(-1, 1)
    plt.plot(line, regressor.predict(line), c="peru")
    plt.show(block=True)

结果如下:
在这里插入图片描述

参考文献

[1] https://www.cse.psu.edu/~rtc12/CSE486/lecture15.pdf
[2] https://en.wikipedia.org/wiki/Random_sample_consensus
[3] Overview of the RANSAC Algorithm
[4] https://scikit-learn.org/stable/auto_examples/linear_model/plot_ransac.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/529484.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【docker】Docker 简介

Docker 简介 什么是虚拟化、容器化?为什么要虚拟化、容器化&#xff1f;虚拟化实现方式应用程序执行环境分层虚拟化常见类别虚拟机容器JVM 之类的虚拟机 常见虚拟化实现主机虚拟化(虚拟机)实现容器虚拟化实现容器虚拟化实现原理容器虚拟化基础之 NameSpace 什么是虚拟化、容器…

人体跟随小车(旭日x3派、yolov5、目标检测)

人体跟随小车&#xff08;yolov5、目标检测&#xff09; 前言最终结果接线实现注意 前言 上板运行的后处理使用cython封装了&#xff0c;由于每个版本的yolo输出的形状不一样&#xff0c;这里只能用yolov5-6.2这个版本。 ①训练自己的模型并部署于旭日x3派参考&#xff1a; ht…

RuntimeError: Library cublas64_12.dll is not found or cannot be loaded

运行guillaumekln/faster-whisper-large-v2模型进行语音识别的时候报错了 RuntimeError: Library cublas64_12.dll is not found or cannot be loaded 代码&#xff1a; from faster_whisper import WhisperModelmodel WhisperModel("H:\\model\\guillaumekln\\faster…

【C++】优先级队列(priority_queue)的用法与实现

目录 一、概念&#xff1a; 二、仿函数&#xff08;Functor&#xff09;&#xff1a; 概念&#xff1a; 应用&#xff1a; 三、底层实现&#xff1a; 基本操作&#xff1a; 完整代码&#xff1a; 测试示例&#xff1a; 一、概念&#xff1a; 优先级队列&#xff08;pri…

PostgreSQL入门到实战-第六弹

PostgreSQL入门到实战 PostgreSQL查询语句(三)官网地址PostgreSQL概述PostgreSQL中ORDER BY理论PostgreSQL中ORDER BY实操更新计划 PostgreSQL查询语句(三) 官网地址 声明: 由于操作系统, 版本更新等原因, 文章所列内容不一定100%复现, 还要以官方信息为准 https://www.post…

tcp的全连接队列和半连接队列满时,客户端再connect发生的情况

首先简单介绍下tcp的全连接队列(accept queue)和半连接队列(syn queue)&#xff0c; 1.客户端发起syn请求时&#xff0c;服务端收到该请求&#xff0c;将其放入到syn queue&#xff0c;然后回复acksyn给客户端。 2.客户端收到acksyn&#xff0c;再发送ack给服务端。 3. 服务端从…

3、最大池化maxinmum pooling

了解有关最大池化特征提取的更多信息。 简介 在第二课中,我们开始讨论卷积神经网络(convnet)的基础如何进行特征提取。我们了解了这个过程中的前两个操作是在带有 relu 激活的 Conv2D 层中进行的。 在这一课中,我们将看一下这个序列中的第三个(也是最后一个)操作:通过…

3dmax渲染十几个小时怎么办?3dmax怎么多机渲染

当使用3ds Max进行渲染作业时&#xff0c;如果发现单张图像的渲染时间长达十数小时&#xff0c;这可能是由于计算机硬件配置较低或渲染场景过于复杂所致。为了缩短渲染时间并提高效率&#xff0c;我们可以考虑采用多台计算机进行协同渲染。下面&#xff0c;让我们一起探讨如何通…

MyBatis操作数据库(2)

MyBatis XML配置文件 MyBatis开发有两种方式: 1.注解 2.xml 上面我们学习了注解的方式, 下面来学习xml的方式 使用MyBatis的注解方式, 主要是为了完成一些简单的增删改查功能, 而下面我们介绍的xml方式, 则一般用于写一些比较复杂的sql语句. MyBatis XML的方式需要以下两步: …

《荒野大镖客》游戏提示emp.dll文件丢失如何解决?

emp.dll它作为一种动态链接库&#xff08;DLL&#xff09;文件&#xff0c;在Windows操作系统中扮演着重要角色。当打开一个程序时&#xff0c;操作系统会将程序的代码和数据加载到内存中&#xff0c;并创建一个进程来运行该程序。在这个过程中&#xff0c;emp.dll负责将这些代…

OpenHarmony开发-连接开发板调试应用

在 OpenHarmony 开发过程中&#xff0c;连接开发板进行应用调试是一个关键步骤&#xff0c;只有在真实的硬件环境下&#xff0c;我们才能测试出应用更多的潜在问题&#xff0c;以便后续我们进行优化。本文详细介绍了连接开发板调试 OpenHarmony 应用的操作步骤。 首先&#xf…

实现几何对象按照一定距离向外缓冲

1、首先&#xff0c;确保你已经引入了Turf.js库。你可以通过在HTML文件中添加以下代码来引入 <script src"https://cdn.jsdelivr.net/npm/turf/turf6.5.0/turf.min.js"></script>2、使用turf.buffer实现几何对象按照设定距离扩充 let originalCoordinat…

【MATLAB源码-第183期】基于matlab的图像处理GUI很全面包括滤波,灰度,边缘提取,RGB亮度调节,二值化等。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 1. RGB颜色亮度调整 1.1 RGB颜色模型 RGB颜色模型是一种加色模型&#xff0c;使用红色&#xff08;R&#xff09;、绿色&#xff08;G&#xff09;、蓝色&#xff08;B&#xff09;三种颜色的不同组合来表示各种颜色。每种…

每日OJ题_两个数组dp⑤_力扣10. 正则表达式匹配

目录 力扣10. 正则表达式匹配 解析代码 力扣10. 正则表达式匹配 10. 正则表达式匹配 难度 困难 给你一个字符串 s 和一个字符规律 p&#xff0c;请你来实现一个支持 . 和 * 的正则表达式匹配。 . 匹配任意单个字符* 匹配零个或多个前面的那一个元素 所谓匹配&#xff0c…

部署 GlusterFS 群集

目录 一、GFS部署 1.1.环境 1.2.更改节点名称 1.3.节点进行磁盘挂载&#xff0c;安装本地源 1.4.添加节点创建集群 1.5.根据规划创建卷 1.6. 部署gluster客户端 1.7. 破坏性测试 挂起 node2 节点或者关闭glusterd服务来模拟故障 复制卷&#xff0c;在node3和no…

基于springboot+vue+Mysql的药品商超管理系统

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

c++ 谷歌glog日志库使用

效果如图&#xff1a; 本次使用qt环境演示&#xff0c;相关库文件和头文件下载链接&#xff1a;https://download.csdn.net/download/bangtanhui/89108477 将相关库文件和头文件&#xff0c;丢到工程目录下 然后需要在工程pro文件当中引入库文件和头文件&#xff1a; …

LMDeploy 推理部署工具

一. 大模型部署面临的挑战 1. 计算量巨大 大模型参数量巨大&#xff0c;前向推理时需要进行大量计算。 2. 内存开销巨大 大模型在推理过程中&#xff0c;以FP16为例&#xff0c;20B模型仅加载参数就需40G显存&#xff0c;175B模型更是需要350G显存。同时在推理过程中&#xff…

JVM内存模型深度剖析

JDK体系结构 Java语言的跨平台特性 JDK整体结构及内存模型 JVM虚拟机 JVM主要由以下三个部分组成 类装载子系统:负责将Java类文件加载到运行时数据区中.并在运行时由类加载器创建Java类对象.运行时数据区:运行时数据区是JVM用于存储数据的内存区域.它包括方法区,堆,栈,本地方…

使用VPN时,Java程序无法访问远程网络的解决办法

应用场景&#xff1a; 电脑连接VPN之后&#xff0c;Java程序无法连接远程服务&#xff0c;比如第三方接口、远程数据库连接、远程微服务等。我个人遇到的情况有连接海康威视SDK&#xff0c;influxdb以及一些微服务。 解决办法&#xff1a; 启动Java时加入参数&#xff1a;-D…