Numpy 实现ID3决策树

Numpy 实现ID3决策树

# 定义节点类 二叉树
class Node:
    def __init__(self, root=True, label=None, feature_name=None, feature=None):
        self.root = root
        self.label = label
        self.feature_name = feature_name
        self.feature = feature
        self.tree = {}
        self.result = {
            'label:': self.label,
            'feature': self.feature,
            'tree': self.tree
        }

    def __repr__(self):
        return '{}'.format(self.result)

    def add_node(self, val, node):
        self.tree[val] = node

    def predict(self, features):
        if self.root is True:
            return self.label
        return self.tree[features[self.feature]].predict(features)


class DTree:
    def __init__(self, epsilon=0.1):
        self.epsilon = epsilon
        self._tree = {}

    # 熵
    @staticmethod
    def calc_ent(datasets):
        data_length = len(datasets)
        label_count = {}
        for i in range(data_length):
            label = datasets[i][-1]
            if label not in label_count:
                label_count[label] = 0
            label_count[label] += 1
        ent = -sum([(p / data_length) * log(p / data_length, 2)
                    for p in label_count.values()])
        return ent

    # 经验条件熵
    def cond_ent(self, datasets, axis=0):
        data_length = len(datasets)
        feature_sets = {}
        for i in range(data_length):
            feature = datasets[i][axis]
            if feature not in feature_sets:
                feature_sets[feature] = []
            feature_sets[feature].append(datasets[i])
        cond_ent = sum([(len(p) / data_length) * self.calc_ent(p)
                        for p in feature_sets.values()])
        return cond_ent

    # 信息增益
    @staticmethod
    def info_gain(ent, cond_ent):
        return ent - cond_ent

    def info_gain_train(self, datasets):
        count = len(datasets[0]) - 1
        ent = self.calc_ent(datasets)
        best_feature = []
        for c in range(count):
            c_info_gain = self.info_gain(ent, self.cond_ent(datasets, axis=c))
            best_feature.append((c, c_info_gain))
        # 比较大小
        best_ = max(best_feature, key=lambda x: x[-1])
        return best_

    def train(self, train_data):
        """
        input:数据集D(DataFrame格式),特征集A,阈值eta
        output:决策树T
        """
        _, y_train, features = train_data.iloc[:, :
                                               -1], train_data.iloc[:,
                                                                    -1], train_data.columns[:
                                                                                            -1]
        # 1,若D中实例属于同一类Ck,则T为单节点树,并将类Ck作为结点的类标记,返回T
        if len(y_train.value_counts()) == 1:
            return Node(root=True, label=y_train.iloc[0])

        # 2, 若A为空,则T为单节点树,将D中实例树最大的类Ck作为该节点的类标记,返回T
        if len(features) == 0:
            return Node(
                root=True,
                label=y_train.value_counts().sort_values(
                    ascending=False).index[0])

        # 3,计算最大信息增益 同5.1,Ag为信息增益最大的特征
        max_feature, max_info_gain = self.info_gain_train(np.array(train_data))
        max_feature_name = features[max_feature]

        # 4,Ag的信息增益小于阈值eta,则置T为单节点树,并将D中是实例数最大的类Ck作为该节点的类标记,返回T
        if max_info_gain < self.epsilon:
            return Node(
                root=True,
                label=y_train.value_counts().sort_values(
                    ascending=False).index[0])

        # 5,构建Ag子集
        node_tree = Node(
            root=False, feature_name=max_feature_name, feature=max_feature)

        feature_list = train_data[max_feature_name].value_counts().index
        for f in feature_list:
            sub_train_df = train_data.loc[train_data[max_feature_name] ==
                                          f].drop([max_feature_name], axis=1)

            # 6, 递归生成树
            sub_tree = self.train(sub_train_df)
            node_tree.add_node(f, sub_tree)

        # pprint.pprint(node_tree.tree)
        return node_tree

    def fit(self, train_data):
        self._tree = self.train(train_data)
        return self._tree

    def predict(self, X_test):
        return self._tree.predict(X_test)

验证

import numpy as np
import pandas as pd
import math
from math import log
import pandas as pd


