决策树实验分析(分类和回归任务,剪枝,数据对决策树影响)

目录

1. 前言

2. 实验分析

        2.1 导入包

        2.2 决策树模型构建及树模型的可视化展示

        2.3 概率估计

        2.4 绘制决策边界

        2.5 决策树的正则化(剪枝)

        2.6 对数据敏感

        2.7 回归任务

        2.8 对比树的深度对结果的影响

        2.9 剪枝


1. 前言

        本文主要分析了决策树的分类和回归任务,对比一系列的剪枝的策略对结果的影响,数据对于决策树结果的影响。

        介绍使用graphaviz这个决策树可视化工具

2. 实验分析

        2.1 导入包

#1.导入包
import os
import numpy as np
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
import warnings
warnings.filterwarnings('ignore')

        2.2 决策树模型构建及树模型的可视化展示

        下载安装包:https://graphviz.gitlab.io/_pages/Download/Download_windows.html

         选择一款安装,注意安装时要配置环境变量

        注意这里使用的是鸢尾花数据集,选择花瓣长和宽两个特征

#2.建立树模型
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
iris = load_iris()
X = iris.data[:,2:] # petal legth and width
y = iris.target
tree_clf = DecisionTreeClassifier(max_depth=2)
tree_clf.fit(X,y)
#3.树模型的可视化展示
from sklearn.tree import export_graphviz
export_graphviz(
    tree_clf,
    out_file='iris_tree.dot',
    feature_names=iris.feature_names[2:],
    class_names=iris.target_names,
    rounded=True,
    filled=True
)

        然后就可以使用graphviz包中的dot.命令工具将此文件转换为各种格式的如pdf,png,如 dot -Tpng iris_tree.png -o iris_tree.png

        可以去文件系统查看,也可以用python展示

from IPython.display import Image
Image(filename='iris_tree.png',width=400,height=400)

        分析:value表示每个节点所有样本中各个类别的样本数,用花瓣宽<=0.8和<=1.75 作为根节点划分,叶子节点表示分类结果,结果执行少数服从多数策略,gini指数随着分类进行在减小。

        2.3 概率估计

        估计类概率 输入数据为:花瓣长5厘米,宽1.5厘米的花。相应节点是深度为2的左节点,因此决策树因输出以下概率:

        iris-Setosa为0%(0/54)

        iris-Versicolor为90.7%(49/54)

        iris-Virginica为9.3%(5/54)

        

#4.概率估计
print(tree_clf.predict_proba([[5,1.5]]))
print(tree_clf.predict([[5,1.5]]))

        2.4 绘制决策边界

        

#5.绘制决策边界
from matplotlib.colors import ListedColormap

def plot_decision_boundary(clf,X,y,axes=[0,7.5,0,3],iris=True,legend=False,plot_training=True):
    #找两个特征 x1 x2
    x1s = np.linspace(axes[0],axes[1],100)
    x2s = np.linspace(axes[2],axes[3],100)
    #构建棋盘
    x1,x2 = np.meshgrid(x1s,x2s)
    #在棋盘中构建待测试数据
    X_new = np.c_[x1.ravel(),x2.ravel()]
    #将预测值算出来
    y_pred = clf.predict(X_new).reshape(x1.shape)
    #选择颜色
    custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])
    #绘制并填充不同的区域
    plt.contourf(x1,x2,y_pred,alpha=0.3,cmap=custom_cmap)

    if not iris:
        custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])
        plt.contourf(x1,x2,y_pred,alpha=0.8,cmap=custom_cmap2)
    
    #可以把训练数据展示出来
    if plot_training:
        plt.plot(X[:,0][y==0],X[:,1][y==0],'yo',label='Iris-Setosa')
        plt.plot(X[:,0][y==1],X[:,1][y==1],'bs',label='Iris-Versicolor')
        plt.plot(X[:,0][y==2],X[:,1][y==2],'g^',label='Iris-Virginica')
    if iris:
        plt.xlabel('Petal length',fontsize = 14)
        plt.ylabel('Petal width',fontsize = 14)

    else:
        plt.xlabel(r'$x_1$',fontsize=18)
        plt.ylabel(r'$x_2$',fontsize=18)
    if legend:
        plt.legend(loc='lower right',fontsize=14)
    
