【机器学习算法】Adaboost原理及实现

Adaboost

一、基本内容

[!note]

实现思路:在每一轮训练中,记录每一次由 f ( x ) = ∑ m = 1 i − 1 α m G m ( x ) f(x) = \sum_{m=1}^{i-1}\alpha_mG_m(x) f(x)=m=1i1αmGm(x)【错误\正确】分类的样本,在加入新的弱学习器中【提高\降低】分类【错误\正确】样本的权值即改变样本的比例,类似过采样与降采样

  • 加法模型:多个弱分类器 G m ( x ) G_m(x) Gm(x)与对应权值 α m \alpha_m αm的叠加:

f ( x ) = ∑ m = 1 M α m G m ( x ) f(x) = \sum_{m=1}^M\alpha_mG_m(x) f(x)=m=1MαmGm(x)

  • 样本权值的初始设计(权值平等):

w 1 , i = 1 N w_{1,i}=\frac{1}{N} w1,i=N1

  • 二分类损失函数,指数损失函数

L ( y , f ( x ) ) = e x p [ − y f ( x ) ] = e x p [ − y ⋅ ( f m − 1 ( x i ) + α m G m ( x ) ) ] L(y,f(x)) = exp[-yf(x)]=exp[-y ·(f_{m-1}(x_i)+\alpha_mG_m(x))] L(y,f(x))=exp[yf(x)]=exp[y(fm1(xi)+αmGm(x))]

二、样本权值更新

​ 可以发现,指数损失函数在【正确\错误】分类的样本的值【小于\大于】1,正好符合Adaboost加法模型的实现思路,在加入新的弱学习器中【提高\降低】分类【错误\正确】样本的权值,所以第 m m m个弱分类器,第 i i i个样本的权值更新可以设计为:
ω m , i = e x p [ − y i f m − 1 ( x i ) ] \omega_{m,i} = exp[-y_if_{m-1}(x_i)] ωm,i=exp[yifm1(xi)]
​ 在Adaboost模型中, f m − 1 ( x i ) = a m − 1 G m − 1 f_{m-1}(x_i) = a_{m-1}G_{m-1} fm1(xi)=am1Gm1,所以,权值更新的公式为:
w m , i = e x p ( y i ( − a m − 1 ) G m − 1 ( x i ) ) w_{m,i} = exp(y_i(-a_{m-1})G_{m-1}(x_i)) wm,i=exp(yi(am1)Gm1(xi))
​ 为了加强不同弱分类器之间的依赖性,在更新权值时是在上一个弱分类器模型的基础上进行更新的:
w m , i = w m − 1 , i ⋅ e x p ( − a m − 1 y i G m − 1 ( x i ) ) w_{m,i} = w_{{m-1},i} · exp(-a_{m-1}y_iG_{m-1}(x_i)) wm,i=wm1,iexp(am1yiGm1(xi))
​ 最后加入 Z m − 1 Z_{m-1} Zm1,得到最终的权值更新式子:
w m , i = w m − 1 , i z m − 1 e x p ( − a m − 1 y i G m − 1 ( x i ) ) w_{m,i} = \frac{w_{m-1,i}}{z_{m-1}}exp(-a_{m-1}y_iG_{m-1}(x_i)) wm,i=zm1wm1,iexp(am1yiGm1(xi))
​ 其中,规范化因子 Z m − 1 Z_{m-1} Zm1表示为:
Z m − 1 = ∑ i = 1 N ω m − 1 , i e x p ( − a m − 1 y i G m − 1 ( x i ) ) Z_{m-1} = \sum_{i=1}^{N}\omega_{m-1,i}exp(-a_{m-1}y_iG_{m-1}(x_i)) Zm1=i=1Nωm1,iexp(am1yiGm1(xi))

[!important]

分类正确时, y i = G m − 1 ( x i ) y_i=G_{m-1}(x_i) yi=Gm1(xi), e x p ( − a m − 1 y i G m − 1 ( x i ) ) = e x p ( − a m − 1 ) < 1 exp(-a_{m-1}y_iG_{m-1}(x_i))=exp(-a_{m-1}) < 1 exp(am1yiGm1(xi))=exp(am1)<1, 其中 a m − 1 > 1 a_{m-1}>1 am1>1,对应正确样本的权值会减少,同理当错误分类时, e x p ( − a m − 1 ) > 1 exp(-a_{m-1}) > 1 exp(am1)>1,对应样本的权值增加

三、弱分类器权值更新