def create_data():
    datasets = [['青年', '否', '否', '一般', '否'],
               ['青年', '否', '否', '好', '否'],
               ['青年', '是', '否', '好', '是'],
               ['青年', '是', '是', '一般', '是'],
               ['青年', '否', '否', '一般', '否'],
               ['中年', '否', '否', '一般', '否'],
               ['中年', '否', '否', '好', '否'],
               ['中年', '是', '是', '好', '是'],
               ['中年', '否', '是', '非常好', '是'],
               ['中年', '否', '是', '非常好', '是'],
               ['老年', '否', '是', '非常好', '是'],
               ['老年', '否', '是', '好', '是'],
               ['老年', '是', '否', '好', '是'],
               ['老年', '是', '否', '非常好', '是'],
               ['老年', '否', '否', '一般', '否'],
               ]
    labels = [u'年龄', u'有工作', u'有自己的房子', u'信贷情况', u'类别']
    # 返回数据集和每个维度的名称
    return datasets, labels


datasets, labels = create_data()

train_data = pd.DataFrame(datasets, columns=labels)

datasets, labels = create_data()
data_df = pd.DataFrame(datasets, columns=labels)
dt = DTree()
tree = dt.fit(data_df)

tree

在这里插入图片描述

dt.predict(['老年', '否', '否', '一般'])

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/225450.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

HarmonyOS学习--TypeScript语言学习(一)

注意&#xff1a;这只是我学习的笔记&#xff01;&#xff01;&#xff01; 注意&#xff1a;这只是我学习的笔记&#xff01;&#xff01;&#xff01; 注意&#xff1a;这只是我学习的笔记&#xff01;&#xff01;&#xff01; 本章目录如下&#xff1a; 一、TypeScript语言…

汽车防爆膜行业研究:中国发展前景及市场投资分析

随着汽车保有量的不断增长&#xff0c;汽车的维修和保养等服务市场规模也会快速提升。业内人士表示&#xff0c;今年以来&#xff0c;越来越多的企业开始发力这一市场&#xff0c;汽车后市场的竞争区域也从大中城市向县域城市下沉。 防爆膜就是在车的玻璃上安装一层保护膜&…

各大期刊网址

1.NeurIPS&#xff0c;全称Annual Conference on Neural Information Processing Systems&#xff0c; 是机器学习领域的顶级会议&#xff0c;与ICML&#xff0c;ICLR并称为机器学习领域难度最大&#xff0c;水平最高&#xff0c;影响力最强的会议&#xff01; NeurIPS是CCF 推…

Android 背景边框集合

效果图 代码 <?xml version"1.0" encoding"utf-8"?> <shape xmlns:android"http://schemas.android.com/apk/res/android" android:shape"rectangle"><solid android:color"#ffffff" /><stroke and…

Doris 编译报错 Error: flex version (2.5.37) must be greater than or equal to 2.6.0

Doris 编译过程报错 Error: flex version (2.5.37) must be greater than or equal to 2.6.0yum update flex 不生效 下载flex 安装包 https://github.com/westes/flex/releases解压 tar -xvf flex-2.6.4.tar

java优雅写法01 -业务判断并断言-Assert工具类

断言的理解: 断言他满足条件不抛异常; 不满足断言时产生异常&#xff1b; 业务需求&#xff1a;判断某个查询对象为空&#xff0c;直接返回错误信息 常见写法&#xff1a; if (ObjectUtil.isNull(object)){ return XXXX }Assert 写法 这句代码的意思是 如果对象为空就提示并返…

你了解架构图吗?

技术界的扛把子架构图&#xff0c;你了解吗&#xff1f;今天我们就来聊聊它 什么是架构图呢&#xff1f; 系统架构是概念的体现&#xff0c;是对物/信息的功能与形式元素之间的对应情况所做的分配&#xff0c;是对元素之间的关系以及元素同周边环境之间的关系所做的定义。 架构…

微信小程序访问不了阿里云oss图片链接解决办法

以下都有可能导致访问不了oss图片 1.小程序没有加访问白名单 这个需要前端搞,加上白名单,如果是域名加域名白名单,ip的话加ip白名单 2.阿里云设置域名白名单 打开bucket列表,选择对应的bucket 配置这个白名单,配置好以后,开发者工具可以预览了,手机端预览不了,查看自己的路…

[AbutionGraph开发文档]时序图谱数据库-流式图计算

文档地址&#xff1a;https://thutmose.gitee.io/abution-graph AbutionGraph是一款端到端数据实时分析的图谱数据库&#xff0c;实时(写入实时、决策分析实时、流式图计算实时)&#xff1a; 基于历史数据构建的指标模型实时查询&#xff1b;接入流式数据并实时更新业务指标&a…

[足式机器人]Part2 Dr. CAN学习笔记-数学基础Ch0-6复数Complex Number

