机器学习(原理+实现)|(Datawhale)基于决策树的分类预测


文章目录

    • 1 学习目标
    • 2 决策树简介
    • 3 决策树原理
      • 3.1 构造
      • 3.2 剪枝
      • 3.3 信息熵
        • 3.3.1 信息增益(ID3 算法)
        • 3.3.2 信息增益率(C4.5 算法)
        • 3.3.3 基尼指数(Cart 算法)
        • 3.3.4 三种算法简单比较
    • 4 基于企鹅数据分类预测

1 学习目标 了解 决策树 的理论知识。
掌握 决策树 的 sklearn 函数调用使用并将其运用到企鹅数据集预测。
2 决策树简介 【机器学习(原理+实现)|(Datawhale)基于决策树的分类预测】决策树是一种常见的分类模型,在金融分控、医疗辅助诊断等诸多行业具有较为广泛的应用。决策树的核心思想是基于树结构对数据进行划分,这种思想是人类处理问题时的本能方法。例如在婚恋市场中,女方通常会先看男方是否有房产,如果有房产再看是否有车产,如果有车产再看是否有稳定工作……最后得出是否要深入了解的判断。
决策树的主要优点:
具有很好的解释性,模型可以生成可以理解的规则。
可以发现特征的重要程度。
模型的计算复杂度较低。
决策树的主要缺点:
模型容易过拟合,需要采用减枝技术处理。
不能很好利用连续型特征。
预测能力有限,无法达到其他强监督模型效果。
方差较高,数据分布的轻微改变很容易造成树结构完全不同。
3 决策树原理 3.1 构造
构造就是生成一棵完整的决策树。简单来说,构造的过程就是选择什么属性作为节点的过程,那么在构造过程中,会存在三种节点:
根节点:就是树的最顶端,最开始的那个节点。
内部节点:就是树中间的那些节点。
叶节点:就是树最底部的节点,也就是决策结果。
节点之间存在父子关系。比如根节点会有子节点,子节点会有子子节点,但是到了叶节点就停止了,叶节点不存在子节点。那么在构造过程中,你要解决三个重要的问题:
选择哪个属性作为根节点;
选择哪些属性作为子节点;
什么时候停止并得到目标状态,即叶节点。
3.2 剪枝
过拟合” 指的就是模型的训练结果“太好了”,以至于在实际应用的过程中,会存在“死板”的情况,导致分类错误。
造成过拟合的原因之一就是因为训练集中样本量较小。如果决策树选择的属性过多,构造出来的决策树一定能够“完美”地把训练集中的样本分类,但是这样就会把训练集中一些数据的特点当成所有数据的特点,但这个特点不一定是全部数据的特点,这就使得这个决策树在真实的数据分类中出现错误,也就是模型的“泛化能力”差。
预剪枝是在决策树构造时就进行剪枝。方法是在构造的过程中对节点进行评估,如果对某个节点进行划分,在验证集中不能带来准确性的提升,那么对这个节点进行划分就没有意义,这时就会把当前节点作为叶节点,不对其进行划分。
后剪枝就是在生成决策树之后再进行剪枝,通常会从决策树的叶节点开始,逐层向上对每个节点进行评估。如果剪掉这个节点子树,与保留该节点子树在分类准确性上差别不大,或者剪掉该节点子树,能在验证集中带来准确性的提升,那么就可以把该节点子树进行剪枝。方法是:用这个节点子树的叶子节点来替代该节点,类标记为这个节点子树中最频繁的那个类。
3.3 信息熵
纯度,就是把决策树的构造过程理解成为寻找纯净划分的过程。数学上,我们可以用纯度来表示,纯度换一种方式来解释就是让目标变量的分歧最小。
信息熵(entropy):它表示了信息的不确定度。
机器学习(原理+实现)|(Datawhale)基于决策树的分类预测
文章图片

p(i|t) 代表了节点 t 为分类 i 的概率,其中 log2 为取以 2 为底的对数。是说存在一种度量,它能帮我们反映出来这个信息的不确定度。当不确定性越大时,它所包含的信息量也就越大,信息熵也就越高。
信息熵越大,纯度越低。当集合中的所有样本均匀混合时,信息熵最大,纯度最低。
经典的 “不纯度”的指标有三种,分别是信息增益(ID3 算法)、信息增益率(C4.5 算法)以及基尼指数(Cart 算法)。
3.3.1 信息增益(ID3 算法) ID3 算法计算的是信息增益,信息增益指的就是划分可以带来纯度的提高,信息熵的下降。它的计算公式,是父亲节点的信息熵减去所有子节点的信息熵。在计算的过程中,我们会计算每个子节点的归一化信息熵,即按照每个子节点在父节点中出现的概率,来计算这些子节点的信息熵。所以信息增益的公式可以表示为:
机器学习(原理+实现)|(Datawhale)基于决策树的分类预测
文章图片

