一、决策树模型

决策树(Decision Tree)是一种常见的机器学习算法,而其核心便是“分而治之”的划分策略。例如,以西瓜为例,需要判断一个西瓜是否是一个好瓜,那么就可以根据经验考虑“西瓜是什么颜色?”,如果是“青绿色”,那么接着考虑“它的根蒂是什么形态?”,如果是“蜷缩”,那么再接着考虑“敲打声音如何?”,如果是“浊响”,那么就可以判定这个西瓜可能是一个“好瓜”。

上图便是一个决策树模型基本决策流程。由上图可见,决策树模型通常包含一个根结点,若干中间节点和若干叶结点。其中,每个叶结点保存了当前特征分支\(x_i\)的标记\(y_i\),而其余的每个结点则保存了当前结点划分的特征和划分的依据(特征值)。
下面是决策树模型训练的伪代码:

生成决策树(D,F):# D为数据集,F为特征向量
生成一个结点node
1)如果无法继续划分,标记为D中样本数最多的类:
a. D中标记相同,无需划分
b. F为空,无法划分
c. D在F上取值相同,无需划分
2)如果可以继续划分:
选择最优划分属性f;
遍历D在f上的每一种取值f‘:
为node生成一个新的分支;
D‘为D在f上取值为f’的子集;
(1)如果D‘为空,将分支节点标记为D中样本最多的类;
(2)如果D’不为空,生成决策树(D',F \ {f})

二、选择划分

由前面的伪代码不难看出,生成决策树的关键在于选择最优的划分属性。为了衡量各个特征划分的优劣,这里引入了纯度purity)的概念。即,希望随着划分的进行,决策树的分支节点所包含的样本尽可能属于同一类别。现在已经有了几种用来度量样本集纯度的指标:

2.1 信息熵和信息增益

信息熵(information entropy)是度量样本集纯度的常用指标,其定义如下:

\[Ent(D)=-\sum^{|y|}_{k=1}{p_k\log_2p_k} \]

其中,\(p_k\)\(D\)中第\(k\)类样本所占比例,而\(Ent(D)\)越小,则说明样本集的纯度越高。
以一个选定的特征将样本划分为\(v\)个子集,则可以分别计算出这\(v\)个子集的信息熵\(Ent(D^i)\),赋予每个子集权重\(\frac{|D^i|}{|D|}\)求出各个子集信息熵的加权平均值,便可以得到划分后的平均信息熵。定义信息增益information gain):

\[Gain(D, f)=Ent(D)-\sum^{v}_{i=1}{\frac{|D^i|}{|D|}Ent(D^i)} \]

那么,信息增益越大,说明样本集以特征\(f\)进行划分获得的纯度提升越大(划分后子节点的平均信息熵越小)。著名的ID3决策树便是采用信息增益作为标准来选择划分属性。

2.2 增益率

通过对信息熵公式进行分析不难发现,信息增益对于可取值数目较多的特征有所偏好。因为如果D在f上可取值较多,那么划分后各个子集的纯度往往较高,信息熵也会较小,最后产生大量的分支,严重影响决策树的泛化性能。所以C4.5决策树在信息增益的基础上引入了增益率gain rate)的概念:

\[Gain\_ratio(D,f)=\frac{Gain(D,f)}{IV(f)} \]
\[IV(f)=-\sum^{v}_{i=1}{\frac{|D^i|}{|D|}\log_2\frac{|D^i|}{|D|}} \]

其中\(IV\)称为属性f的固有值intrinsic value),属性f的可取值数目越多,\(IV\)越大。不难看出,增益率对于属性f可取值较少的属性有所偏好,所以通常的策略是:先选择信息增益高于平均水平的属性,然后从中选择增益率最小的属性进行划分。

2.3 基尼指数

CART决策树(分类回归树)采用基尼指数(Gini index)来选择划分属性。样本集\(D\)的纯度可用基尼值来度量,其定义如下:

\[Gini(D)=\sum^{|y|}_{k=1}\sum_{k'≠k}{p_kp_{k'}}=\sum^{|y|}_{k=1}p_k\sum_{k'≠k}p_{k'}=\sum^{|y|}_{k=1}{p_k(1-p_k)}=1-\sum^{|y|}_{k=1}p_k^2 \]

不难看出,\(Gini(D)\)反映了\(D\)中任意抽取两个样本,其类标不一致的概率。因此,基尼值越小,样本集纯度越高。那么属性f的基尼指数则定义为:

\[Gini\_index(D,f)=\sum^{v}_{i=1}{\frac{|D^i|}{|D|}Gini(D^i)} \]

因此,以基尼指数为评价标准的最优划分属性为\(f^*=\arg\min_{f\in F}Gini\_index(D, f)\)

三、剪枝

在训练决策树模型的时候,有时决策树会将训练集的一些特有性质当作一般性质进行了学习,从而产生过多的分支,不仅效率下降还可能导致过拟合over fitting)从而降低泛化性能。剪枝pruning)就是通过主动去掉决策树的一些分支从而防止过拟合的一种手段。

3.1 预剪枝

预剪枝prepruning)是指在生成决策树的过程中,对每个结点划分前进行模拟,如果划分后不能带来决策树泛化性能的提升,则停止划分并将当前结点标记为叶结点。

3.2 后剪枝

后剪枝post-pruning)则是指在生成一棵决策树后,自下而上地对非叶结点进行考察,如果将该结点对应的子树替换为叶结点能带来泛化性能的提升,则进行替换。

3.3 剪枝示例

那么如何判断模型的泛化性能是否提升?举个栗子来说明(周志华《机器学习》P80-P83):假设生成的决策树如下,

预留的验证集如下:

以预剪枝为例,首先看根结点。假设不进行划分,将结点1标记为\(D\)中样本最多的类别“是”,用验证集进行评估,易得准确率为\(\frac{3}{7}=42.9\%\)。如果进行划分,则结点2、3、4将分别被标记为“是”、“是”、“否”,用验证集评估则易得准确率为\(\frac{5}{7}=71.4\%>42.9\%\),所以继续划分。再看结点2,划分前验证集准确率为\(71.4\%\),划分后却下降到了\(57.1\%\),所以不进行划分。以此为例,最终可以将前面的决策树剪枝为如下形式:

而后剪枝则只是在生成决策树后,从下往上开始判断泛化性能,这里不再赘述(详情可见周志华《机器学习》P82-P83)。后剪枝后决策树形式如下:

3.4 预剪枝和后剪枝对比
项目 预剪枝 后剪枝 不剪枝
时间 生成决策树时 生成决策树后
方向 自上而下 自下而上
效率
拟合度 欠拟合风险 拟合较好 过拟合风险

四、Python实现

4.1 基尼值和基尼指数

基尼值:

def _gini(self, y): # 基尼值
    y_ps = []
    y_unque = np.unique(y)
    for y_u in y_unque:
        y_ps.append(np.sum(y == y_u) / len(y))
    return 1 - sum(np.array(y_ps) ** 2) 

基尼指数:

def _gini_index(self, X, y, feature): # 特征feature的基尼指数
    X_y = np.hstack([X, y.reshape(-1, 1)])
    unique_feature = np.unique(X_y[:, feature])
    gini_index = []
    for uf in unique_feature:
        sub_y = X_y[X_y[:, feature] == uf][:, X_y.shape[1] - 1]
        gini_index.append(len(sub_y) / len(y) * self._gini(sub_y))
    return sum(gini_index), feature 
4.2 选择划分特征

划分特征的选择依赖于前面的基尼指数函数:

def _best_feature(self, X, y, features): # 选择基尼指数最低的特征
    return min([self._gini_index(X, y, feature) for feature in features], key=lambda x:x[0])[1] 
4.3 后剪枝算法
def _post_pruning(self, X, y):
    nodes_mid = [] # 栈,存储所有中间结点
    nodes = [self.root] # 队列,用于辅助广度优先遍历
    while nodes: # 通过广度优先遍历找到所有中间结点
        node = nodes.pop(0)
        if node.sub_node:
            nodes_mid.append(node)
            for sub in node.sub_node:
            nodes.append(sub)
    while nodes_mid: # 开始剪枝
    		node = nodes_mid.pop(len(nodes_mid) - 1)
        y_pred = self.predict(X)
        from sklearn.metrics import accuracy_score
        score = accuracy_score(y, y_pred)
        temp = node.sub_node
        node.sub_node = None
        if accuracy_score(y, self.predict(X)) <= score: node.sub_node = temp 
4.4 训练算法

首先需要将数据集划分为训练集和验证集,训练集用于训练决策树,验证集用于后剪枝。训练算法按照伪代码编写即可。

