最近在看西瓜书中有关决策树的部分,就想用R语言建立简单的决策树模型,因为Python实在还不太会,哈。
这里为了方便,我就直接使用自带的数据集鸢尾花iris,用的R包有rpart和rpart.plot。rpart是一个专门用于做决策树模型的包,rpart.plot则用于绘制rpart模型。
为了方便理解和记忆,此处将模型的完整建立分成导入数据包/设置建模参数/数据切分/建模/调整模型参数并计算训练误差和测试误差这几个步骤。
以下是代码的具体实现部分。
首先加载需要的R包:
install.packages('rpart')
install.packages('rpart.plot')
library(rpart)
library(rpart.plot)
查看数据集,
iris
str(iris)
文章图片
主要要查看数据集的标签列的位置在哪一列,
还要注意标签列的数据类型必须为factor因子型,不然数据类型不对不好分类。
我们可以看到鸢尾花数据集的标签Species,类型为factor,因此不需要再转换数据类型。
对数据进行切分,随机分为训练集和测试集,
index <- sample(nrow(iris), 0.7*nrow(iris))
train <- iris[index, ]
test <- iris[-index, ]
设置建模控制参数,参数的设置在一定程度上可以防止模型过拟合。
其中rpart.control 参数minbucket 表示叶节点至少包含的样本数,少于这个数量就进行剪枝;参数maxdepth设置树的最大深度;xval是交叉验证次数;cp是树生长的最低增长指标,也就是每生长一步,对整体纯度提升的的最低指标,低于这个指标就进行剪枝。
tc <- rpart.control(minbucket=5,maxdepth=10,xval=5,cp=0.005)
接下来就可以用训练集建立模型啦,
fit <- rpart(Species ~ ., data=https://www.it610.com/article/train, control="tc")
然后用建立好的模型分别对训练集和测试集进行预测,并计算准确率。
其中table函数可以统计每个类别的频数,通过公式:预测正确的个数除以总数 可以很好的计算出准确率。
train.pred <- predict(fit, train[,-5], type="class")
table(train$Species == train.pred)['TRUE'] / length(train.pred)
test.pred <- predict(fit, test[,-5], type="class")
table(test$Species == test.pred)['TRUE'] / length(test.pred)
得到的结果如下,
文章图片
接下来就可以画一棵决策树了!
rpart.plot(fit, main="Decision Tree")
文章图片
从上面计算的模型的准确率来看,模型的泛化能力还是挺好的。
【用R语言实现决策树分类】如果想要对树进行剪枝的话,我们可以根据设置相应的cp值来进行剪枝,
fit$cptable
查看模型各层的cp值,一共有三层,
文章图片
如果要对最下面一层进行剪枝的话,我们要设置最低cp值略大于倒数第二层的cp值,这样的话倒数第二层就不会继续生长了。
prune(fit, 0.43077)
rpart.plot(prune(fit, 0.43077)
文章图片
我们还可以确定最佳cp值。
对控制参数重新设置,初始cp值设置为0,然后用同样的方法建立模型fit2,这里就不再重复了
fit2$cptable
文章图片
通过cptable图可以看出cp值在0.081时模型的准确率已经很好,大于这个值模型容易欠拟合,小于这个容易过拟合。
推荐阅读
- 决策树|决策树,随机森林,集成学习的算法实现
- 机器学习与数据挖掘|一文读懂常用机器学习解释性算法(特征权重,feature_importance, lime,shap)
- python|决策树——id3算法
- 机器学习|决策树算法(DecisionTree)
- 机器学习(原理+实现)|(Datawhale)基于决策树的分类预测
- 决策树|LightGBM算法详解(教你一文掌握LightGBM所有知识点)