plt.figure(figsize=(8,4))
plot_decision_boundary(tree_clf,X,y)
plt.plot([2.45,2.45],[0,3],'k-',linewidth=2)
plt.plot([2.45,7.5],[1.75,1.75],'k--',linewidth=2)
plt.plot([4.95,4.95],[0,1.75],'k:',linewidth=2)
plt.plot([4.85,4.85],[1.75,3],'k:',linewidth=2)
plt.text(1.40,1.0,'Depth=0',fontsize=15)
plt.text(3.2,1.80,'Depth=1',fontsize=13)
plt.text(4.05,0.5,'(Depth=2)',fontsize=11)
plt.title('Decision Tree decision boundareies')

plt.show()

        
    

        可以看出三种不同颜色的代表分类结果,Depth=0可看作第一刀切分,Depth=1,2 看作第二刀,三刀,把数据集切分。

        2.5 决策树的正则化(剪枝)

        决策树的正则化

        DecisionTreeClassifier类还具有一些其他的参数类似地限制了决策树的形状

        min-samples_split(节点在分割之前必须具有的样本数)

        min-samples_leaf(叶子节点必须具有的最小样本数)

        max-leaf_nodes(叶子节点的最大数量)

        max_features(在每个节点处评估用于拆分的最大特征数)

        max_depth(树的最大深度)

#6.决策树正则化
from sklearn.datasets import make_moons
X,y = make_moons(n_samples=100,noise=0.25,random_state=53)
plt.plot(X[:,0],X[:,1],"b.")
tree_clf1 = DecisionTreeClassifier(random_state=42)
tree_clf2 = DecisionTreeClassifier(random_state=42,min_samples_leaf=4)
tree_clf1.fit(X,y)
tree_clf2.fit(X,y)
plt.figure(figsize=(12,4))
plt.subplot(121)
plot_decision_boundary(tree_clf1,X,y,axes=[-1.5,2.5,-1,1.5],iris=False)
plt.title('no restriction')
plt.subplot(122)
plot_decision_boundary(tree_clf2,X,y,axes=[-1.5,2.5,-1,1.5],iris=False)
plt.title('min_samples_leaf={}'.format(tree_clf2.min_samples_leaf))

        可以看出在没有加限制条件之前,分类器要考虑每个点,模型变得复杂,容易过拟合。其他的一些参数读者可以自行尝试。

        2.6 对数据敏感

        决策树对于数据是很敏感的

        

#6.对数据敏感
np.random.seed(6)
Xs = np.random.rand(100,2) - 0.5
ys = (Xs[:,0] > 0).astype(np.float32) * 2

angle = np.pi /4
rotation_matrix = np.array([[np.cos(angle),-np.sin(angle)],[np.sin(angle),np.cos(angle)]])
Xsr = Xs.dot(rotation_matrix)
 
tree_clf_s = DecisionTreeClassifier(random_state=42)
tree_clf_sr = DecisionTreeClassifier(random_state=42)
tree_clf_s.fit(Xs,ys)
tree_clf_sr.fit(Xsr,ys)

plt.figure(figsize=(11,4))
plt.subplot(121)
plot_decision_boundary(tree_clf_s,Xs,ys,axes=[-0.7,0.7,-0.7,0.7],iris=False)
plt.title('Sensitivity to training set rotation')

plt.subplot(122)
plot_decision_boundary(tree_clf_sr,Xsr,ys,axes=[-0.7,0.7,-0.7,0.7],iris=False)
plt.title('Sensitivity to training set rotation')

plt.show()

         这里是把数据又旋转了45度,然而决策边界并没有也旋转45度,却是变复杂了。可以看出,对于复杂的数据,决策树是很敏感的。

        2.7 回归任务

