随机森林,Random Forests Classifiers/Regressor

目录

介绍: 

一、 Random Forests Classifiers(离散型)

1.1 数据处理

1.2建模

1.3特征值权值分析

1.4 特征值的缩减

二、Random Forests Regressor(连续型)

2.1数据处理 

2.2建模

2.3调参


介绍: 

随机森林(Random Forests)是一种集成学习算法,它由多个决策树组成。它在每个决策树的训练过程中引入了随机性,以提高模型的泛化能力和鲁棒性。

随机森林的训练过程如下:

  1. 从训练集中随机选取一部分样本,构建一个决策树。这种随机选取样本的过程叫做自助采样(bootstrap sampling)。
  2. 对于每个决策树的每个节点,从所有特征中随机选取一部分特征,根据这些特征来选择最优的分割点。
  3. 重复以上两个步骤,构建多个决策树。
  4. 预测时,将待预测样本输入到每个决策树中,得到多个预测结果。最终,根据这些预测结果进行投票或平均,确定最终的预测结果。

随机森林在许多方面都表现出良好的性能。它可以用于分类问题和回归问题,并且对于处理高维数据和大型数据集也非常有效。此外,随机森林能够处理缺失数据和不平衡数据,并能够评估特征的重要性。

总的来说,随机森林是一种强大的机器学习算法,它通过组合多个决策树的预测结果来提高模型的性能和鲁棒性。它在实际应用中广泛使用,并且具有很好的可解释性和通用性。

一、 Random Forests Classifiers(离散型)

1.1 数据处理

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

data=pd.read_csv('iris.csv')#离散型

X=data.iloc[:,1:5]
'''结果:
 	Sepal.Length 	Sepal.Width 	Petal.Length 	Petal.Width
0 	5.1 	3.5 	1.4 	0.2
1 	4.9 	3.0 	1.4 	0.2
2 	4.7 	3.2 	1.3 	0.2
3 	4.6 	3.1 	1.5 	0.2
4 	5.0 	3.6 	1.4 	0.2
... 	... 	... 	... 	...
145 	6.7 	3.0 	5.2 	2.3
146 	6.3 	2.5 	5.0 	1.9
147 	6.5 	3.0 	5.2 	2.0
148 	6.2 	3.4 	5.4 	2.3
149 	5.9 	3.0 	5.1 	1.8

150 rows × 4 columns
'''

y=data.iloc[:,-1:]
'''结果:
 	Species
0 	setosa
1 	setosa
2 	setosa
3 	setosa
4 	setosa
... 	...
145 	virginica
146 	virginica
147 	virginica
148 	virginica
149 	virginica

150 rows × 1 columns
'''

1.2建模

from  sklearn.model_selection import train_test_split#将数据分成测试和训练集
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=0)#测试集占百分之三十,random_state=0随机抽取数据集里的成为测试集

from sklearn.ensemble import RandomForestClassifier#import  random forest model

clf=RandomForestClassifier(n_estimators=100)#赋类,100棵树,后面可以调参
clf.fit(X_train,y_train)#训练集喂给这个模型
y_pred=clf.predict(X_test)#预测值

y_pred#预测值
'''结果:
array(['virginica', 'versicolor', 'setosa', 'virginica', 'setosa',
       'virginica', 'setosa', 'versicolor', 'versicolor', 'versicolor',
       'virginica', 'versicolor', 'versicolor', 'versicolor',
       'versicolor', 'setosa', 'versicolor', 'versicolor', 'setosa',
       'setosa', 'virginica', 'versicolor', 'setosa', 'setosa',
       'virginica', 'setosa', 'setosa', 'versicolor', 'versicolor',
       'setosa', 'virginica', 'versicolor', 'setosa', 'virginica',
       'virginica', 'versicolor', 'setosa', 'virginica', 'versicolor',
       'versicolor', 'virginica', 'setosa', 'virginica', 'setosa',
       'setosa'], dtype=object)
'''

from sklearn import metrics
metrics.accuracy_score(y_test,y_pred)#模型的值,y_test,y_pred对比
#结果:0.9777777777777777

1.3特征值权值分析

#特征变量的权值分析
feature_list=list(X.columns)
feature_imp=pd.Series(clf.feature_importances_,index=feature_list).sort_values(ascending=False)
feature_imp#特征值的权重
'''结果:
Petal.Width     0.456188
Petal.Length    0.411471
Sepal.Length    0.106732
Sepal.Width     0.025609
dtype: float64
'''