​ 目标损失函数:
L ( y , f ( x ) ) = e x p [ − y f ( x ) ] = e x p [ − y ⋅ ( f m − 1 ( x i ) + α m G m ( x ) ) ] L(y,f(x)) = exp[-yf(x)]=exp[-y ·(f_{m-1}(x_i)+\alpha_mG_m(x))] L(y,f(x))=exp[yf(x)]=exp[y(fm1(xi)+αmGm(x))]
​ 在模型优化更新权重的过程中,并不是与传统模型一样采用梯度下降法,因为弱分类器的数量多,更新的参数多,难以实现,在Adaboost模型中采用的前向分布算法,只更新当前弱分类器 G m G_m Gm的参数,优化目标:
( a m , G m ( x ) = a r g m a x a , G ∑ i = 1 N e x p [ − y i ( f m − 1 ( x i ) + α G m ( x i ) ) ] ) (a_m,G_m(x) = argmax_{a,G} \sum_{i=1}^Nexp[-y_i(f_{m-1}(x_i) + \alpha G_m(x_i))]) (am,Gm(x)=argmaxa,Gi=1Nexp[yi(fm1(xi)+αGm(xi))])
​ 对 a a a求导的结果,表示损失最小的 α \alpha α
α m = 1 2 l o g 1 − e m e m \alpha_m = \frac{1}{2}log\frac{1-e_m}{e_m} αm=21logem1em

​ 其中 e m e_m em表示误差率:
e m ′ = ∑ i = 1 N P ( G m ( x i ≠ y i ) ) = ∑ i = 1 N ω m i I ( G m ( x i ≠ y i ) ) = ∑ G m ( x i ) ≠ y i ω m i e_m' = \sum_{i=1}^N P(G_m(x_i \neq y_i))= \sum_{i=1}^N\omega_{mi}I(G_m(x_i \neq y_i))= \sum_{G_m(x_i) \neq y_i}\omega_{mi} em=i=1NP(Gm(xi=yi))=i=1NωmiI(Gm(xi=yi))=Gm(xi)=yiωmi

​ 最后还需要实现归一化:
e m = ∑ G m ( x i ) ≠ y i ω m i ∑ i = 1 N ω i m e_m = \frac{\sum_{G_m(x_i) \neq y_i}\omega_{mi}}{\sum_{i=1}^N}\omega_i^{m} em=i=1NGm(xi)=yiωmiωim

四、代码实现
  • 定义弱分类器,采用决策树
model = DecisionTreeClassifier(max_depth=1)
model.fit(X, y, sample_weight=w)
y_pred = model.predict(X)
  • 权重更新

权重初始化 w 1 , i = 1 N w_{1,i}=\frac{1}{N} w1,i=N1

# 权重初始化
w = np.ones(n_samples) / n_samples # 初始化权重

误差率计算 e m = ∑ G m ( x i ) ≠ y i ω m i ∑ i = 1 N ω i m e_m = \frac{\sum_{G_m(x_i) \neq y_i}\omega_{mi}}{\sum_{i=1}^N}\omega_i^{m} em=i=1NGm(xi)=yiωmiωim

弱分类器 α \alpha α更新 α m = 1 2 l o g 1 − e m e m \alpha_m = \frac{1}{2}log\frac{1-e_m}{e_m} αm=21logem1em

# 误差率计算以及更新权重
err = np.sum(w * (y_pred != y)) / np.sum(w)
alpha = 0.5 * np.log((1 - err) / max(err, 1e-10))

样本权重更新 w m , i = w m − 1 , i z m − 1 e x p ( − a m − 1 y i G m − 1 ( x i ) ) w_{m,i} = \frac{w_{m-1,i}}{z_{m-1}}exp(-a_{m-1}y_iG_{m-1}(x_i)) wm,i=zm1wm1,iexp(am1yiGm1(xi))

# 更新样本权重
norm = np.sum(w)
w = w * np.exp(-alpha * y * y_pred)
w /= norm 
  • 加法模型预测: f ( x ) = ∑ m = 1 M α m G m ( x ) f(x) = \sum_{m=1}^M\alpha_mG_m(x) f(x)=m=1MαmGm(x)
for alpha, model in zip(self.alphas, self.models):
    pred += alpha * model.predict(X)
return np.sign(pred)
  • 完整代码:
class AdaBoost:
    def __init__(self, n_estimators=50):
        self.n_estimators = n_estimators
        self.alphas = []  # 每个弱分类器的权重
        self.models = []  # 弱分类器列表

    def fit(self, X, y):
        n_samples, n_features = X.shape
        # 初始化样本权重
        w = np.ones(n_samples) / n_samples # 初始化权重

        for _ in range(self.n_estimators):
            # 使用样本权重训练一个弱分类器
            model = DecisionTreeClassifier(max_depth=1)
            model.fit(X, y, sample_weight=w)
            y_pred = model.predict(X)

            # 计算分类误差率
            err = np.sum(w * (y_pred != y)) / np.sum(w)
            if err >= 0.5:
                break

            # 计算弱分类器的权重
            alpha = 0.5 * np.log((1 - err) / max(err, 1e-10))
            self.alphas.append(alpha)
            self.models.append(model)
            norm = np.sum(w)
            # 更新样本权重
            w = w * np.exp(-alpha * y * y_pred)
            w /= norm

    def predict(self, X):
        pred = np.zeros(X.shape[0])
        for alpha, model in zip(self.alphas, self.models):
            pred += alpha * model.predict(X)
        return np.sign(pred)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/925108.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024年信号处理与神经网络应用(SPNNA 2024)

会议官网&#xff1a;www.spnna.org 会议时间&#xff1a;2024年12月13-15日 会议地点&#xff1a;中国武汉

canal同步数据教程

canal简介 官网&#xff1a;https://github.com/alibaba/canal 主要是基于 MySQL 数据库增量日志解析&#xff0c;提供增量数据订阅和消费&#xff0c;是一个实时同步的方案。 基于日志增量订阅和消费的业务包括 数据库镜像数据库实时备份索引构建和实时维护(拆分异构索引、…

【网络安全 | 漏洞挖掘】绕过SAML认证获得管理员面板访问权限

未经许可,不得转载。 文章目录 什么是SAML认证?SAML是如何工作的?SAML响应结构漏洞结果什么是SAML认证? SAML(安全断言标记语言)用于单点登录(SSO)。它是一种功能,允许用户在多个服务之间切换时无需多次登录。例如,如果你已经登录了facebook.com,就不需要再次输入凭…

【Redis】Redis介绍

目录 1.Redis是什么? 2. Redis特性 2.1 速度快 2.2 基于键值对的数据结构服务器 2.3 丰富的功能 2.4 简单稳定 2.5 客户端语言多 2.6 持久化 2.7 主从复制 2.8 高可用和分布式 3. Redis使用场景 3.1 缓存(Cache) 3.2 排行榜系统 3.3 计数器应用 3.4 社交网络 …

【HarmonyOS学习日志(10)】一次开发,多端部署之功能级一多开发,工程级一多开发

功能级一多开发 SysCap机制介绍 HarmonyOS使用SysCap机制&#xff08;即SystemCapability&#xff09;&#xff0c;可以帮助开发者仅关注设备的系统能力&#xff0c;而不用考虑成百上千种具体的设备类型。 在过去&#xff0c;开发不同设备上的应用就用不同设备的SDK进行开发&…

vue3 与 spring-boot 完成跨域访问

spring-boot&#xff0c;写一个接口用于前端访问&#xff0c;并且给接口设置跨域访问&#xff0c;这里我前端的域名为 localhost:5173 RestController CrossOrigin(origins "http://localhost:5173") public class Vue3Controller {GetMapping("/vue")pu…

机器学习-神经网络(BP神经网络前向和反向传播推导)

1.1 神经元模型 神经网络(neural networks)方面的研究很早就已出现,今天“神经网络”已是一个相当大的、多学科交叉的学科领域.各相关学科对神经网络的定义多种多样,本书采用目前使用得最广泛的一种,即“神经网络是由具有适应性的简单单元组成的广泛并行互连的网络,它的组织能够…

如何通过智能生成PPT,让演示文稿更高效、更精彩?

在快节奏的工作和生活中&#xff0c;我们总是追求更高效、更精准的解决方案。而在准备演示文稿时&#xff0c;PPT的制作往往成为许多人头疼的问题。如何让这项工作变得轻松且富有创意&#xff1f;答案或许就在于“AI生成PPT”这一智能工具的广泛应用。我们就来聊聊如何通过这些…

丹摩|丹摩智算平台使用教学指南

本指南旨在为新用户提供一个详细的操作步骤和实用的入门指导&#xff0c;帮助大家快速上手丹摩智算平台。 一、平台简介 丹摩智算平台是一款强大的数据分析和计算平台&#xff0c;支持多种编程语言&#xff0c;提供丰富的数据处理和机器学习工具。无论您是数据分析师、开发者…

