教程目录
阅读:
CART分类树算法详解(Python代码实现)
CART算法(Classification And Regression Tree)是决策树的一种,如上所述,主要的决策树模型有 ID3 算法、C4.5 算法和 CART 算法,与 ID3 算法和 C4.5 算法不同的是,CART 算法既能处理分类问题也可以处理回归问题。
CART 算法既可以用于创建分类树(Classification Tree),也可以用于创建回归树(Regression Tree),本节主要是利用 CART 算法创建分类树。
现在,让我们一起利用 Python 实现 CART 决策树。为了能构建 CART 分类树算法,首先,需要为 CART 分类树中节点设置一个结构,并将其保存到 CART 树的文件“tree.py”中,其具体的实现如下所示:
当定义好树的节点后,利用训练数据训练 CART 分类树模型,其具体实现如下程序所示:
在构建 CART 分类树的过程中,首先是计算当前的 Gini 指数,函数 cal_gini_index 的具体实现如下程序所示:
在寻找最好划分的过程中,首先取得所有样本在 fea 特征处的可能的取值,并将其存储到字典 feature_values 中。对特征 fea 处的每一种可能取值,利用函数 split_tree 尝试将数据集 data 划分成左右子树 set_1 和 set_2。
函数 split_tree 按照指定的特征 fea 处的值 value 将数据集划分成左右子树,具体实现如下程序所示。
划分后,计算此时的 Gini 指数,此时的 Gini 指数为左右子树的 Gini 指数之和。判断当前的 Gini 指数与划分前 Gini 指数的变化,找到能够使得 Gini 指数变化最大的特征作为最终的划分标准。
CART 算法既可以用于创建分类树(Classification Tree),也可以用于创建回归树(Regression Tree),本节主要是利用 CART 算法创建分类树。
CART分类树的构建
在 CART 分类树算法中,利用 Gini 指数作为划分数的指标,通过样本中的特征,对样本进行划分,直到所有的叶节点中的所有样本都为同一个类别为止,CART 分类树的构建过程如下所示:- 对于当前训练数据集,遍历所有属性及其所有可能的切分点,寻找最佳切分属性及其最佳切分点,使得切分之后的基尼指数最小,利用该最佳属性及其最佳切分点将训练数据集切分成两个子集,分别对应判别结果为左子树和判别结果为右子树。
- 重复以下的步骤直至满足停止条件:为每一个叶子节点寻找最佳切分属性及其最佳切分点,将其划分为左右子树。
- 生成 CART 决策树。
现在,让我们一起利用 Python 实现 CART 决策树。为了能构建 CART 分类树算法,首先,需要为 CART 分类树中节点设置一个结构,并将其保存到 CART 树的文件“tree.py”中,其具体的实现如下所示:
class node: '''树的节点的类 ''' def __init__(self, fea=-1, value=None, results=None, right=None, left=None): self.fea = fea # 用于切分数据集的属性的列索引值 self.value = value # 设置划分的值 self.results = results # 存储叶节点所属的类别 self.right = right # 右子树 self.left = left # 左子树程序中,为树中节点设置 node 类,在 node 类中,属性 fea 表示的是待切分的特征的索引值,属性 value 表示的是待切分的特征的索引处的具体的值,当 node 为叶子节点时,属性 results 表示的是该叶子节点所属的类别,属性 right 表示的是树中节点 node 的右子树,属性 left 表示的是树中节点的左子树。
当定义好树的节点后,利用训练数据训练 CART 分类树模型,其具体实现如下程序所示:
def build_tree(data): '''构建树 input: data(list):训练样本 output: node:树的根结点 ''' # 构建决策树,函数返回该决策树的根节点 if len(data) == 0: return node() # 1、计算当前的Gini指数 currentGini = cal_gini_index(data) bestGain = 0.0 bestCriteria = None # 存储最佳切分属性以及最佳切分点 bestSets = None # 存储切分后的两个数据集 feature_num = len(data[0]) - 1 # 样本中特征的个数 # 2、找到最好的划分 for fea in range(0, feature_num): # 2.1、取得fea特征处所有可能的取值 feature_values = {} # 在fea位置处可能的取值 for sample in data: # 对每一个样本 feature_values[sample[fea]] = 1 # 存储特征fea处所有可能的取值 # 2.2、针对每一个可能的取值,尝试将数据集划分,并计算Gini指数 for value in feature_values.keys(): # 遍历该属性的所有切分点 # 2.2.1、 根据fea特征中的值value将数据集划分成左右子树 (set_1, set_2) = split_tree(data, fea, value) # 2.2.2、计算当前的Gini指数 nowGini = float(len(set_1) * cal_gini_index(set_1) + len(set_2) * cal_gini_index(set_2)) / len(data) # 2.2.3、计算Gini指数的增加量 gain = currentGini - nowGini # 2.2.4、判断此划分是否比当前的划分更好 if gain > bestGain and len(set_1) > 0 and len(set_2) > 0: bestGain = gain bestCriteria = (fea, value) bestSets = (set_1, set_2) # 3、判断划分是否结束 if bestGain > 0: right = build_tree(bestSets[0]) left = build_tree(bestSets[1]) return node(fea=bestCriteria[0], value=bestCriteria[1], right=right, left=left) else: return node(results=label_uniq_cnt(data)) # 返回当前的类别标签作为最终的类别标签程序中,函数 build_tree 用于构建 CART 分类树,构建分类树的过程主要分为如下 3 步:
- 计算当前的Gini指数;
- 尝试按照数据集中的每一个特征将树划分成左右子树,计算出最好的划分,通过迭代的方式继续对左右子树进行划分;
- 判断当前是否还可以继续划分,若不能继续划分则退出。
在构建 CART 分类树的过程中,首先是计算当前的 Gini 指数,函数 cal_gini_index 的具体实现如下程序所示:
def cal_gini_index(data): '''计算给定数据集的Gini指数 input: data(list):树中 output: gini(float):Gini指数 ''' total_sample = len(data) # 样本的总个数 if len(data) == 0: return 0 label_counts = label_uniq_cnt(data) # 统计数据集中不同标签的个数 # 计算数据集的Gini指数 gini = 0 for label in label_counts: gini = gini + pow(label_counts[label], 2) gini = 1 - float(gini) / pow(total_sample, 2) return gini在划分的过程中,需要按照 Gini 指数找到最好的划分。寻找最好的划分的方法是遍历所有的样本的特征,取得能够使得划分前后 Gini 指数的变化最大的特征,按照该特征的值将树划分成左右子树。
在寻找最好划分的过程中,首先取得所有样本在 fea 特征处的可能的取值,并将其存储到字典 feature_values 中。对特征 fea 处的每一种可能取值,利用函数 split_tree 尝试将数据集 data 划分成左右子树 set_1 和 set_2。
函数 split_tree 按照指定的特征 fea 处的值 value 将数据集划分成左右子树,具体实现如下程序所示。
def split_tree(data, fea, value): '''根据特征fea中的值value将数据集data划分成左右子树 input: data(list):数据集 fea(int):待分割特征的索引 value(float):待分割的特征的具体值 output: (set1,set2)(tuple):分割后的左右子树 ''' set_1 = [] set_2 = [] for x in data: if x[fea] >= value: set_1.append(x) else: set_2.append(x) return (set_1, set_2)该函数主要用于特征的值是连续的值时的划分,当特征 fea 处的值是一些连续值的时候,当该处的值大于或等于待划分的值 value 时,将该样本划分到 set_1 中;否则,划分到 set_2 中。
划分后,计算此时的 Gini 指数,此时的 Gini 指数为左右子树的 Gini 指数之和。判断当前的 Gini 指数与划分前 Gini 指数的变化,找到能够使得 Gini 指数变化最大的特征作为最终的划分标准。
利用构建好的分类树进行预测
当整个 CART 分类树构建完成后,利用训练样本对分类树进行训练,最终得到分类树的模型,对于未知的样本,需要用训练好的分类树的模型对其进行预测,对样本进行预测的过程如下程序所示:def predict(sample, tree): '''对每一个样本sample进行预测 input: sample(list):需要预测的样本 tree(类):构建好的分类树 output: tree.results:所属的类别 ''' # 1、只是树根 if tree.results != None: return tree.results else: # 2、有左右子树 val_sample = sample[tree.fea] branch = None if val_sample >= tree.value: branch = tree.right else: branch = tree.left return predict(sample, branch)该程序中,函数 predict 利用训练好的 CART 分类树模型 tree 对样本 sample 进行预测,当只有树根时,直接返回树根的类标签。若此时有左右子树,则根据指定的特征 fea 处的值进行比较,选择左右子树,直到找到最终的标签。