【Python机器学习】决策树集成——随机森林

理论知识:

集成是合并多个机器学习模型来构建更强大模型法方法。

随机森林本质上是许多决策树的集合,其中每棵树都和其他数略有不同,随机森林背后的思想是:每棵树的预测可能都比较好,但是可能对部分数据过拟合,如果构造很多树,并且每棵预测的都很好,但都以不同的方式过拟合,那么可以对这些树的结果取平均来降低过拟合。

为了实现这一策略,需要构造很多决策树。每棵树都应该对目标值做出可以接受的预测,还应该与其他树不同。

随机森林中树的随机化方法有两周:一种是通过选择用于构造树的数据点,另一种是通过选择每次划分测试的特征。

想要构造一个随机森林模型,需要确定用于构造的树的个数。比如构造10棵树,这些树在构造时彼此完全独立,算法对这些树进行不同的随机选择,确保树和树之间是有区别的。想要构造一棵树,首先对数据进行自助采样。也就是说,从数据点中有放回的重复抽取样本,样本数与数据点数相同,这样会创建一个与原数据集相同大小的数据集,但是有些数据点会缺失或重复。

基于新数据集构造决策树,但是要对算法稍作修改。在每个叶结点处,算法随机选择特征的一个子集,并对其中一个特征寻找最佳测试,而不是对每个结点都寻找最佳测试。选择的特征个数由max_features参数来控制。每个结点中特征子集的选择是相互独立的,这样树的每个结点可以使用特征的不同子集来做出决策。

由于使用了自助采样,随机森林中构造每棵决策树的数据集都是略有不同的,由于每个结点的特征选择,每棵树的每次划分都是基于特征的不同子集。这两种方法共同确保了随机森林中每棵树都不相同。

构造过程中的一个关键参数是max_features,如果设置max_features=n_features,那么每次划分都要考虑数据集的所有特征,等于在特征选择过程中没有添加随机性,如果max_features=1,那么在划分时就无法选择对哪个特征进行测试,只能对随机选择的某个特征搜索不同的阈值。为了很好的拟合数据,每棵树的深度都要比较大。

想要利用随机森林进行预测,算法首先对森林中的每棵树进行预测,对于回归问题,可以对这些预测结果取均值作为最终结果,对于分类问题,可以采取“软投票”的方式取概率最大的结果作为最终的预测值。

分析随机森林:

import mglearn.plots
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']
X,y=make_moons(n_samples=100,noise=0.25,random_state=3)
X_train,X_test,y_train,y_test=train_test_split(X,y,stratify=y,random_state=42)

forest=RandomForestClassifier(n_estimators=5,random_state=2)
forest.fit(X_train,y_train)

fig,axes=plt.subplots(2,3,figsize=(20,10))
for i,(ax,tree) in enumerate(zip(axes.ravel(),forest.estimators_)):
    ax.set_title('Tree {}'.format(i))
    mglearn.plots.plot_tree_partition(X_train,y_train,tree,ax=ax)
mglearn.plots.plot_2d_separator(forest,X_train,fill=True,ax=axes[-1,-1],alpha=.4)
axes[-1,-1].set_title('随机森林')
mglearn.discrete_scatter(X_train[:,0],X_train[:,1],y_train)
plt.show()

可以看到,5棵树的决策边界大不相同,并且每棵树都犯了一些错误,因为有些训练点实际上没有包含在这些树的训练集里,这是自助采样的结果。

再构造一个包含100棵树的随机森林:

import mglearn.plots
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.datasets import load_breast_cancer
import numpy as np

def plot_importances(model):
    n_feature=cancer.data.shape[1]
    plt.barh(range(n_feature),model.feature_importances_,align='center')
    plt.yticks(np.arange(n_feature),cancer.feature_names)
    plt.xlabel('特征重要性')
    plt.ylabel('特征')

plt.rcParams['font.sans-serif'] = ['SimHei']
cancer=load_breast_cancer()
X_train,X_test,y_train,y_test=train_test_split(cancer.data,cancer.target,random_state=0)