#7.回归任务 
np.random.seed(42)
m = 200
X = np.random.rand(m,1)
y = 4 * (X-0.5)**2
y = y + np.random.randn(m,1) /10
plt.plot(X,y,'b.')
from sklearn.tree import DecisionTreeRegressor
tree_reg = DecisionTreeRegressor(max_depth=2)
tree_reg.fit(X,y)
from sklearn.tree import export_graphviz
export_graphviz(
    tree_reg,
    out_file='regression_tree.dot',
    feature_names=['X1'],
    rounded=True,
    filled=True
)
from IPython.display import Image
Image(filename='regression_tree.png',width=400,height=400)

 

         回归任务,这里的衡量标准就变成了均方误差。

        2.8 对比树的深度对结果的影响

#8.对比树的深度对结果的影响
from sklearn.tree import DecisionTreeRegressor
tree_reg1 = DecisionTreeRegressor(random_state=42,max_depth=2)
tree_reg2 = DecisionTreeRegressor(random_state=42,max_depth=3)
tree_reg1.fit(X,y)
tree_reg2.fit(X,y)

def plot_regression_predictions(tree_reg,X,y,axes=[0,1,-0.2,1],ylabel='$y$'):
    x1 = np.linspace(axes[0],axes[1],500).reshape(-1,1)
    y_pred = tree_reg.predict(x1)
    plt.axis(axes)
    plt.xlabel('$X_1$',fontsize =18)
    if ylabel:
        plt.ylabel(ylabel,fontsize = 18,rotation=0)
    plt.plot(X,y,'b.')
    plt.plot(x1,y_pred,'r.-',linewidth=2,label=r'$\hat{y}$')


plt.figure(figsize=(11,4))
plt.subplot(121)

plot_regression_predictions(tree_reg1,X,y)
for split,style in ((0.1973,'k-'),(0.0917,'k--'),(0.7718,'k--')):
    plt.plot([split,split],[-0.2,1],style,linewidth = 2)
plt.text(0.21,0.65,'Depth=0',fontsize= 15)
plt.text(0.01,0.2,'Depth=1',fontsize= 13)
plt.text(0.65,0.8,'Depth=0',fontsize= 13)
plt.legend(loc='upper center',fontsize = 18)
plt.title('max_depth=2',fontsize=14)
plt.subplot(122)
plot_regression_predictions(tree_reg2,X,y)
for split,style in ((0.1973,'k-'),(0.0917,'k--'),(0.7718,'k--')):
    plt.plot([split,split],[-0.2,1],style,linewidth = 2)
for split in (0.0458,0.1298,0.2873,0.9040):
    plt.plot([split,split],[-0.2,1],linewidth = 1)
plt.text(0.3,0.5,'Depth=2',fontsize= 13)
plt.title('max_depth=3',fontsize=14)

plt.show()

        不同的树的深度,对于结果产生极大的影响

        2.9 剪枝

        

#9.加一些限制
tree_reg1 = DecisionTreeRegressor(random_state=42)
tree_reg2 = DecisionTreeRegressor(random_state=42,min_samples_leaf=10)
tree_reg1.fit(X,y)
tree_reg2.fit(X,y)

x1 = np.linspace(0,1,500).reshape(-1,1)
y_pred1 = tree_reg1.predict(x1)
y_pred2 = tree_reg2.predict(x1)

plt.figure(figsize=(11,4))

plt.subplot(121)
plt.plot(X,y,'b.')
plt.plot(x1,y_pred1,'r.-',linewidth=2,label=r'$\hat{y}$')
plt.axis([0,1,-0.2,1.1])
plt.xlabel('$x_1$',fontsize=18)
plt.ylabel('$y$',fontsize=18,rotation=0)
plt.legend(loc='upper center',fontsize =18)
plt.title('No restrctions',fontsize =14)

plt.subplot(122)
plt.plot(X,y,'b.')
plt.plot(x1,y_pred2,'r.-',linewidth=2,label=r'$\hat{y}$')
plt.axis([0,1,-0.2,1.1])
plt.xlabel('$x_1$',fontsize=18)
plt.ylabel('$y$',fontsize=18,rotation=0)
plt.legend(loc='upper center',fontsize =18)
plt.title('min_samples_leaf={}'.format(tree_reg2.min_samples_leaf),fontsize =14)