公式中 D 是父亲节点,Di 是子节点,Gain(D,a) 中的 a 作为 D 节点的属性选择。
3.3.2 信息增益率(C4.5 算法) C4.5 采用信息增益率的方式来选择属性。信息增益率 = 信息增益 / 属性熵。
当属性有很多值的时候,相当于被划分成了许多份,虽然信息增益变大了,但是对于 C4.5 来说,属性熵也会变大,所以整体的信息增益率并不大。
在 C4.5 中,会在决策树构造之后采用悲观剪枝(PEP),这样可以提升决策树的泛化能力。悲观剪枝是后剪枝技术中的一种,通过递归估算每个内部节点的分类错误率,比较剪枝前后这个节点的分类错误率来决定是否对其进行剪枝。这种剪枝方法不再需要一个单独的测试数据集。
C4.5 可以处理连续属性的情况,对连续的属性进行离散化的处理。
ID3 算法的优点是方法简单,缺点是对噪声敏感。训练数据如果有少量错误,可能会产生决策树分类错误。
C4.5 在 ID3 的基础上,用信息增益率代替了信息增益,解决了噪声敏感的问题,并且可以对构造树进行剪枝、处理连续数值以及数值缺失等情况,但是由于 C4.5 需要对数据集进行多次扫描,算法效率相对较低。
3.3.3 基尼指数(Cart 算法) 假设 t 为节点,那么该节点的 GINI 系数的计算公式为:
机器学习(原理+实现)|(Datawhale)基于决策树的分类预测
文章图片

这里 p(Ck|t) 表示节点 t 属于类别 Ck 的概率,节点 t 的基尼系数为 1 减去各类别 Ck 概率平方和。
CART 采用基尼系数作为节点划分的依据,得到的是离散的结果,也就是分类结果。
3.3.4 三种算法简单比较 ID3:以信息增益作为判断标准,计算每个特征的信息增益,选取信息增益最大的特征,但是容易选取到取值较多的特征。
C4.5:以信息增益比作为判断标准,计算每个特征的信息增益比,选取信息增益比最大的特征。
CART:分类树以基尼系数为标准,选取基尼系数小的的特征。
回归树以均方误差或绝对值误差为标准,选取均方误差或绝对值误差最小的特征。
4 基于企鹅数据分类预测 本次我们选择企鹅数据(palmerpenguins)进行方法的尝试训练,该数据集一共包含8个变量,其中7个特征变量,1个目标分类变量。共有150个样本,目标变量为 企鹅的类别 其都属于企鹅类的三个亚属,分别是(Adélie, Chinstrap and Gentoo)。包含的三种种企鹅的七个特征,分别是所在岛屿,嘴巴长度,嘴巴深度,脚蹼长度,身体体积,性别以及年龄。
##基础函数库 import numpy as np import pandas as pd## 绘图函数库 import matplotlib.pyplot as plt import seaborn as sns

data = https://www.it610.com/article/pd.read_csv('penguins_raw.csv') data = https://www.it610.com/article/data[['Species','Culmen Length (mm)','Culmen Depth (mm)', 'Flipper Length (mm)','Body Mass (g)']]

data.info()

RangeIndex: 344 entries, 0 to 343 Data columns (total 5 columns): Species344 non-null object Culmen Length (mm)342 non-null float64 Culmen Depth (mm)342 non-null float64 Flipper Length (mm)342 non-null float64 Body Mass (g)342 non-null float64 dtypes: float64(4), object(1) memory usage: 13.5+ KB

data.head()

Species Culmen Length (mm) Culmen Depth (mm) Flipper Length (mm) Body Mass (g)
0 Adelie Penguin (Pygoscelis adeliae) 39.1 18.7 181.0 3750.0
1 Adelie Penguin (Pygoscelis adeliae) 39.5 17.4 186.0 3800.0
2 Adelie Penguin (Pygoscelis adeliae) 40.3 18.0 195.0 3250.0
3 Adelie Penguin (Pygoscelis adeliae) NaN NaN NaN NaN
4 Adelie Penguin (Pygoscelis adeliae) 36.7 19.3 193.0 3450.0
'''发现数据集中存在NaN,一般的我们认为NaN在数据集中代表了缺失值, 可能是数据采集或处理时产生的一种错误。这里我们采用-1将缺失值进行填补, 还有其他例如“中位数填补、平均数填补”的缺失值处理方法''' data = https://www.it610.com/article/data.fillna(-1) data.tail()