forest=RandomForestClassifier(n_estimators=100,random_state=0)
forest.fit(X_train,y_train)

print('训练集特征:{:.3f}'.format(forest.score(X_train,y_train)))
print('测试集特征:{:.3f}'.format(forest.score(X_test,y_test)))

plot_importances(forest)
plt.show()

 

可以看到在没有任何参数的情况下,随机森林的精度为97.2%,比线性模型或单棵树都要好。可以通过调节max_feature参数进行调整,但是一般情况下,默认参数就已经可以给出很好的结果了。

特征重要性:

 与单棵树相比,随机森林中有更多特征不为0,而且因为算法需要考虑多种可能的解释,随机森林比单棵树更能从总体上把握数据的特征。

随机森林的优缺点:

随机森林拥有决策树所有的优点,同时弥补了决策树的一些缺点随机森林的本质上是随机的,设置不同的随机状态可以彻底改变构建的模型,森林中的树越多,对随机状态选择的鲁棒性就越好,如果希望可以复现,固定random_state是很重要的。

对于维度非常高的稀疏数据,随机森林的表现往往不是很好,对于这种数据,用线性模型会更合适。随机森林需要更大内存,训练和预测的速度比线性模型要慢一些,所以如果时间和内存很重要,也可以选择线性模型。

随机森林需要调节的参数有n_estimators和max_features,可能还包括预剪枝选项(max_dept), n_estimators越大越好,但是越大需要的内存和时间也更多。

max_features决定每棵树的随机性大小,较小的max_features可以降低过拟合,一般来说,max_features使用默认值就很好,对于分类问题,默认值max_features=sqrt(n_features),对于回归来说,默认值max_features= n_features。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/305330.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

若依项目的table列表中对每一个字段增加排序按钮(单体版和前后端分离版)

一、目标:每一个字段都添加上下箭头用来排序 只需要更改前端代码,不需要更改后端代码,后面会讲解原理 二、单体版实现方式: 1.在options中添加sortable:true 2.在需要排序的字段中添加sortable:true 三、前后端分离版 1.el-table上添加@sort-change=“handleSortChange”…

MySQL的导入导出及备份

一.准备导入之前 二.navicat导入导出 ​编辑 三.MySQLdump命令导入导出 四.load data file命令的导入导出 五.远程备份 六. 思维导图 一.准备导入之前 需要注意: 在导出和导入之前,确保你有足够的权限。在进行导入操作之前,确保目标数据…

Tensorflow2.0笔记 - 创建tensor

tensor创建可以基于numpy,list或者tensorflow本身的API。 笔记直接上代码: import tensorflow as tf import numpy as np import matplotlib.pyplot as plttf.__version__#通过numpy创建tensor tensor0 tf.convert_to_tensor(np.ones([2,3])) print(te…

GitHub 一周热点汇总 第4期 (2024/01/01-01/06)

GitHub一周热点汇总第四期 (2023/12/24-12/30),梳理每周热门的GitHub项目,了解热点技术趋势,掌握前沿科技方向,发掘更多商机。2024年到了,希望所有的朋友们都能万事顺遂。 说明一下,有时候本周的热点项目会…

null和undefined的区别

null 和 undefined 是 JavaScript 中的两个基础类型特殊值。它们都表示“空”,但是有一些区别。 一、null 在 JavaScript 内部,null 被视为一个表示空值或缺少值的对象指针。在计算机内存中,它通常被表示为一个指向内存空间的空指针。这意味…

源码开发实践:搭建企业培训APP的技术难题及解决方案

在企业培训源码开发实践中,各位开发者可能遇到各种各样的问题,本文将深入探讨这些挑战,并提供解决方案,助力你顺利搭建企业培训APP。 1.多平台兼容性 企业中员工使用的设备多种多样,包括iOS、Android等不同操作系统。…

电力监控系统在数据中心应用

摘 要:在电力系统的运行过程中,变电站作为整个电力系统的核心,在保证电力系统可靠的运行方面起着至关重要的作用,基于此需对变电站监控系统的特点进行分析,结合变电站监控系统的功能需求,对变电站电力监控系…