def fit(self, X, y):
    # 将数据集划分为训练集和验证集
    X_train, X_valid, y_train, y_valid = train_test_split(X, y, train_size=0.7, test_size=0.3)
    queue = [[self.root, list(range(X_train.shape[0])), list(range(X_train.shape[1]))]]
    while queue: # 广度优先生成树
        node, indexs, features = queue.pop(0)
        node.y = ss.mode(y_train[indexs])[0][0] # 这里给每一个结点都添加了类标是为了防止测试集出现训练集中没有的特征值
        # 如果样本全部属于同一类别
        unique_y = np.unique(y_train[indexs])
        if len(unique_y) == 1:
            continue
        # 如果无法继续进行划分
        if len(features) < 2: if len(features) == 0 or len(np.unique(X_train[indexs, features[0]])) == 1: continue # 选择最优划分特征 feature = self._best_feature(X_train[indexs], y_train[indexs], features) node.feature = feature features.remove(feature) # 生成子节点 for uf in np.unique(X_train[indexs, feature]): sub_node = Node(value=uf) node.append(sub_node) new_indexs = [] for index in indexs: if X_train[index, feature] == uf: new_indexs.append(index) queue.append([sub_node, new_indexs, features]) self._post_pruning(X_valid, y_valid) return self 
4.6 导入鸢尾花数据集测试

导入鸢尾花数据集测试:

if __name__ == "__main__":
    from sklearn import datasets
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import KBinsDiscretizer
    from sklearn.metrics import classification_report

    iris = datasets.load_iris()
    X = iris.data
    y = iris.target
    X = KBinsDiscretizer(encode="ordinal").fit_transform(X) # 离散化
    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, test_size=0.3)
    classifier = DecisionTreeClassifier().fit(X_train, y_train)
    y_pred = classifier.predict(X_test)
    print(classification_report(y_test, y_pred)) 

分类报告:

剪枝决策树原理与Python实现的更多相关文章

  1. Python3利用scapy局域网实现自动多线程arp扫描功能

    一、所需Python库from scapy.all import *import threading二、实现ip扫描1.获取c段ip地址在ARP()里面有ip地址,我们可以从里面提取出前3段出来ARP().show()然后通过从后查找最后一个.得到最后一段位数,然后总长度-最后一段长度就能取出前3段......

  2. pandas读取excel,txt,csv,pkl文件等命令的操作

    pandas读取txt文件读取txt文件需要确定txt文件是否符合基本的格式,也就是是否存在\t,,,等特殊的分隔符一般txt文件长成这个样子txt文件举例下面的文件为空格间隔1 2019-03-22 00:06:24.4463094 中文测试 2 2019-03-22 00:06:32.45656......

  3. Python数据可视化分析--豆瓣电影Top250

    Python数据分析–豆瓣电影Top250利用Python爬取豆瓣电影TOP250并进行数据分析,对于众多爬虫爱好者,应该并不陌生。很多人都会以此作为第一个练手的小项目。当然这也多亏了豆瓣的包容,没有加以太多的反爬措施,对新手比较友好。手动声明版权声明:本文为博主原创文章,创作不易本文链接:http......

  4. 利用python为PostgreSQL的表自动添加分区

    PostgreSQL引进“分区”表特性,解放了之前采用“表继承”+“触发器”来实现分区表的繁琐、低效。而添加分区,都是手动执行SQL。 演示目的:利用python来为PostgreSQL的表自动添加分区。python版本:python3+ pip3 install psycopg2 一......

  5. python 爬虫

    学习python就一直想做爬虫的东西,还要继续学 理论上的东西一要加强 #!/usr/bin/python#coding=utf-8import urllibimport redef getHtml(url): page = urllib.urlopen(u......

  6. python3 如何读取python2的npy文件

    python3读取python2打包的npy文件会报错,原因是编码方式不同,所以只要在读取的时候加上编码方式即可。解决方法docs_train = np.load('./data/20news_clean/train.txt.npy', allow_pickle=True, encoding='by......

  7. Python 3的f-Strings:增强的字符串格式语法(指南)

    最近也在一个视频网站的爬虫,项目已经完成,中间有不少需要总结的经验。从Python 3.6开始,f-Strings是格式化字符串的一种很棒的新方法。与其他格式化方式相比,它们不仅更具可读性,更简洁且不易出错,而且速度更快!Python中的“老式”字符串格式化在Python 3.6之前,你有两种主要的......

  8. 执行py文件需要可执行权限吗?

    我们知道可执行权限x在Linux系统中的重要性,那么在执行py文件的过程中,是否一定需要可执行权限呢?本文将会详细的分析几种测试案例。案例解析这个问题描述起来有点违反直觉,要执行一个文件难道不应该需要可执行权限吗?让我们先来看一个例子:# module1.pydef test():print ('h......

  9. Python学习(4)( If 判断语句 、逻辑运算、elif、if嵌套、随机数、石头剪刀布程序)

    Python学习(4)一、python的 if 判断语句二、python的逻辑运算1. and2. or3. not三、python的 elif 判断语句四、python的if 嵌套五、随机数的处理六、石头剪刀布 ---演练一、python的 if 判断语句在python 中,if 语句 就是用来进......

  10. 用python做youtube自动化下载器 代码

    目录项目地址思路流程1. posti. 先把post中的headers格式化ii.然后把参数也格式化iii. 最后再执行requests库的post请求iv. 封装成一个函数2. 调用解密函数i. 分析ii. 先取出js部分iii. 取第一个解密函数作为我们用的解密函数iv. 用execjs执行1.......

随机推荐

  1. C# 使用Socket链接Ftp服务器下载上传代码FTPClient

    C#操作FTP的类,Socket实现,网上找到的,整理了一下,处理了一些BUG,喜欢的拿去用,但不保证全部BUG已捉完。using System;using System.Net;using System.IO;using System.Text;using System.Net.Sockets;n......

  2. 浅谈在Java中JSON的多种使用方式

    1. 常用的JSON转换JSONObject 转 JSON 字符串JSONObject json = new JSONObject();jsonObject.put("name", "test");String str = JSONObject.toJSONS......

  3. 透过现象看本质:Java类动态加载和热替换

    摘要:本文主要介绍类加载器、自定义类加载器及类的加载和卸载等内容,并举例介绍了Java类的热替换。最近,遇到了两个和Java类的加载和卸载相关的问题:1) 是一道关于Java的判断题:一个类被首次加载后,会长期留驻JVM,直到JVM退出。这个说法,是不是正确的?2) 在开发的一个集成平台中,需要集成......

  4. php结合GD库实现中文验证码的简单方法

    前言上一次写了一个常见的验证码,现在玩一下中文的验证码,顺便升级一下写的代码流程基本差不多先看GD库开启了没生成中文5位验证码开始画图画干扰素生成图形完事生成中文验证码?1234567891011//小小心机$hanzi= "如果觉得写得还可以的话互相关注报团取暖交流经验来自合肥的小码农巴......

  5. MYSQL大量写入问题优化详解

    摘要:大家提到Mysql的性能优化都是注重于优化sql以及索引来提升查询性能,大多数产品或者网站面临的更多的高并发数据读取问题。然而在大量写入数据场景该如何优化呢?今天这里主要给大家介绍,在有大量写入的场景,进行优化的方案。总的来说MYSQL数据库写入性能主要受限于数据库自身的配置,以及操作系统的性......

  6. 请谨慎使用 avaliable 方法来申请缓冲区

    问题今天开始尝试用 Java 写 http 服务器,开局就遇到 Bug。我先写了一个多线程的、BIO 的 http 服务器,其中接收请求的部分,会将请求的第一行打印出来。下面是浏览器发出的请求和控制台的输出情况。我们竟然收到了一个空的请求!!这是为什么呢?我解析请求的部分代码如下。// reques......

  7. Node.js 安全指南

    当项目周期快结束时,开发人员会越来越关注应用的“安全性”问题。一个安全的应用程序并不是一种奢侈,而是必要的。你应该在开发的每个阶段都考虑应用程序的安全性,例如系统架构、设计、编码,包括最后的部署。在这篇教程中,我们将一步步来学习如何提高Node.js应用程序安全性的方法。1. 数据验证 - 永远不要......

  8. Python 实现进度条的六种方式

    一、普通进度条示例代码import sysimport timedef progress_bar():for i in range(1, 101):print("\r", end="")print("Download progress: {}%: &......

  9. PHP实现爬虫爬取图片代码实例

    文字信息我们尝试获取表的信息,这里,我们就用某校的课表来代替: 接下来我们就上代码:a.php<?php header( "Content-type:text/html;Charset=utf-8" );$ch = curl_init()......

  10. python里glob模块知识点总结

    之前遇到过一类问题,要求快速做文件搜索,当时小编找了很多内容,但是没有发现实现方法,突然看到glob模块便豁然开朗了,该模块主要就是能够实现类似于windows的文件搜索,旗下的函数都可以实现搜索功能,并且有很多通配符,能够应用在多种场景中,一一对应的选择解决方案。简单介绍:匹配一定的格式文件和文件......