feature_imp.index
#结果:Index(['Petal.Width', 'Petal.Length', 'Sepal.Length', 'Sepal.Width'], dtype='object')
sns.barplot(x=feature_imp,y=feature_imp.index)
plt.xlabel('feature importance score')
plt.ylabel('feature')
plt.legend(feature_imp.index)
plt.show()

 

1.4 特征值的缩减

#特征变量的缩减,对于成百上千特征变量的大数据非常有意义
X=data.iloc[:,1:-2]
'''结果:
 	Sepal.Length 	Sepal.Width 	Petal.Length
0 	5.1 	3.5 	1.4
1 	4.9 	3.0 	1.4
2 	4.7 	3.2 	1.3
3 	4.6 	3.1 	1.5
4 	5.0 	3.6 	1.4
... 	... 	... 	...
145 	6.7 	3.0 	5.2
146 	6.3 	2.5 	5.0
147 	6.5 	3.0 	5.2
148 	6.2 	3.4 	5.4
149 	5.9 	3.0 	5.1

150 rows × 3 columns
'''


X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=0)#测试集占百分之三十,random_state=0随机抽取数据集里的成为测试集
clf=RandomForestClassifier(n_estimators=100)#赋类,100棵树,后面可以调参
clf.fit(X_train,y_train)#训练集喂给这个模型
y_pred=clf.predict(X_test)#预测值
metrics.accuracy_score(y_test,y_pred)#模型的值,y_test,y_pred对比
#结果:0.9333333333333333

二、Random Forests Regressor(连续型)

2.1数据处理 

dataset = pd.read_csv('petrol_consumption.csv')
dataset#汽油税,收入,高速费,人口密度,汽油消耗

dataset.info()
'''结果:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 48 entries, 0 to 47
Data columns (total 5 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   Petrol_tax                    48 non-null     float64
 1   Average_income                48 non-null     int64  
 2   Paved_Highways                48 non-null     int64  
 3   Population_Driver_licence(%)  48 non-null     float64
 4   Petrol_Consumption            48 non-null     int64  
dtypes: float64(2), int64(3)
memory usage: 2.0 KB
'''

X = dataset.iloc[:,0:4]
y=dataset.iloc[:,4]

#数据差异非常大,需要引入数据标准化
from sklearn.preprocessing import StandardScaler
sc =  StandardScaler()
X_train = sc.fit_transform(X_train)
X_test=sc.transform(X_test)

X_train
'''结果:
array([[-0.60684249, -0.13370363, -0.39371558,  0.71661097],
       [-0.60684249,  0.73650306,  0.12337074,  2.41961586],
       [-0.60684249, -0.08812138,  1.37233744,  0.09273789],
       [ 0.33674156, -0.35747107,  0.14030588, -0.29507511],
       [ 1.28032561,  1.11152071, -0.85491594, -1.1718697 ],
       [-0.60684249, -0.3201765 ,  0.85525111, -0.14332219],
       [-0.60684249, -1.45973288, -0.42137631, -0.29507511],
       [-0.60684249,  0.48165682,  0.66501302,  1.39106835],
       [-0.60684249, -0.12541595, -0.52016464,  0.37938228],
       [ 0.33674156,  0.17915639,  0.87472653, -0.86836388],
       [-0.60684249, -0.31810458,  0.31106856, -0.59858093],
       [ 1.28032561, -1.32505804, -0.20658227, -0.61544236],
       [ 2.22390966,  2.03352542, -1.16990957, -0.16018363],
       [-0.60684249, -1.3312738 , -0.21250957, -0.6828881 ],
       [-1.00314779, -1.15723246,  0.66501302,  0.81777958],
       [-0.60684249, -0.05911449,  0.75674504,  0.46368945],
       [ 0.80853358, -1.50324322, -0.62205774,  1.39106835],
       [ 1.28032561, -0.55637546, -1.19333652, -0.14332219],
       [-0.60684249,  0.94576705,  0.40985689, -0.10959932],
       [-0.60684249, -1.27533194, -0.80919106, -1.22245401],
       [ 0.80853358,  0.44229032, -0.80693304, -0.49741232],
       [ 0.33674156,  1.98587124,  1.80361904, -2.18355578],
       [ 1.28032561, -0.21243662, -0.22351741, -1.0707011 ],
       [-2.49401059, -0.65375573,  3.4728595 , -0.2444908 ],
       [ 0.33674156,  1.28970589, -1.37623605,  0.36252084],
       [ 0.80853358, -0.0404672 ,  0.15018472,  1.62712844],
       [-0.60684249,  0.31383124,  0.85496886, -0.48055089],
       [-0.60684249, -0.03217952, -0.4439565 ,  1.54282127],
       [ 1.28032561,  0.23924209, -0.43351316, -0.16018363],
       [-0.13505047,  1.05557885, -0.88257667, -0.86836388],
       [ 1.28032561, -1.63584614, -0.9884213 , -0.93580962],
       [-1.55042654,  1.77039149, -0.89640703,  1.54282127]])
'''

