机器学习随记(5)—决策树

手搓决策树:用决策树将其应用于分类蘑菇是可食用还是有毒的任务

温馨提示:下面为不完全代码,只是每个步骤代码的实现,需要完整跑通代码的同学不建议花时间看;适合了解决策树各个流程及代码实现的同学复习使用。


1 数据

1.1 one-hot编码数据集

1.2数据集:

X_train = np.array([[1,1,1],[1,0,1],[1,0,0],[1,0,0],[1,1,1],[0,1,1],[0,0,0],[1,0,1],[0,1,0],[1,0,0]])
y_train = np.array([1,1,0,0,1,0,0,1,1,0])
The shape of X_train is: (10, 3)
The shape of y_train is:  (10,)

10条数据,3个特征xi+1个目标y

2 计算熵

  • 计算𝑝1,这是可食用示例的一部分(即具有 value = in y
  • 然后计算熵:

 代码:

def compute_entropy(y):
    """
    Computes the entropy for 
    
    Args:
       y (ndarray): Numpy array indicating whether each example at a node is
           edible (`1`) or poisonous (`0`)
       
    Returns:
        entropy (float): Entropy at that node
        
    """
    # You need to return the following variables correctly
    entropy = 0.
    
    ### START CODE HERE ###
    if len(y) != 0:
        p1 = np.count_nonzero(y == 1)/len(y)
        if p1 != 0 and p1 != 1:
            entropy = -p1*np.log2(p1) - (1-p1)*np.log2(1-p1)
    ### END CODE HERE ###        
    
    return entropy

3 拆分数据集(分裂)

  • 该函数接收训练数据、该节点的数据点索引列表以及要拆分的特征。
  • 它拆分数据并返回左右分支的索引子集。
  • 例如,假设我们从根节点 (so node_indices = [0,1,2,3,4,5,6,7,8,9]) 开始,我们选择在特征上进行拆分0,即示例是否有棕色帽。
    • left_indices = [0,1,2,3,4,7,9]然后函数的输出是right_indices = [5,6,8]

split_dataset()下图所示的功能

  • 对于中的每个索引node_indices
    • X如果该特征在该索引处的值为1,则将该索引添加到left_indices
    • X如果该特征在该索引处的值为0,则将该索引添加到right_indices
def split_dataset(X, node_indices, feature):
    """
    Splits the data at the given node into
    left and right branches
    
    Args:
        X (ndarray):             Data matrix of shape(n_samples, n_features)
        node_indices (ndarray):  List containing the active indices. I.e, the samples being considered at this step.
        feature (int):           Index of feature to split on
    
    Returns:
        left_indices (ndarray): Indices with feature value == 1
        right_indices (ndarray): Indices with feature value == 0
    """
    
    # You need to return the following variables correctly
    left_indices = []
    right_indices = []
    
    ### START CODE HERE ###
    X_f = X[:,feature]
    for i in node_indices:
        if X_f[i] == 1:
            left_indices.append(i)
        elif X_f[i] == 0:
            right_indices.append(i)
    ### END CODE HERE ###
    return left_indices, right_indices

 4 计算信息增益

  • 𝐻(𝑝node1)是节点处的熵
  • 𝐻(𝑝left1) 和𝐻(𝑝right1)是由分裂产生的左分支和右分支的熵
  • 𝑤分别是左右分支的示例比例
def compute_information_gain(X, y, node_indices, feature):
    
    """
    Compute the information of splitting the node on a given feature
    
    Args:
        X (ndarray):            Data matrix of shape(n_samples, n_features)
        y (array like):         list or ndarray with n_samples containing the target variable
        node_indices (ndarray): List containing the active indices. I.e, the samples being considered in this step.
   
    Returns:
        cost (float):        Cost computed
    
    """    
    # Split dataset
    left_indices, right_indices = split_dataset(X, node_indices, feature)
    
    # Some useful variables
    X_node, y_node = X[node_indices], y[node_indices]
    X_left, y_left = X[left_indices], y[left_indices]
    X_right, y_right = X[right_indices], y[right_indices]
    
    # You need to return the following variables correctly
    information_gain = 0
    
    ### START CODE HERE ###
    
    # Weights 
    wl = len(X_left)/len(X_node)
    wr = len(X_right)/len(X_node)
    #Weighted entropy
    Hn = compute_entropy(y_node)
    Hl = compute_entropy(y_left)
    Hr = compute_entropy(y_right)
    #Information gain                                                   
    information_gain = Hn-(wl*Hl+wr*Hr)
    ### END CODE HERE ###  
    
    return information_gain

5 获得最佳划分(分裂)

get_best_split()如下所示的功能。

  • 该函数接收训练数据以及该节点的数据点索引
  • 函数的输出给出最大信息增益的特征
    • 您可以使用该compute_information_gain()函数迭代特征并计算每个特征的信息
def get_best_split(X, y, node_indices):   
    """
    Returns the optimal feature and threshold value
    to split the node data 
    
    Args:
        X (ndarray):            Data matrix of shape(n_samples, n_features)
        y (array like):         list or ndarray with n_samples containing the target variable
        node_indices (ndarray): List containing the active indices. I.e, the samples being considered in this step.

    Returns:
        best_feature (int):     The index of the best feature to split
    """    
    
    # Some useful variables
    num_features = X.shape[1]
    
    # You need to return the following variables correctly
    best_feature = -1
    gain_max = 0
    ### START CODE HERE ###
    for i in range(num_features):
        gain_ = compute_information_gain(X, y, node_indices, i)
        if gain_ > gain_max:
            gain_max = gain_
            best_feature = i
    ### END CODE HERE ##    
    return best_feature

6 构建树

在上面实现的函数来生成决策树,方法是连续选择最佳特征进行拆分,直到达到停止条件(最大深度为 2)。

tree = []

def build_tree_recursive(X, y, node_indices, branch_name, max_depth, current_depth):
    """
    Build a tree using the recursive algorithm that split the dataset into 2 subgroups at each node.
    This function just prints the tree.
    
    Args:
        X (ndarray):            Data matrix of shape(n_samples, n_features)
        y (array like):         list or ndarray with n_samples containing the target variable
        node_indices (ndarray): List containing the active indices. I.e, the samples being considered in this step.
        branch_name (string):   Name of the branch. ['Root', 'Left', 'Right']
        max_depth (int):        Max depth of the resulting tree. 
        current_depth (int):    Current depth. Parameter used during recursive call.
   
    """ 

    # Maximum depth reached - stop splitting
    if current_depth == max_depth:
        formatting = " "*current_depth + "-"*current_depth
        print(formatting, "%s leaf node with indices" % branch_name, node_indices)
        return
   
    # Otherwise, get best split and split the data
    # Get the best feature and threshold at this node
    best_feature = get_best_split(X, y, node_indices) 
    tree.append((current_depth, branch_name, best_feature, node_indices))
    
    formatting = "-"*current_depth
    print("%s Depth %d, %s: Split on feature: %d" % (formatting, current_depth, branch_name, best_feature))
    
    # Split the dataset at the best feature
    left_indices, right_indices = split_dataset(X, node_indices, best_feature)
    
    # continue splitting the left and the right child. Increment current depth
    build_tree_recursive(X, y, left_indices, "Left", max_depth, current_depth+1)
    build_tree_recursive(X, y, right_indices, "Right", max_depth, current_depth+1)
build_tree_recursive(X_train, y_train, root_indices, "Root", max_depth=2, current_depth=0)

(本示例问题来源Andrew NG 机器学习公开课)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/17546.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Redis7】Redis7 持久化(重点:RDB与AOF重写机制)

【大家好,我是爱干饭的猿,本文重点介绍Redis7 持久化(重点:RDB与AOF重写机制)。 后续会继续分享Redis7和其他重要知识点总结,如果喜欢这篇文章,点个赞👍,关注一下吧】 …

每天看一个fortran文件(7)之寻找cesm边界层高度计算代码

自从我把我的代码移到了单独的F90 中,从tphysac.F90中调用后。我发现有很多的变量我没办法调用了,特别是边界层中原先已经算好的变量,比如说我想调用原来模式中的pblh,但是那是在vertical diffusion.F90中计算的,在tphysac中根本调…

[python] 协程学习从0到1,配合案例,彻底理解协程,耗费资源不增加,效果接近多线程

文章目录 前言1.python 生成器1.1 python 生成器概述1.2 关键字yield/yield from1.3 next/send函数1.4 StopInteration异常1.5 利用生成器实现生产者-消费者模型1.6 生成器和协程的关系 2.生成器协程调度器3.python事件驱动编程4.实现协程调度器5.python 协程生态 前言 多进程…

ShardingCore安装笔记

由于本人采用Visual Studio的nuget管理器安装ShardingCore经常出现网络错误的问题,所以采用离线包的方式安装插件。 nueget包下载地址:NuGet Gallery | ShardingCore 7.7.1.8 ShardingCore使用版本7.7.1.7 1、下载各种依赖文件,并存放到系…

JAVA11新特性

JAVA11新特性 概述 2018年9月26日,Oracle官方发布JAVA11.这是JAVA大版本周期变化后的第一个长期支持版本,非常值得关注.最新发布的JAVA11将带来ZGC HttpClient等重要特性,一共17个需要我们关注的JEP,参考文档http://openjdk.java.net/projects/jdk/11/ 181:基于嵌套的访问控制…

HashMap 底层原理///HashMap详解

HashMap超详细探讨 总述从集合出发探讨HashMapCollection接口继承树Map接口继承树 从Map接口来,一步一步深入其中Map接口概述Map接口:常用方法Map接口 Map实现类之一:HashMapHashMap的存储结构HashMap源码中的重要常量JDK 1.8之前JDK 1.8总结…

《花雕学AI》28:革命性的 ChatGPT for SEO——让您的排名飙升 50%!

引言: 如果您想写篇有吸引力的文章,或者您是一个博客和网站的拥有者,那么您一定知道 SEO(搜索引擎优化)的重要性。SEO 可以帮助您提高相应的流量、转化率和收入,但是 SEO 也是一个复杂和耗时的过程&#x…

S3C6410 中的 cascaded irqdomain 之 gpio

文章目录 VIC 中断 与 gpio 中断 的硬件拓扑图描述linux cascaded irq domainirq domain 初始化时获取 IRQ number(软件中断号) 时中断发生时如何调试linux irq domain 实例 VIC domain 与 gpio domain 的硬件拓扑语言描述VIC 与 INT_EINTx 的关系INT_EINTx 与 GPIO的关系INT_E…

数电中需要注意的问题

逻辑函数表达式之间的相互转换 (更多请详见PPT)若题目要求用: 与非门实现逻辑函数,则画卡诺图圈出值为1的然后化简 或非门实现逻辑函数,则画卡诺图圈出值为0的然后化简 与或非门实现逻辑函数,则画卡诺图圈…

c++ 构造函数与析构函数

c构造函数: (1)构造函数名必须与类名相同 (2)无返回值 (3)访问权限符一般设置为public (4)无自定义构造函数,编译器提供默认构造函数,构造函数只调用一次 c析构函数: (1)析构函数与类名相同,前面加~符号 (2)析构…

纯比例控制为什么会存在稳态误差,用纯增益系统举例

warning: 本文仅为个人思考,非常不严谨甚至可能会出现严重错误,请读者仔细甄别,若本文真的存在严重错误,恳请评论区纠正,我看到将会考虑修改或者删除文章 纯比例控制存在稳态误差是由其本质(控制逻辑&#…

精炼计算机网络——物理层(一)

文章目录 前言2.1物理层的基本概念2.2 数据通信的基础知识2.2.1 数据通信系统的模型2.2.3 信道的极限容量 2.3 物理层下面的传输媒体2.3.1 导引型传输媒体2.3.2 非导引型传输媒体 总结 前言 经历了两篇文章的学习,相信读者们一定对计算机网络有了一个基础的了解。接…

一个新的ubuntu

1 安装cmake 方法一:现成的教程 ubuntu安装cmake_yuanzhoulvpi的博客-CSDN博客 方法二:自己总结的 安装openssl系统:sudo apt-get install libssl-dev 安装:sudo apt-get install build-essential 以上是防止安装cmake时缺少文…

PS磨皮插件portraiture最新版磨皮工具

Portraiture是一款智能磨皮插件,为Photoshop和Lightroom添加一键磨皮美化功能,快速对照片中皮肤、头发、眉毛等部位进行美化,无需手动调整,大大提高P图效率。全新4版本,升级AI算法,并独家支持多人及全身模式…

I2C工作流程

FM33A0XX的I2C接口只用作主机,且不支持多主机,因此挂在总线上的其他设备都是从机。总线上总是由主机提供同步时钟SCL,SDA数据流方向可以是主机发送从机接收,或者从机发送主机接收。 数据发送流程 1、主机发起 START 时序 2、主机…

C++之基础总结

目录 POD类型左值和右值静态全局变量(static)类型转换const/constexprconstconstexpr C中的关键字union基础知识点编译与函数参数入栈总结一些常见用法归纳: POD类型 平凡的和标准布局的——貌似和深度探索C对象模型中关于按位拷贝冲突 平凡的定义:符合…

Camtasia2023最好用的电脑屏幕录制软件

Camtasia2023是市场上最好的录像机和屏幕录制软件之一。强大的软件视频编辑程序的Camtasia 适用于Windows和iOS。 它支持多种流行的媒体格式,并对您创建的视频提供令人印象深刻的控制范围。3000多万专业人士在全球范围内使用Camtasia展示产品,教授课程&a…

文字的显示

文字的显示 文章目录 文字的显示1.文字编码方式2.英文和汉字的点阵显示3.显示中文“中”和“A”show_font.c结果 1.文字编码方式 数字>代表什么->显示为什么 GBK国标拓展 下列代码用不同编码方式保存utf-8.c ansi.c #include <stdio.h>int main(int argc ,char *…

网络编程之 Socket 套接字(使用数据报套接字和流套接字分别实现一个小程序(附源码))

文章目录 1. 什么是网络编程2. 网络编程中的基本概念1&#xff09;发送端和接收端2&#xff09;请求和响应3&#xff09;客户端和服务端4&#xff09;常见的客户端服务端模型 3. Socket 套接字1&#xff09;Socket 的分类2&#xff09;Java 数据报套接字通信模型3&#xff09;J…

基于Open3D的点云处理2-Open3D的IO与数据转换

三维数据类型 点云 某个坐标系下的点数据集&#xff0c;每个点包括三维坐标X&#xff0c;Y&#xff0c;Z、颜色、分类值、强度值、时间等信息&#xff1b; 储存格式&#xff1a;pts、LAS、PCD、xyz、asc、ply等&#xff1b;Mesh 多边形网格&#xff0c;常见的是三角网格&#…