plt.show()

        一目了然。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/424032.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

matplotlib——散点图和条形图(python)

散点图 需求 我们获得北京2016年三月和十月每天白天最高气温&#xff0c;我们现在需要找出气温随时间变化的某种规律。 代码 # 导入库 from matplotlib import pyplot as plt import random# 解决中文乱码 import matplotlib matplotlib.rc("font",family"F…

详细讲解Docker架构的原理、功能以及如何使用

一、简介 1、了解docker的前生LXC LXC为Linux Container的简写。可以提供轻量级的虚拟化&#xff0c;以便隔离进程和资源&#xff0c;而且不需要提供指令解释机制以及全虚拟化的其他复杂性。相当于C中的NameSpace。容器有效地将由单个操作系统管理的资源划分到孤立的组中&…

如何解决线程安全问题(synchronized、原子性、产生线程不安全的原因,锁的特性,加锁的方式等等干货)

文章目录 &#x1f490;线程不安全的示例&#x1f490;锁的特性&#x1f490;产生线程不安全的原因&#xff1a;&#x1f490;加锁的三种方式 &#x1f490;线程不安全的示例 对于线程安全问题&#xff0c;这里用一个例子进行讲解&#x1f447;&#xff1a; 我现在定义一个变…

Image Fusion via Vision-Language Model【文献阅读】

阅读目录 文献阅读AbstractIntroduction3. Method3.1. Problem Overview3.2. Fusion via Vision-Language Model 4. Vision-Language Fusion Datasets5. Experiment5.1Infrared and Visible Image Fusion 6. Conclusion个人总结 文献阅读 原文下载&#xff1a;https://arxiv.or…

串及BF朴素查找算法(学习整理):

关于串的相关定义&#xff1a; 串&#xff1a;用‘ ’表示的字符序列空串&#xff1a;包含零个字符的串子串&#xff1a;包含传本身和空串的子串 eg: abc(,a,b,c,ab,bc,ac,abc)共7个&#xff1a;串的长度的阶乘1&#xff08;空串&#xff09;真子串&#xff1a;不包含自身的所…

Java进阶-IO(3)

话接上回&#xff0c;继续java IO的学习。上一次说完了字符流的读写数据&#xff0c;这次将基础部分剩余的一点内容看完。 一、流按功能分类 1、系统流 1.1 概述 系统流的类为 java.lang.System。Sytem 类封装了 Java 程序运行时的 3 个系统流。 System.in&#xff1a;标…

腾讯云幻兽帕鲁服务器中,如何检查并确保所有必要的配置文件(如PalWorldSettings.ini和WorldOption.sav)正确配置?

腾讯云幻兽帕鲁服务器中&#xff0c;如何检查并确保所有必要的配置文件&#xff08;如PalWorldSettings.ini和WorldOption.sav&#xff09;正确配置&#xff1f; 登录腾讯云控制台&#xff1a;登录轻量云控制台&#xff0c;找到部署了幻兽帕鲁的服务器&#xff0c;单击实例卡片…

基于BP-Adaboost的预测与分类,附MATLAB代码免费获取

今天为大家带来一期基于BP-Adaboost的预测与分类。代码中的BP可以替换为任意的机器学习算法。 原理详解 BP-AdaBoos模型先通过 AdaBoost集成算法串行训练多个基学习器并计算每个基学习 器的权重系数,接着将各个基学习器的预测结果进行线性组合,生成最终的预测结果。关于更多的原…

关于编写测试用例的一些思考

测试用例是QA同学的基本功&#xff0c;每个人都有一套编写测试用例的体系&#xff0c;本文是作者结合自身的工作经验以及阅读一些测试相关的书籍后的一些看法&#xff0c;欢迎大家一起讨论学习。 测试设计 测试用例格式 面试中一些常见的问题 1.APP测试与服务端测试的区别&am…

计算机设计大赛 深度学习火车票识别系统

文章目录 0 前言1 课题意义课题难点&#xff1a; 2 实现方法2.1 图像预处理2.2 字符分割2.3 字符识别部分实现代码 3 实现效果4 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 图像识别 火车票识别系统 该项目较为新颖&#xff0c;适…