Python学习第十天--处理CSV文件和JSON数据

CSV&#xff1a;简化的电子表格&#xff0c;被保存为纯文本文件 JSON&#xff1a;是一种数据交换格式&#xff0c;易于人阅读和编写&#xff0c;同时也易于机器解析和生成&#xff0c;以JavaScript源代码的形式将信息保存在纯文本文件中 一、csv模块 CSV文件中的每行代表电…

mini-spring源码分析

IOC模块 关键解释 beanFactory&#xff1a;beanFactory是一个hashMap, key为beanName, Value为 beanDefination beanDefination: BeanDefinitionRegistry&#xff0c;BeanDefinition注册表接口&#xff0c;定义注册BeanDefinition的方法 beanReference&#xff1a;增加Bean…

2024年9月中国电子学会青少年软件编程(Python)等级考试试卷(六级)答案 + 解析

一、单选题 1、下面代码运行后出现的图像是&#xff1f;&#xff08; &#xff09; import matplotlib.pyplot as plt import numpy as np x np.array([A, B, C, D]) y np.array([30, 25, 15, 35]) plt.bar(x, y) plt.show() A. B. C. D. 正确答案&#xff1a;A 答案…

UniApp开发实战:常见报错解析与解决方案

UniApp开发实战&#xff1a;常见报错解析与解决方案 病例1、TypeError: undefined is not an object (evaluating ‘this. s c o p e . scope. scope.getAppWebview’) 需求&#xff1a;获取页面示例&#xff0c;动态修改头部搜索框内容&#xff0c;获取页面实例时候报错unde…

BGP对等体建立方法--实验

目录 实验拓扑图 实验要求&#xff1a; 第一步、IP地址规划 第二步、配置接口IP地址 第三步、AS 200使用IGP OSPF实现网络互通 第四步、建立BGP对等体关系 1、R1与R2使用直连链路建立EBGP关系。 2、R2与R4使环回建立非直连IBGP关系。 3、R4与R5使用环回建立EBGP关系。…

(已解决)wps无法加载此加载项程序mathpage.wll

今天&#xff0c;在安装Mathtype的时候遇到了点问题&#xff0c;如图所示 尝试了网上的方法&#xff0c;将C:\Users\Liai_\AppData\Roaming\Microsoft\Word\STARTUP路径中的替换为32位的Mathtype加载项。但此时&#xff0c;word又出现了问题 后来知道了&#xff0c;这是因为64位…

Vue+Element Plus实现自定义表单弹窗

目录 一、基本框架 1.父组件index.vue 2.子组件FormPop.vue 二、细节补充 1&#xff09;input、textarea、select、input number 2&#xff09;daterange、date、monthrange 3&#xff09;数据定义 4&#xff09;没改样式的效果 5&#xff09;最终效果 三、最终代码 …

【插入排序】:直接插入排序、二分插入排序、shell排序

【插入排序】&#xff1a;直接插入排序、二分插入排序、shell排序 1. 直接插入排序1.1 详细过程1.2 代码实现 2. 二分插入排序2.1 详细过程2.2 代码实现 3. shell排序3.1 详细过程3.2 代码实现 1. 直接插入排序 1.1 详细过程 1.2 代码实现 public static void swap(int[]arr,…

PHP 生成分享海报

因为用户端有多个平台&#xff0c;如果做分享海报生成&#xff0c;需要三端都来做&#xff0c;工作量比较大。 所以这个艰巨的任务就光荣的交给后端了。经过一定时间的研究和调试&#xff0c;最终圆满完成了任务&#xff0c;生成分享海报图片实现笔记如下。 目录 准备字体文件…

MySQL底层概述—5.InnoDB参数优化

大纲 1.内存相关参数优化 (1)缓冲池内存大小配置 (2)配置多个Buffer Pool实例 (3)Chunk(块)大小配置 (4)InnoDB缓存性能评估 (5)Page管理相关参数 (6)Change Buffer相关参数优化 2.日志相关参数优化 (1)日志缓冲区相关参数配置 (2)日志文件参数优化 3.IO线程相关参数…

05_JavaScript注释与常见输出方式

JavaScript注释与常见输出方式 JavaScript注释 源码中注释是不被引擎所解释的&#xff0c;它的作用是对代码进行解释。lavascript 提供两种注释的写法:一种是单行注释&#xff0c;用//起头:另一种是多行注释&#xff0c;放在/*和*/之间。 //这是单行注释/* 这是 多行 注释 *…