2.2建模

from sklearn.ensemble import RandomForestRegressor
regressor = RandomForestRegressor(n_estimators=100,random_state=0)
regressor.fit(X_train,y_train)
y_pred = regressor.predict(X_test)
#结果:array([586.98, 489.34, 626.27, 658.65, 631.72, 614.6 , 612.19, 579.32,
#       465.44, 512.32, 438.32, 657.18, 624.52, 598.46, 534.4 , 555.78])


from sklearn import metrics
print('Mean Absolute Error',metrics.mean_absolute_error(y_test,y_pred))
print('Mean Squared Error',metrics.mean_squared_error(y_test,y_pred))
print('root mean squared error',np.sqrt(metrics.mean_squared_error(y_test,y_pred)))
'''结果:
Mean Absolute Error 50.18562499999999
Mean Squared Error 3964.588093749999
root mean squared error 62.964975134990716
'''

2.3调参

rmse=nestimators=[]#调参
for n in [20,30,50,80,100,200,300,400,500,600,700,800]:
    regressor = RandomForestRegressor(n_estimators=n,random_state=0)
    regressor.fit(X_train,y_train)
    y_pred = regressor.predict(X_test)
    print('-------------------')
    print('n_estimators={}',format(n))
    print('Mean Absolute Error',metrics.mean_absolute_error(y_test,y_pred))
    print('Mean Squared Error',metrics.mean_squared_error(y_test,y_pred))
    print('Root mean squared error',np.sqrt(metrics.mean_squared_error(y_test,y_pred)))
    rmse=np.append(rmse,np.sqrt(metrics.mean_squared_error(y_test,y_pred)))
    nestimators=np.append(nestimators,n)

'''结果:
-------------------
n_estimators={} 20
Mean Absolute Error 56.128125000000004
Mean Squared Error 4606.41578125
Root mean squared error 67.87058111766835
-------------------
n_estimators={} 30
Mean Absolute Error 49.94375000000001
Mean Squared Error 3922.442708333335
Root mean squared error 62.62940769585271
-------------------
n_estimators={} 50
Mean Absolute Error 49.158750000000005
Mean Squared Error 3868.3672749999996
Root mean squared error 62.19619984372035
-------------------
n_estimators={} 80
Mean Absolute Error 50.70390625
Mean Squared Error 4013.614755859375
Root mean squared error 63.35309586641662
-------------------
n_estimators={} 100
Mean Absolute Error 50.18562499999999
Mean Squared Error 3964.588093749999
Root mean squared error 62.964975134990716
-------------------
n_estimators={} 200
Mean Absolute Error 48.34375
Mean Squared Error 3622.057096875
Root mean squared error 60.18352845152069
-------------------
n_estimators={} 300
Mean Absolute Error 49.467708333333334
Mean Squared Error 3789.9574437499987
Root mean squared error 61.5626302536693
-------------------
n_estimators={} 400
Mean Absolute Error 48.489999999999995
Mean Squared Error 3636.7144398437504
Root mean squared error 60.30517755420135
-------------------
n_estimators={} 500
Mean Absolute Error 48.917499999999976
Mean Squared Error 3726.081923499998
Root mean squared error 61.041640897833
-------------------
n_estimators={} 600
Mean Absolute Error 48.97749999999999
Mean Squared Error 3719.864061805555
Root mean squared error 60.990688320476885
-------------------
n_estimators={} 700
Mean Absolute Error 48.50473214285714
Mean Squared Error 3633.144154209183
Root mean squared error 60.275568468569276
-------------------
n_estimators={} 800
Mean Absolute Error 48.12984374999999
Mean Squared Error 3560.1158533203115
Root mean squared error 59.66670640583668
'''
rmse
'''结果:
array([67.87058112, 62.6294077 , 62.19619984, 63.35309587, 62.96497513,
       60.18352845, 61.56263025, 60.30517755, 61.0416409 , 60.99068832,
       60.27556847, 59.66670641])
'''

