人工智能中文网
  • 主页
  • 线代考研视频
  • 线性代数
  • Python机器学习与算法
  • 大数据与机器学习
  • Python基础入门教程
  • 人工智能中文网
    教程目录
    阅读:

    CART分类树算法详解(Python代码实现)

    < 上一篇:决策树算法 下一篇:随机森林算法 >
    CART算法(Classification And Regression Tree)是决策树的一种,如上所述,主要的决策树模型有 ID3 算法、C4.5 算法和 CART 算法,与 ID3 算法和 C4.5 算法不同的是,CART 算法既能处理分类问题也可以处理回归问题。

    CART 算法既可以用于创建分类树(Classification Tree),也可以用于创建回归树(Regression Tree),本节主要是利用 CART 算法创建分类树。

    CART分类树的构建

    在 CART 分类树算法中,利用 Gini 指数作为划分数的指标,通过样本中的特征,对样本进行划分,直到所有的叶节点中的所有样本都为同一个类别为止,CART 分类树的构建过程如下所示:
    • 对于当前训练数据集,遍历所有属性及其所有可能的切分点,寻找最佳切分属性及其最佳切分点,使得切分之后的基尼指数最小,利用该最佳属性及其最佳切分点将训练数据集切分成两个子集,分别对应判别结果为左子树和判别结果为右子树。
    • 重复以下的步骤直至满足停止条件:为每一个叶子节点寻找最佳切分属性及其最佳切分点,将其划分为左右子树。
    • 生成 CART 决策树。

    现在,让我们一起利用 Python 实现 CART 决策树。为了能构建 CART 分类树算法,首先,需要为 CART 分类树中节点设置一个结构,并将其保存到 CART 树的文件“tree.py”中,其具体的实现如下所示:
    class node:
        '''树的节点的类
        '''
        def __init__(self, fea=-1, value=None, results=None, right=None, left=None):
            self.fea = fea  # 用于切分数据集的属性的列索引值
            self.value = value  # 设置划分的值
            self.results = results  # 存储叶节点所属的类别
            self.right = right  # 右子树
            self.left = left  # 左子树
    程序中,为树中节点设置 node 类,在 node 类中,属性 fea 表示的是待切分的特征的索引值,属性 value 表示的是待切分的特征的索引处的具体的值,当 node 为叶子节点时,属性 results 表示的是该叶子节点所属的类别,属性 right 表示的是树中节点 node 的右子树,属性 left 表示的是树中节点的左子树。

    当定义好树的节点后,利用训练数据训练 CART 分类树模型,其具体实现如下程序所示:
    def build_tree(data):
        '''构建树
        input:  data(list):训练样本
        output: node:树的根结点
        '''
        # 构建决策树,函数返回该决策树的根节点
        if len(data) == 0:
            return node()
        # 1、计算当前的Gini指数
        currentGini = cal_gini_index(data)
        bestGain = 0.0
        bestCriteria = None  # 存储最佳切分属性以及最佳切分点
        bestSets = None  # 存储切分后的两个数据集
        feature_num = len(data[0]) - 1  # 样本中特征的个数
        # 2、找到最好的划分
        for fea in range(0, feature_num):
            # 2.1、取得fea特征处所有可能的取值
            feature_values = {}  # 在fea位置处可能的取值
            for sample in data:  # 对每一个样本
                feature_values[sample[fea]] = 1  # 存储特征fea处所有可能的取值
            # 2.2、针对每一个可能的取值,尝试将数据集划分,并计算Gini指数
            for value in feature_values.keys():  # 遍历该属性的所有切分点
                # 2.2.1、 根据fea特征中的值value将数据集划分成左右子树
                (set_1, set_2) = split_tree(data, fea, value)
                # 2.2.2、计算当前的Gini指数
                nowGini = float(len(set_1) * cal_gini_index(set_1) + 
                                 len(set_2) * cal_gini_index(set_2)) / len(data)
                # 2.2.3、计算Gini指数的增加量
                gain = currentGini - nowGini
                # 2.2.4、判断此划分是否比当前的划分更好
                if gain > bestGain and len(set_1) > 0 and len(set_2) > 0:
                    bestGain = gain
                    bestCriteria = (fea, value)
                    bestSets = (set_1, set_2)
        # 3、判断划分是否结束
        if bestGain > 0:
            right = build_tree(bestSets[0])
            left = build_tree(bestSets[1])
            return node(fea=bestCriteria[0], value=bestCriteria[1], 
                        right=right, left=left)
        else:
            return node(results=label_uniq_cnt(data))  # 返回当前的类别标签作为最终的类别标签
    程序中,函数 build_tree 用于构建 CART 分类树,构建分类树的过程主要分为如下 3 步:
    1. 计算当前的Gini指数;
    2. 尝试按照数据集中的每一个特征将树划分成左右子树,计算出最好的划分,通过迭代的方式继续对左右子树进行划分;
    3. 判断当前是否还可以继续划分,若不能继续划分则退出。

    在构建 CART 分类树的过程中,首先是计算当前的 Gini 指数,函数 cal_gini_index 的具体实现如下程序所示:
    def cal_gini_index(data):
        '''计算给定数据集的Gini指数
        input:  data(list):树中
        output: gini(float):Gini指数
        '''
        total_sample = len(data)  # 样本的总个数
        if len(data) == 0:
            return 0  
        label_counts = label_uniq_cnt(data)  # 统计数据集中不同标签的个数
        # 计算数据集的Gini指数
        gini = 0
        for label in label_counts:
            gini = gini + pow(label_counts[label], 2) 
        gini = 1 - float(gini) / pow(total_sample, 2)
        return gini
    在划分的过程中,需要按照 Gini 指数找到最好的划分。寻找最好的划分的方法是遍历所有的样本的特征,取得能够使得划分前后 Gini 指数的变化最大的特征,按照该特征的值将树划分成左右子树。

    在寻找最好划分的过程中,首先取得所有样本在 fea 特征处的可能的取值,并将其存储到字典 feature_values 中。对特征 fea 处的每一种可能取值,利用函数 split_tree 尝试将数据集 data 划分成左右子树 set_1 和 set_2。

    函数 split_tree 按照指定的特征 fea 处的值 value 将数据集划分成左右子树,具体实现如下程序所示。
    def split_tree(data, fea, value):
        '''根据特征fea中的值value将数据集data划分成左右子树
        input:  data(list):数据集
                fea(int):待分割特征的索引
                value(float):待分割的特征的具体值
        output: (set1,set2)(tuple):分割后的左右子树
        '''
        set_1 = []
        set_2 = []
        for x in data:
            if x[fea] >= value:
                set_1.append(x)
            else:
                set_2.append(x)
        return (set_1, set_2)
    该函数主要用于特征的值是连续的值时的划分,当特征 fea 处的值是一些连续值的时候,当该处的值大于或等于待划分的值 value 时,将该样本划分到 set_1 中;否则,划分到 set_2 中。

    划分后,计算此时的 Gini 指数,此时的 Gini 指数为左右子树的 Gini 指数之和。判断当前的 Gini 指数与划分前 Gini 指数的变化,找到能够使得 Gini 指数变化最大的特征作为最终的划分标准。

    利用构建好的分类树进行预测

    当整个 CART 分类树构建完成后,利用训练样本对分类树进行训练,最终得到分类树的模型,对于未知的样本,需要用训练好的分类树的模型对其进行预测,对样本进行预测的过程如下程序所示:
    def predict(sample, tree):
        '''对每一个样本sample进行预测
        input:  sample(list):需要预测的样本
                tree(类):构建好的分类树
        output: tree.results:所属的类别
        '''
        # 1、只是树根
        if tree.results != None:
            return tree.results
        else:
        # 2、有左右子树
            val_sample = sample[tree.fea]
            branch = None
            if val_sample >= tree.value:
                branch = tree.right
            else:
                branch = tree.left
            return predict(sample, branch)
    该程序中,函数 predict 利用训练好的 CART 分类树模型 tree 对样本 sample 进行预测,当只有树根时,直接返回树根的类标签。若此时有左右子树,则根据指定的特征 fea 处的值进行比较,选择左右子树,直到找到最终的标签。
    < 上一篇:决策树算法 下一篇:随机森林算法 >