本文仅供学习使用 本文参考&#xff1a; B站&#xff1a;DR_CAN Dr. CAN学习笔记-数学基础Ch0-6复数Complex Number x 2 − 2 x 2 0 ⇒ x 1 i x^2-2x20\Rightarrow x1\pm i x2−2x20⇒x1i 代数表达&#xff1a; z a b i , R e ( z ) a , I m ( z ) b zabi,\mathrm{Re}…

【ARM Coresight 系列 2 文章 -- Trace32 对 APBIC 地址的配置 介绍】

文章目录 APBIC RomtableAPBIC Romtable 图 1 APBIC 网络图 如上图所示,如果想通过Trace32/DS-5 去访问 AP, 这个时候需要怎么做呢?可以看到 APBIC 中ROMTABLE 中 APB-AP 的偏移是0x00200000,所以 APB-AP的基地址为 0x2b000000 + 0x00200000,又从 APB-AP 的romtable 中可以…

基于ssm在线云音乐系统的设计与实现论文

摘 要 随着移动互联网时代的发展&#xff0c;网络的使用越来越普及&#xff0c;用户在获取和存储信息方面也会有激动人心的时刻。音乐也将慢慢融入人们的生活中。影响和改变我们的生活。随着当今各种流行音乐的流行&#xff0c;人们在日常生活中经常会用到的就是在线云音乐系统…

item_get_app_pro-根据ID取商品详情原数据接入参数和返回值说明

参数说明 pinduoduo.item_get_app_pro 公共参数 名称类型必须描述keyString是调用key&#xff08;申请调用免费测试&#xff09;secretString是调用密钥api_nameString是API接口名称&#xff08;包括在请求地址中&#xff09;[item_search,item_get,item_search_shop等]cache…

python中dir函数和help函数的用法和区别

背景 在python的学习和使用中&#xff0c;常常会涉及到很多第三方库的安装和使用&#xff0c;有时候我们需要借助某些第三方库的文档帮助我们更好更快速的学习第三方库中的函数或者类的使用方法&#xff0c;除了直接阅读说明文档外&#xff0c;借助于python的dir()函数和help()…

Pandas教程09:DataFrame数据可视化绘制折线图、柱状图、散点图、直方图等

pandas.plot() 是 pandas 库中的一个非常方便的函数&#xff0c;用于绘制各种图形&#xff0c;例如线图、柱状图、散点图等。以下是一些示例用法&#xff1a; 1.绘制一个简单的线图&#xff1a; # Author : 小红牛 # 微信公众号&#xff1a;wdPython import pandas as pd impo…

安防芯片步进驱动选型GC系列,低成本,大电流

GLOBALCHIP新一代的安防产品用的芯片&#xff0c;GC3901 GC3909 GC3910 GC3911 GC8548 GC8549 GC6609 GC6610 GC2003 GC2803 GC2804步进驱动 达林顿选型型号&#xff0c;采用12V H 桥驱动器。其中最大持续电流(A)可达2.5A。其中GC6610&#xff0c;GC6609​​​​​​​ 是36V&…

销售技巧培训之如何提升销售沟通技巧

销售技巧培训之如何提升销售沟通技巧 现在市场环境竞争越来越激烈&#xff0c;产品越来越过剩&#xff0c;如何把产品卖出去是摆在企业面前的难题。所以打造一致所向披靡的销售团队&#xff0c;提升销售人员的系统化销售能力就显得非常重要。在销售系统培训模块中&#xff0c;…

聚观早报 |JFrog发布新功能;中科百孚减持龙芯中科股票

【聚观365】12月7日消息 JFrog发布新功能 中科百孚减持龙芯中科股票 商汤集团再回应做空报告 xAI融资新进展 苹果市值再次突破 JFrog发布新功能 流式软件公司、企业软件供应链平台提供商JFrog发布新功能&#xff0c;推出业界首款致力于加速安全软件建构与发布的端到端平台…

Kubernetes(K8s 1.27.x) 快速上手+实践,无废话纯享版(视频笔记)

视频源&#xff1a;1.03-k8s是什么&#xff1f;_哔哩哔哩_bilibili 1 基础知识 1.1 K8s 有用么&#xff1f; K8s有没有用 K8s要不要学&#xff1f; 参考资料: https://www.infoq.com/articles/devops-and-cloud-trends-2022/?itm_sourcearticles_about_InfoQ-trends-report…

thinkphp 多表连接 子查询 group by 分组以最新的一条数据为组

用这个 $subQuery Db::name(wms_orderitems)->distinct(true)->field(kw_id,orders_id,product_id)->order(items_id desc)->buildSql();$list Db::name(wms_orders)->alias(order)->join($subQuery. item,item.orders_idorder.orders_id)->field(order…