用KNN算法预测iris数据集

鸢尾花(iris)数据集包含150条鸢尾花的数据,每条包含四个参数和一个标签,我们需要根据四个参数来预测出当前这朵花属于哪个类别。
KNN(K Nearest Neighbors)算法是经典的懒惰学习算法,也就是说它没有训练过程,直接根据已有数据进行预测。KNN基于这样一个假设:对当前要预测的数据t,找到和t最相似的k个已知标签的训练数据,然后获取这k条训练数据的标签,把出现次数最多的标签作为t的预测标签。从这个思想中我们可以看出,对于KNN来说,如果数据集中标签数量较少则效果会比较好,如果标签过多,很可能出现与t相似的数据只有非常少的几条,这种情况下就很难得到准确的分类。
代码如下:

# -*- coding=utf8 -*- from sklearn import datasets import numpy as np from sklearn.manifold import TSNE from sklearn.decomposition import PCA import matplotlib.pyplot as pltdef get_data(): # 从sklearn的数据集中获取iris数据 iris =datasets.load_iris() # 洗乱数据 indices = np.arange(len(iris.data)) np.random.shuffle(indices)# 查看数据的形状 print(iris.data.shape) print(iris.target.shape) print(type(iris)) print(type(iris.data)) # 分割训练数据集和测试数据集 split_index = int(len(iris.data) * 0.1)data = https://www.it610.com/article/iris.data[indices] target = iris.target[indices]tr_value = data[:-split_index] tr_target = target[:-split_index]te_value = data[-split_index:] te_target = target[-split_index:]return tr_value, tr_target, te_value, te_targetdef knn(tr_value, tr_target, te_value, te_target, k): hit = 0 for i in range(len(te_value)): # 利用numpy提供的大矩阵运算简洁地计算当前数据与所有测试数据之间的欧氏距离 one2n = np.tile(te_value[i], (int(len(tr_value)), 1)) distance = (((tr_value - one2n) ** 2).sum(axis=1)) ** 0.5count = {} # 根据距离从小到大排列训练数据的下标 sorted_distance = distance.argsort() # 统计前k个训练数据中哪个标签出现次数最多 # print(sorted_distance) for j in range(k): # print(distance[sorted_distance[j]]) tmp_tag = tr_target[sorted_distance[j]] # print(tmp_tag) if tmp_tag in count.keys(): count[tmp_tag] += 1 else: count[tmp_tag] = 1# 排序后选取出现次数最多的标签作为当前测试数据的预测结果 tag_count = sorted(count.items(), key=lambda x: x[1], reverse=True) print(te_target[i], tag_count[0][0]) if te_target[i] == tag_count[0][0]: hit += 1accuracy = hit / len(te_target) print('accuracy:%f\n' % accuracy)def visualize(): iris = datasets.load_iris() X_tsne = TSNE(learning_rate=100).fit_transform(iris.data) T = np.arctan2(X_tsne[:, 1], X_tsne[:, 0]) plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=T, label='embedding') for i in range(len(X_tsne)): x = X_tsne[i][0] y = X_tsne[i][1] plt.text(x=x + 0.1, y=y + 0.1, s=iris.target[i]) plt.xticks(()) plt.yticks(())plt.show()def main(): tr_value, tr_target, te_value, te_target = get_data() knn(tr_value, tr_target, te_value, te_target, 5)if __name__ == '__main__': main()

【用KNN算法预测iris数据集】最终可以得到100%的准确率。
当然准确率这么高是因为iris数据集的数据分布本来就比较干净,并且标签数量非常有限,我们可以使用上述代码中的visualize函数来看一下数据分布,数据分布如下:
用KNN算法预测iris数据集
文章图片

正如前文提到的,iris数据集本身就具备较好的性质,因此容易预测,但在真实场景中,数据分布往往都是比较复杂的,因此在真实应用场景中使用KNN时需要更加谨慎地操作每一步,尤其是数据预处理获取数据的特征以及距离度量方法的选择等。

    推荐阅读