用R语言实现决策树分类

最近在看西瓜书中有关决策树的部分,就想用R语言建立简单的决策树模型,因为Python实在还不太会,哈。
这里为了方便,我就直接使用自带的数据集鸢尾花iris,用的R包有rpart和rpart.plot。rpart是一个专门用于做决策树模型的包,rpart.plot则用于绘制rpart模型。
为了方便理解和记忆,此处将模型的完整建立分成导入数据包/设置建模参数/数据切分/建模/调整模型参数并计算训练误差和测试误差这几个步骤。
以下是代码的具体实现部分。
首先加载需要的R包:

install.packages('rpart') install.packages('rpart.plot') library(rpart) library(rpart.plot)

查看数据集,
iris str(iris)

用R语言实现决策树分类
文章图片

主要要查看数据集的标签列的位置在哪一列,
还要注意标签列的数据类型必须为factor因子型,不然数据类型不对不好分类。
我们可以看到鸢尾花数据集的标签Species,类型为factor,因此不需要再转换数据类型。
对数据进行切分,随机分为训练集和测试集,
index <- sample(nrow(iris), 0.7*nrow(iris)) train <- iris[index, ] test <- iris[-index, ]

设置建模控制参数,参数的设置在一定程度上可以防止模型过拟合。
其中rpart.control 参数minbucket 表示叶节点至少包含的样本数,少于这个数量就进行剪枝;参数maxdepth设置树的最大深度;xval是交叉验证次数;cp是树生长的最低增长指标,也就是每生长一步,对整体纯度提升的的最低指标,低于这个指标就进行剪枝。
tc <- rpart.control(minbucket=5,maxdepth=10,xval=5,cp=0.005)

接下来就可以用训练集建立模型啦,
fit <- rpart(Species ~ ., data=https://www.it610.com/article/train, control="tc")

然后用建立好的模型分别对训练集和测试集进行预测,并计算准确率。
其中table函数可以统计每个类别的频数,通过公式:预测正确的个数除以总数 可以很好的计算出准确率。
train.pred <- predict(fit, train[,-5], type="class") table(train$Species == train.pred)['TRUE'] / length(train.pred) test.pred <- predict(fit, test[,-5], type="class") table(test$Species == test.pred)['TRUE'] / length(test.pred)

得到的结果如下,
用R语言实现决策树分类
文章图片
接下来就可以画一棵决策树了!
rpart.plot(fit, main="Decision Tree")

用R语言实现决策树分类
文章图片

从上面计算的模型的准确率来看,模型的泛化能力还是挺好的。
【用R语言实现决策树分类】如果想要对树进行剪枝的话,我们可以根据设置相应的cp值来进行剪枝,
fit$cptable

查看模型各层的cp值,一共有三层,
用R语言实现决策树分类
文章图片

如果要对最下面一层进行剪枝的话,我们要设置最低cp值略大于倒数第二层的cp值,这样的话倒数第二层就不会继续生长了。
prune(fit, 0.43077) rpart.plot(prune(fit, 0.43077)

用R语言实现决策树分类
文章图片

我们还可以确定最佳cp值。
对控制参数重新设置,初始cp值设置为0,然后用同样的方法建立模型fit2,这里就不再重复了
fit2$cptable

用R语言实现决策树分类
文章图片

通过cptable图可以看出cp值在0.081时模型的准确率已经很好,大于这个值模型容易欠拟合,小于这个容易过拟合。

    推荐阅读