StarRocks实战——首汽约车实时数仓实践

目录 前言 一、引入背景 二、OLAP引擎选型 三、架构演进 四、实时数仓构建 五、业务实践价值未来规划 原文大佬的这篇首汽约车实时数仓实践有借鉴意义&#xff0c;这里摘抄下来用作学习和知识沉淀。 前言 首汽约车&#xff08;以下简称“首约”&#xff09;是首汽集团打造…

滑动窗口问题

日升时奋斗&#xff0c;日落时自省 目录 一、长度最小的子数组 二、无重复字符的最长子串 三、最大连续1的个数III 四、将x减到0的最小操作数 五、水果成篮 六、找到字符串中所有字母异位词 七、串联所有单词的⼦串 八、最小覆盖子串 注&#xff1a;滑动窗口其实很类似…

图片按照宽度进行居中裁剪,缩放大小

要求 文件存放在img_folder_path中 裁剪要求&#xff1a; 图片大小以高度为基准。居中裁剪 缩放要求&#xff1a; 图片缩放到512大小 图片另存到save_file_path路径中 代码 import numpy as np import cv2 import os from tqdm import tqdm#原图片存放位置 img_folder_p…

操作系统原理与实验——实验三优先级进程调度

实验指南 运行环境&#xff1a; Dev c 算法思想&#xff1a; 本实验是模拟进程调度中的优先级算法&#xff0c;在先来先服务算法的基础上&#xff0c;只需对就绪队列到达时间进行一次排序。第一个到达的进程首先进入CPU&#xff0c;将其从就绪队列中出队后。若此后队首的进程的…

Spring重点记录

文章目录 1.Spring的组成2.Spring优点3.IOC理论推导4.IOC本质5.IOC实现&#xff1a;xml或者注解或者自动装配&#xff08;零配置&#xff09;。6.hellospring6.1beans.xml的结构为&#xff1a;6.2.Spring容器6.3对象的创建和控制反转 7.IOC创建对象方式7.1以有参构造的方式创建…

【硬件相关】RDMA网络类别及基础介绍

文章目录 一、前言1、RDMA网络协议2、TCP/IP网络协议 二、RDMA类别1、IB2、RoCE3、iWARP 三、RDMA对比1、优缺点说明a、性能b、扩展性c、维护难度 2、总结说明 一、前言 roce-vs-infiniband-vs-tcp-ip RoCE、IB和TCP等网络的基本知识及差异对比 分布式存储常见网络协议有TCP/IP…

【【C语言简单小题学习-1】】

实现九九乘法表 // 输出乘法口诀表 int main() {int i 0;int j 0;for (i 1; i < 9; i){for (j 1; j < i;j)printf("%d*%d%d ", i , j, i*j);printf("\n"); }return 0; }猜数字的游戏设计 #define _CRT_SECURE_NO_WARNINGS 1 #include<stdi…

c语言--qsort函数(详解)

目录 一、定义二、用qsort函数排序整型数据三、用qsort排序结构数据四、qsort函数的模拟实现 一、定义 二、用qsort函数排序整型数据 #include<stdio.h> scanf_S(int *arr,int sz) {for (int i 0; i < sz; i){scanf("%d", &arr[i]);} } int int_cmp(c…

点云数据结构化与体素化理论学习

一、PCD点云数据存储格式的进一步认识 &#xff08;一&#xff09;PCD点云存储格式相较于其它存储格式&#xff08;如PLY、STL、OBJ、X3D等&#xff09;的优势[1] &#xff08;1&#xff09;具有存储和处理有组织的点云数据集的能力&#xff0c;这对于实时应用和增强现实及机器…

GEE入门篇|图像处理(三):阈值处理、掩膜和重新映射图像

阈值处理、掩膜和重新映射图像 本章前一节讨论了如何使用波段运算来操作图像&#xff0c; 这些方法通过组合图像内的波段来创建新的连续值。 本期内容使用逻辑运算符对波段或索引值进行分类&#xff0c;以创建分类图像。 1.实现阈值 实现阈值使用数字&#xff08;阈值&#xf…