Species Culmen Length (mm) Culmen Depth (mm) Flipper Length (mm) Body Mass (g)
339 Chinstrap penguin (Pygoscelis antarctica) 55.8 19.8 207.0 4000.0
340 Chinstrap penguin (Pygoscelis antarctica) 43.5 18.1 202.0 3400.0
341 Chinstrap penguin (Pygoscelis antarctica) 49.6 18.2 193.0 3775.0
342 Chinstrap penguin (Pygoscelis antarctica) 50.8 19.0 210.0 4100.0
343 Chinstrap penguin (Pygoscelis antarctica) 50.2 18.7 198.0 3775.0
data['Species'].unique()

array(['Adelie Penguin (Pygoscelis adeliae)', 'Gentoo penguin (Pygoscelis papua)', 'Chinstrap penguin (Pygoscelis antarctica)'], dtype=object)

#利用value_counts函数查看每个类别数量 pd.Series(data['Species']).value_counts()

Adelie Penguin (Pygoscelis adeliae)152 Gentoo penguin (Pygoscelis papua)124 Chinstrap penguin (Pygoscelis antarctica)68 Name: Species, dtype: int64

data.describe()

Culmen Length (mm) Culmen Depth (mm) Flipper Length (mm) Body Mass (g)
count 344.000000 344.000000 344.000000 344.000000
mean 43.660756 17.045640 199.741279 4177.319767
std 6.428957 2.405614 20.806759 861.263227
min -1.000000 -1.000000 -1.000000 -1.000000
25% 39.200000 15.500000 190.000000 3550.000000
50% 44.250000 17.300000 197.000000 4025.000000
75% 48.500000 18.700000 213.000000 4750.000000
max 59.600000 21.500000 231.000000 6300.000000
# 特征与标签组合的散点可视化 sns.pairplot(data=https://www.it610.com/article/data, diag_kind='hist', hue= 'Species') plt.show()

机器学习(原理+实现)|(Datawhale)基于决策树的分类预测
文章图片

#将标签转化为数字 def trans(x): if x == data['Species'].unique()[0]: return 0 if x == data['Species'].unique()[1]: return 1 if x == data['Species'].unique()[2]: return 2data['Species'] = data['Species'].apply(trans) for col in data.columns: if col != 'Species': sns.boxplot(x='Species', y=col, saturation=0.5, palette='pastel', data=https://www.it610.com/article/data) plt.title(col) plt.show()

机器学习(原理+实现)|(Datawhale)基于决策树的分类预测
文章图片

机器学习(原理+实现)|(Datawhale)基于决策树的分类预测
文章图片

机器学习(原理+实现)|(Datawhale)基于决策树的分类预测
文章图片

机器学习(原理+实现)|(Datawhale)基于决策树的分类预测
文章图片

# 选取其前三个特征绘制三维散点图 from mpl_toolkits.mplot3d import Axes3Dfig = plt.figure(figsize=(10,8)) ax = fig.add_subplot(111, projection='3d')data_class0 = data[data['Species']==0].values data_class1 = data[data['Species']==1].values data_class2 = data[data['Species']==2].values # 'setosa'(0), 'versicolor'(1), 'virginica'(2) ax.scatter(data_class0[:,0], data_class0[:,1], data_class0[:,2],label=data['Species'].unique()[0]) ax.scatter(data_class1[:,0], data_class1[:,1], data_class1[:,2],label=data['Species'].unique()[1]) ax.scatter(data_class2[:,0], data_class2[:,1], data_class2[:,2],label=data['Species'].unique()[2]) plt.legend()plt.show()

机器学习(原理+实现)|(Datawhale)基于决策树的分类预测
文章图片

# 为了正确评估模型性能,将数据划分为训练集和测试集,并在训练集上训练模型,在测试集上验证模型性能。 from sklearn.model_selection import train_test_split

# 选择其类别为0和1的样本 (不包括类别为2的样本) data_target_part = data[data['Species'].isin([0,1])][['Species']] data_features_part = data[data['Species'].isin([0,1])][['Culmen Length (mm)','Culmen Depth (mm)', 'Flipper Length (mm)','Body Mass (g)']]# 测试集大小为20%, 80%/20%分 x_train, x_test, y_train, y_test = train_test_split(data_features_part, data_target_part, test_size = 0.2, random_state = 2020)

# 从sklearn中导入决策树模型 from sklearn.tree import DecisionTreeClassifier from sklearn import tree # 定义 决策树模型 clf = DecisionTreeClassifier(criterion='entropy') ##在训练集上训练决策树模型 clf.fit(x_train, y_train)

DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=None, max_features=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort=False, random_state=None, splitter='best')

# 在训练集和测试集上分布利用训练好的模型进行预测 train_predict = clf.predict(x_train) test_predict = clf.predict(x_test) from sklearn import metrics # 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果 print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_train,train_predict)) print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_test,test_predict)) # 查看混淆矩阵 (预测值和真实值的各类情况统计矩阵) confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test) print('The confusion matrix result:\n',confusion_matrix_result) #利用热力图对于结果进行可视化 plt.figure(figsize=(8, 6)) sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues') plt.xlabel('Predicted labels') plt.ylabel('True labels') plt.show()

The accuracy of the Logistic Regression is: 0.9954545454545455 The accuracy of the Logistic Regression is: 1.0 The confusion matrix result: [[310] [ 0 25]]

机器学习(原理+实现)|(Datawhale)基于决策树的分类预测
文章图片

#我们可以发现其准确度为1,代表所有的样本都预测正确了

# 测试集大小为20%, 80%/20%分 x_train, x_test, y_train, y_test = train_test_split(data[['Culmen Length (mm)','Culmen Depth (mm)', 'Flipper Length (mm)','Body Mass (g)']], data[['Species']], test_size = 0.2, random_state = 2020) #定义 决策树模型 clf = DecisionTreeClassifier() # 在训练集上训练决策树模型 clf.fit(x_train, y_train)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None, max_features=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort=False, random_state=None, splitter='best')

# 在训练集和测试集上分布利用训练好的模型进行预测 train_predict = clf.predict(x_train) test_predict = clf.predict(x_test) #由于逻辑回归模型是概率预测模型(前文介绍的 p = p(y=1|x,\theta)),所有我们可以利用 predict_proba 函数预测其概率 train_predict_proba = clf.predict_proba(x_train) test_predict_proba = clf.predict_proba(x_test) print('The test predict Probability of each class:\n',test_predict_proba) #其中第一列代表预测为0类的概率,第二列代表预测为1类的概率,第三列代表预测为2类的概率。 #利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果 print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_train,train_predict)) print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_test,test_predict))

The test predict Probability of each class: [[0. 0. 1.] [0. 1. 0.] [0. 1. 0.] [1. 0. 0.] [1. 0. 0.] [0. 0. 1.] [0. 0. 1.] [1. 0. 0.] [0. 1. 0.] [1. 0. 0.] [0. 1. 0.] [0. 1. 0.] [1. 0. 0.] [0. 1. 0.] [0. 1. 0.] [0. 1. 0.] [1. 0. 0.] [0. 1. 0.] [1. 0. 0.] [1. 0. 0.] [0. 0. 1.] [1. 0. 0.] [0. 0. 1.] [1. 0. 0.] [1. 0. 0.] [1. 0. 0.] [0. 1. 0.] [1. 0. 0.] [0. 1. 0.] [1. 0. 0.] [1. 0. 0.] [0. 0. 1.] [0. 0. 1.] [0. 1. 0.] [1. 0. 0.] [0. 1. 0.] [0. 1. 0.] [1. 0. 0.] [1. 0. 0.] [0. 1. 0.] [0. 0. 1.] [1. 0. 0.] [0. 1. 0.] [1. 0. 0.] [1. 0. 0.] [0. 0. 1.] [0. 0. 1.] [1. 0. 0.] [1. 0. 0.] [0. 1. 0.] [1. 0. 0.] [1. 0. 0.] [0. 1. 0.] [0. 1. 0.] [0. 0. 1.] [0. 0. 1.] [0. 1. 0.] [1. 0. 0.] [1. 0. 0.] [1. 0. 0.] [0. 1. 0.] [0. 1. 0.] [0. 0. 1.] [0. 0. 1.] [1. 0. 0.] [0. 1. 0.] [0. 0. 1.] [1. 0. 0.] [1. 0. 0.]] The accuracy of the Logistic Regression is: 0.9963636363636363 The accuracy of the Logistic Regression is: 0.9565217391304348

# 查看混淆矩阵 confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test) print('The confusion matrix result:\n',confusion_matrix_result)# 利用热力图对于结果进行可视化 plt.figure(figsize=(8, 6)) sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues') plt.xlabel('Predicted labels') plt.ylabel('True labels') plt.show()

The confusion matrix result: [[3010] [ 0 230] [ 20 13]]

机器学习(原理+实现)|(Datawhale)基于决策树的分类预测
文章图片

    推荐阅读