sns.set_style('whitegrid')
plt.plot(nestimators,rmse,'bo',linestyle='dashed',linewidth=1,markersize=10)#前面x,后面y
plt.xlabel('feature importance score')
plt.ylabel('features')
plt.title("viualizing importeant features")
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/290595.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数据库:基础SQL知识+SQL实验1

&#xff08;1&#xff09;基础知识&#xff1a; 1.创建数据库&#xff1a; CREATE DATABASE <database_name> 2.删除数据库&#xff1a; DROP DATABASE <database_name> 3.相关数据类型&#xff1a; [1] 字符串类型 CHAR(n)&#xff1a;固定长度的字符数据…

基于ssm的《数据库系统原理》课程平台的设计与实现论文

目 录 目 录 I 摘 要 III ABSTRACT IV 1 绪论 1 1.1 课题背景 1 1.2 研究现状 1 1.3 研究内容 2 2 系统开发环境 3 2.1 vue技术 3 2.2 JAVA技术 3 2.3 MYSQL数据库 3 2.4 B/S结构 4 2.5 SSM框架技术 4 3 系统分析 5 3.1 可行性分析 5 3.1.1 技术可行性 5 3.1.2 操作可行性 5 3…

[蓝桥杯学习]​树上差分

差分 前缀和 sum_i sum_i-1 a_i 差分 diff_i a_i - a_i-1 差分的好处 点的差分 问题引入 解决问题 要用到差分的思想&#xff0c;每次从叶子向上的回溯&#xff0c;会让父结点子结点的cnt值&#xff0c;但是仅仅这样&#xff0c;还不行 回溯的过程中&#xff0c;LCA被…

【Midjourney】AI绘画新手教程(一)登录和创建服务器,生成第一幅画作

一、登录Discord 1、访问Discord官网 使用柯學尚网&#xff08;亲测非必须&#xff0c;可加快响应速度&#xff09;访问Discord官方网址&#xff1a;https://discord.com 选择“在您的浏览器中打开Discord” 然后&#xff0c;注册帐号、购买套餐等&#xff0c;在此不做缀述。…

[每周一更]-(第56期):不能不懂的网络知识

作为程序员&#xff0c;在网络方面具备一定的知识和技能是非常重要的。以下是一些程序员需要熟练掌握的网络知识&#xff1a; 基础网络概念&#xff1a; IP地址&#xff1a;了解IPv4和IPv6地址的格式和分配方式&#xff0c;以及常见的IP地址分类。子网掩码&#xff1a;理解子…

大数据 - Doris系列《一》- Doris简介

目录 &#x1f436;1.1 Doris 概述 &#x1f436;1.2 OLAP和OLTP&#xff08;面试&#xff09; 1. 应用场景 &#x1f959;联机事务处理OLTP(On-Line Transaction Processing) &#x1f959;联机分析处理OLAP(On-Line Analytical Processing) 2. OLAP和OLTP比较--“用户行…

大数据技术在民生资金专项审计中的应用

一、应用背景 目前&#xff0c;针对审计行业&#xff0c;关于大数据技术的相关研究与应用一般包括大数据智能采集数据技术、大数据智能分析技术、大数据可视化分析技术以及大数据多数据源综合分析技术。其中&#xff0c;大数据智能采集数据技术是通过网络爬虫或者WebService接…

Linux程序、进程以及计划任务(第一部分)

目录 一、程序和进程 1、什么是程序&#xff1f; 2、什么是进程&#xff1f; 3、线程是什么&#xff1f; 4、如何查看是多线程还是单线程 5、进程结束的两种情况&#xff1a; 6、进程的状态 二、查看进程信息的相关命令 1、ps&#xff1a;查看静态的进程统计信息 2、…

WEB:探索开源PDF.js技术应用

1、简述 PDF.js 是一个由 Mozilla 开发的开源 JavaScript 库&#xff0c;用于在浏览器中渲染 PDF 文档。它的目标是提供一个纯粹的前端解决方案&#xff0c;摆脱了依赖插件或外部程序的束缚&#xff0c;使得在任何支持 JavaScript 的浏览器中都可以轻松地显示 PDF 文档。 2、…

git的拉取、提交、合并、解决冲突详细教程

我们在开发中使用git&#xff0c;经常会遇到拉代码&#xff0c;切换分支&#xff0c;提交代码&#xff0c;新建分支&#xff0c;合并代码&#xff0c;解决冲突这些操作&#xff0c;下面我跟大家分享一个好用的git工具来进行这些操作。 首先&#xff0c;我们下载一个git工具 点…