BitMap解析(一)

文章目录 前言数据结构添加与删除操作 JDK中BitSet源码解析重要成员属性初始化添加数据清除数据获取数据size和length方法集合操作:与、或、异或 前言 为什么称为bitmap? bitmap不仅仅存储介质以及数据结构不同于hashmap,存储的key和value也…

Spring MVC概述及入门

MVC介绍 MVC是一种设计模式,将软件按照模型、视图、控制器来划分: M:(Model)模型层,指工程中的JavaBean,作用是处理数据 数据模型:User、Student,装数据 业务模型&#…

C++ Primer 第五版 中文版 阅读笔记 + 个人思考

C Primer 第五版 中文版 阅读笔记 个人思考 第 10 章 泛型算法10.1 概述练习10.1练习10.2 第 10 章 泛型算法 泛型的体现:容器类型(包括内置数组),元素类型,元素操作方法。 顺序容器定义的操作:insert&a…

Android-多线程

线程是进程中可独立执行的最小单位,也是 CPU 资源(时间片)分配的基本单位,同一个进程中的线程可以共享进程中的资源,如内存空间和文件句柄。线程有一些基本的属性,如id、name、以及priority。 id&#xff1…

API调试怎么做?Apipost快速上手

前言 Apipost是一款支持 RESTful API、SOAP API、GraphQL API等多种API类型,支持 HTTPS、WebSocket、gRPC多种通信协议的API调试工具。除此之外,Apipost 还提供了自动化测试、团队协作、等多种功能。这些丰富的功能简化了工作流程,提高了研发…

antv/x6_2.0学习使用(四、边)

一、添加边 节点和边都有共同的基类 Cell,除了从 Cell 继承属性外,还支持以下选项。 属性名类型默认值描述sourceTerminalData-源节点或起始点targetTerminalData-目标节点或目标点verticesPoint.PointLike[]-路径点routerRouterData-路由connectorCon…

网络流量分析与故障分析

1.网络流量实时分析 网络监控 也snmp协议 交换机和服务器打开 snmp就ok了 MRTG或者是prgt 用于对网络流量进行实时监测,可以及时了解服务器和交换机的流量,防止因流量过大而导致服务器瘫痪或网络拥塞。 原理 通过snmp监控 是一个…

MES/MOM标准之ISA-95基础内容介绍

ISA-95 简称S95,也有称作SP95。ISA-95 是企业系统与控制系统集成国际标准,由国际自动化学会(ISA,International Society of Automation) 在1995年投票通过。该标准的开发过程是由 ANSI(美国国家标准协会) 监督并保证其过程是正确的。ISA-95不…

acwing 并查集

目录 并查集的路径压缩两种方法法一法二 AcWing 240. 食物链AcWing 837. 连通块中点的数量示例并查集自写并查集 并查集的路径压缩两种方法 法一 沿着路径查询过程中,将非根节点的值都更新为最后查到的根节点 int find(int x) {if (p[x] ! x) p[x] find(p[x]);r…

爬取去哪网旅游攻略信息

代码展现: import requests import parsel import csv import time f open(旅游去哪攻略.csv,modea,encodingutf-8,newline) csv_writer csv.writer(f) csv_writer.writerow([标题,浏览量,日期,天数,人物,人均价格,玩法]) for page in range(1,5):url fhttps://…

整理的Binder、DMS、Handler、PMS、WMS等流程图

AMS: Binder: Handler: PMS: starActivity: WMS: 系统启动:

kdump安装及调试策略

本文基于redhat系的操作系统,debian系不太一样,仅提供参考 1.kdump的部署 注:一般很多操作系统在安装时可默认启动kdump。 (1)需要的包 yum install kexec-tools crash kernel-debuginfo (2&#xff0…

python画房子

前言 今天,我们来用Python画房子。 一、第一种 第一种比较简单。 代码: import turtle as t import timedef go(x, y):t.penup()t.goto(x, y)t.pendown() def rangle(h,w):t.left(180)t.forward(h)t.right(90)t.forward(w)t.left(-90)t.forward(h) de…