HarmonyOS4 vp单位计算

我们在harmonyOS中设置宽度等单位时 需要在后面写明具体是什么单位 width("100%")这里 我们就写明了是 百分之百 如果不写 直接给数值 width(100)那么 它就会按vp去读 这里就被读为 100vp vp 之前是一种移动端宽度概念 后面鸿蒙重定义了它的概念 计算公式是 px 乘…

实战环境搭建-安装xshell和xftp

安装xshell和xftp的原因是想远程虚拟机&#xff0c;很多时候&#xff0c;直接去操作虚拟机明显不太方便。 所以&#xff0c;我们需要一个能够搭载虚拟机和本地电脑之间的桥梁&#xff0c;哪怕是你们去了企业&#xff0c;也和这个类似&#xff0c;唯一的区别是企业里面更多连接…

Centos 磁盘挂载和磁盘扩容(新加硬盘方式)

步骤总结如下 一、对磁盘进行分区 二、对磁盘进行格式化 三、将磁盘挂载到对应目录 四、做开机自动挂载磁盘 磁盘分区 1.使用命令&#xff1a;fdisk -l 查看磁盘&#xff08;注&#xff1a;正常在Centos7中第一块数据盘标识一般是/dev/sda,第二块数据盘标识一般是/dev/sdb&…

2024年防止内卷和被潜规则,RocketMQ消息中间件实战派上下册上线啦|架构随笔录

2023已经过去啦&#xff0c;作为技术小伙伴一定要做好2024年的规划&#xff0c;只有这样才能够避免内卷和潜规则。 2024年即将是一个重新开始的一年&#xff0c;但是你要说互联网不倦&#xff0c;那是不可能的&#xff0c;就连某大厂都开始走下坡路啦&#xff0c;里面卷的是不…

时间序列平稳性相关检验方法

理解平稳性 一般来说&#xff0c;平稳时间序列是指随着时间的推移具有相当稳定的统计特性的时间序列&#xff0c;特别是在均值和方差方面。平稳性可能是一个比较模糊的概念&#xff0c;将序列排除为不平稳可能比说序列是平稳的更容易。通常不平稳序列有几个特征&#xff1a; …

【Pytorch】学习记录分享13——OCR(Optical Character Recognition,光学字符识别)

[TOC](OCR(Optical Character Recognition,光学字符识别)) 1. OCR资源汇总 OCR(Optical Character Recognition,光学字符识别)指提取图像中的文字信息&#xff0c;通常包括文本检测和文本识别。 文字检测&#xff1a;将图片中的文字区域位置检测出来&#xff08;如图1(b)所示…

怎么寄快递可以便宜一点,怎么领快递优惠券?

随着网购越来越多了&#xff0c;人们对于寄快递的需求也越来越大啦。那么&#xff0c;怎么样寄快递才便宜呢&#xff1f;今天&#xff0c;就让有十年网店经验的小编来告诉你。忒别是最近又临近年关&#xff0c;人民喜悦的心情越来越迫切。亲戚朋友之间互送礼品的往来也越来越密…

C++ 多态向下转型详解

文章目录 1 . 前言2 . 多态3 . 向下转型3.1 子类没有改进父类的方法下&#xff0c;去调用该方法3.2 子类有改进父类的方法下&#xff0c;去调用该方法3.3 子类没有改进父类虚函数的方法下&#xff0c;去调用改方法3.4 子类有改进父类虚函数的方法下&#xff0c;去调用改方法3.5…

【设计模式之美】面向对象分析方法论与实现(二):需求到接口实现的方法论

文章目录 一. 进行面向对象设计1. 划分职责>需要有哪些类2. 定义类及其属性和方法3. 定义类与类之间的交互关系4. 将类组装起来并提供执行入口 二. 如何进行面向对象编程&#xff1f;1. 接口实现2. 辩证思考与灵活应用 【设计模式之美】面向对象分析方法论与实现&#xff08…

【JUC】Volatile关键字+CPU/JVM底层原理

Volatile关键字 volatile内存语义 1.当写一个volatile变量时&#xff0c;JMM会把该线程对应的本地内存中的共享变量值立即刷新回主内存中。 2.当读一个volatile变量时&#xff0c;JMM会把该线程对应的本地内存设置为无效&#xff0c;直接从主内存中读取共享变量 所以volatile…