2|2 Kd树的构造与搜索
1 KD-Tree
实现kNN算法时,最简单的实现方法就是线性扫描,正如我们上一章节内容介绍的一样->K近邻算法,需要计算输入实例与每一个训练样本的距离。当训练集很大时,会非常耗时。
为了提高kNN搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数,KD-Tree就是其中的一种方法。
kd树是一个二叉树结构,相当于不断的用垂线将k维空间进行切分,构成一系列的k维超矩形区域。2 如何构造KD-Tree 2.1 KD-Tree算法如下: K维空间数据集
其中
- 构造根节点
选择为坐标轴,将T中所有实例以坐标为中位数,垂直轴切成两个矩形,由根节点生成深度为1的左、右两个子节点:左子节点对应的坐标都小于切分点,右子节点坐标都大于切分点坐标。
- 重复:对深度为j的节点,选择为切分的坐标轴, ,以该节点再次将矩形区域切分为两个子区域。
- 直到两个子区域没有实力存在时停止,从而形成KD-Tree的区域划分。
文章图片
13个随机点分布 首先先沿 x 坐标进行切分,我们选出 x 坐标的中位点,获取最根部节点的坐标
文章图片
根结点 并且按照该点的x坐标将空间进行切分,所有 x 坐标小于 6.27 的数据用于构建左分支,x坐标大于 6.27 的点用于构建右分支。
文章图片
在下一步中 ,对应 y 轴,左右两边再按照 y 轴的排序进行切分,中位点记载于左右枝的节点。得到下面的树,左边的 x 是指这该层的节点都是沿 x 轴进行分割的。
文章图片
空间的切分如下
文章图片
下一步中,对应 x 轴,所以下面再按照 x 坐标进行排序和切分,有
文章图片
文章图片
最后只剩下了叶子结点,就此完成了 kd 树的构造。
文章图片
文章图片
2.3 构造代码
class Node:
def __init__(self, data, depth=0, lchild=None, rchild=None):
self.data = https://www.it610.com/article/data# 此结点
self.depth = depth# 树的深度
self.lchild = lchild # 左子结点
self.rchild = rchild # 右子节点class KdTree:
def __init__(self):
self.KdTree = None
self.n = 0
self.nearest = Nonedef create(self, dataSet, depth=0):"""KD-Tree创建过程"""
if len(dataSet) > 0:
m, n = np.shape(dataSet)
self.n = n - 1
# 按照哪个维度进行分割,比如0:x轴,1:y轴
axis = depth % self.n
# 中位数
mid = int(m / 2)
# 按照第几个维度(列)进行排序
dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
# KD结点为中位数的结点,树深度为depth
node = Node(dataSetcopy[mid], depth)
if depth == 0:
self.KdTree = node
# 前mid行为左子结点,此时行数m改变,深度depth+1,axis会换个维度
node.lchild = self.create(dataSetcopy[:mid], depth+1)
node.rchild = self.create(dataSetcopy[mid+1:], depth+1)
return node
return None
3 搜索KD-Tree 输入:已构造的kd树,目标点x
输出:x的k个最近邻集合L
3.1 KD-Tree的最近邻搜索算法
- 从根结点出发,递归向下访问KD-Tree,如果目标点x当前维小于切分点坐标,移动到左子节点,否则右子节点,直到子节点为叶子结点为止。
- 以此叶子结点为最近邻的点,插入到集合L中
- 递归向上回退,在这个节点进行以下操作:
- a 如果该节点比L里的点更近,则替换集合L中距离最大的点。
- b 目标点到此节点的分割线垂直的距离为d,判断集合L中距离最大的点与 d 相比较,如果比d大,说明d的另一侧区域中有可能有比集合L中距离要小,因此需要查看d的左右两个子节点的距离。
如果集合L中距离最大的点比 d小,那说明另一侧区域的点距离目标点的距离都比d大,因此不用查找了,继续向上回退。
- 当回退到根结点时,搜索结束,最后的集合L里的k个点,就是x的最近邻点。
KD-Tree试用于训练样本数远大于空间维度的k近邻搜索。当空间维数接近训练样本数时,他的效率会迅速下降,几乎接近线性扫描。
3.3 实例说明 设我们想查询的点为 p=(?1,?5),设距离函数是普通的距离,我们想找距离目标点最近的 k=3 个点。如下:
文章图片
首先我们按照构造好的KD-Tree,从根结点开始查找
文章图片
和这个节点的 x 轴比较一下,p 的 x 轴更小。因此我们向左枝进行搜索:
文章图片
文章图片
接下来需要对比 y 轴
文章图片
p 的 y 值更小,因此向左枝进行搜索:
文章图片
这个节点只有一个子枝,就不需要对比了。由此找到了叶子节点 (?4.6,?10.55)。
文章图片
在二维图上是蓝色的点
文章图片
此时我们要执行第二步,将当前结点插入到集合L中,并记录下 L=[(?4.6,?10.55)]。访问过的节点就在二叉树上显示为被划掉的好了。
然后执行第三步,不是最顶端节点。我回退。上面的结点是 (?6.88,?5.4)。
文章图片
文章图片
执行 3a,因为我们记录下的点只有一个,小于 k=3,所以也将当前节点记录下,插入到集合L中,有 L=[(?4.6,?10.55),(?6.88,?5.4)].。 因为当前节点的左枝是空的,所以直接跳过,继续回退,判断不是顶部根节点
文章图片
文章图片
由于还是不够三个点,于是将当前点也插入到集合L中,有 L=[(?4.6,?10.55),(?6.88,?5.4),(1.24,?2.86)]。
此时发现,当前节点有其他的分枝,执行3b,计算得出 p 点和 L 中的三个点的距离分别是 6.62, 5.89, 3.10,但是 p 和当前节点的分割线的距离只有 2.14,小于与 L 的最大距离:
文章图片
到垂线距离小于L中最大的距离,说明垂线的另一侧可能有更近的点 因此,在分割线的另一端可能有更近的点。于是我们在当前结点的另一个分枝从头执行步骤1。好,我们在红线这里:
文章图片
此时处于x轴切分,因此要用 p 和这个节点比较 x 坐标:
文章图片
p 的 x 坐标更大,因此探索右枝 (1.75,12.26),并且发现右枝已经是最底部节点,执行步骤2与3a。
文章图片
经计算,(1.75,12.26) 与 p 的距离是 17.48,要大于 p 与 L 的距离,因此我们不将其放入记录中。
文章图片
然后 回退,判断出不是顶端节点,往上爬。
文章图片
执行3a,这个节点与 p 的距离是 4.91,要小于 p 与 L 的最大距离 6.62。
文章图片
因此,我们用这个新的节点替代 L 中离 p 最远的 (?4.6,?10.55)。
文章图片
然后3b,我们比对 p 和当前节点的分割线的距离
文章图片
image 这个距离小于 L 与 p 的最大距离,因此我们要到当前节点的另一个枝执行步骤1。当然,那个枝只有一个点。
文章图片
计算距离发现这个点离 p 比 L 更远,因此不进行替代。
文章图片
然后回退,不是根结点,我们向上爬
文章图片
image 这个是已经访问过的了,所以再向上爬
文章图片
再爬
文章图片
【2|2 Kd树的构造与搜索】此时到顶点了。所以完了吗?当然不,还要执行3b呢。现在是步骤1的回合。
我们进行计算比对发现顶端节点与p的距离比L还要更远,因此不进行更新。
文章图片
然后计算 p 和分割线的距离发现也是更远。
文章图片
因此也不需要检查另一个分枝。
判断当前节点是顶点,因此计算完成!输出距离 p 最近的三个样本是 L=[(?6.88,?5.4),(1.24,?2.86),(?2.96,?2.5)]。
3.3 代码
def search(self, x, count=1):
"""KD-Tree的搜索"""
nearest = [] # 记录近邻点的集合
for i in range(count):
nearest.append([-1, None])
self.nearest = np.array(nearest)def recurve(node):
"""内方法,负责查找count个近邻点"""
if node is not None:
# 步骤1:怎么找叶子节点
# 在哪个维度的分割线,0,1,0,1表示x,y,x,y
axis = node.depth % self.n
# 判断往左走or右走,递归,找到叶子结点
daxis = x[axis] - node.data[axis]
if daxis < 0:
recurve(node.lchild)
else:
recurve(node.rchild)# 步骤2:满足的就插入到近邻点集合中
# 求test点与此点的距离
dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data)))
# 遍历k个近邻点,如果不满k个,直接加入,如果距离比已有的近邻点距离小,替换掉,距离是从小到大排序的
for i, d in enumerate(self.nearest):
if d[0] < 0 or dist < d[0]:
self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
self.nearest = self.nearest[:-1]
break# 步骤3:判断与垂线的距离,如果比这大,要查找垂线的另一侧
n = list(self.nearest[:, 0]).count(-1)
# -n-1表示不为-1的最后一行,就是记录最远的近邻点(也就是最大的距离)
# 如果大于到垂线之间的距离,表示垂线的另一侧可能还有比他离的近的点
if self.nearest[-n-1, 0] > abs(daxis):
# 如果axis < 0,表示测量点在垂线的左侧,因此要在垂线右侧寻找点
if daxis < 0:
recurve(node.rchild)
else:
recurve(node.lchild)recurve(self.KdTree)# 调用根节点,开始查找
knn = self.nearest[:, 1]# knn为k个近邻结点
belong = []# 记录k个近邻结点的分类
for i in knn:
belong.append(i.data[-1])
b = max(set(belong), key=belong.count) # 找到测试点所属的分类return self.nearest, b
4 整体代码
import numpy as np
from math import sqrt
import pandas as pd
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_splitclass Node:
def __init__(self, data, depth=0, lchild=None, rchild=None):
self.data = https://www.it610.com/article/data# 此结点
self.depth = depth# 树的深度
self.lchild = lchild # 左子结点
self.rchild = rchild # 右子节点class KdTree:
def __init__(self):
self.KdTree = None
self.n = 0
self.nearest = Nonedef create(self, dataSet, depth=0):"""KD-Tree创建过程"""
if len(dataSet) > 0:
m, n = np.shape(dataSet)
self.n = n - 1
# 按照哪个维度进行分割,比如0:x轴,1:y轴
axis = depth % self.n
# 中位数
mid = int(m / 2)
# 按照第几个维度(列)进行排序
dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
# KD结点为中位数的结点,树深度为depth
node = Node(dataSetcopy[mid], depth)
if depth == 0:
self.KdTree = node
# 前mid行为左子结点,此时行数m改变,深度depth+1,axis会换个维度
node.lchild = self.create(dataSetcopy[:mid], depth+1)
node.rchild = self.create(dataSetcopy[mid+1:], depth+1)
return node
return Nonedef preOrder(self, node):
"""遍历KD-Tree"""
if node is not None:
print(node.depth, node.data)
self.preOrder(node.lchild)
self.preOrder(node.rchild)def search(self, x, count=1):
"""KD-Tree的搜索"""
nearest = [] # 记录近邻点的集合
for i in range(count):
nearest.append([-1, None])
self.nearest = np.array(nearest)def recurve(node):
"""内方法,负责查找count个近邻点"""
if node is not None:
# 步骤1:怎么找叶子节点
# 在哪个维度的分割线,0,1,0,1表示x,y,x,y
axis = node.depth % self.n
# 判断往左走or右走,递归,找到叶子结点
daxis = x[axis] - node.data[axis]
if daxis < 0:
recurve(node.lchild)
else:
recurve(node.rchild)# 步骤2:满足的就插入到近邻点集合中
# 求test点与此点的距离
dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data)))
# 遍历k个近邻点,如果不满k个,直接加入,如果距离比已有的近邻点距离小,替换掉,距离是从小到大排序的
for i, d in enumerate(self.nearest):
if d[0] < 0 or dist < d[0]:
self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
self.nearest = self.nearest[:-1]
break# 步骤3:判断与垂线的距离,如果比这大,要查找垂线的另一侧
n = list(self.nearest[:, 0]).count(-1)
# -n-1表示不为-1的最后一行,就是记录最远的近邻点(也就是最大的距离)
# 如果大于到垂线之间的距离,表示垂线的另一侧可能还有比他离的近的点
if self.nearest[-n-1, 0] > abs(daxis):
# 如果axis < 0,表示测量点在垂线的左侧,因此要在垂线右侧寻找点
if daxis < 0:
recurve(node.rchild)
else:
recurve(node.lchild)recurve(self.KdTree)# 调用根节点,开始查找
knn = self.nearest[:, 1]# knn为k个近邻结点
belong = []# 记录k个近邻结点的分类
for i in knn:
belong.append(i.data[-1])
b = max(set(belong), key=belong.count) # 找到测试点所属的分类return self.nearest, bdef show_train():
plt.scatter(x0[:, 0], x0[:, 1], c='pink', label='[0]')
plt.scatter(x1[:, 0], x1[:, 1], c='orange', label='[1]')
plt.xlabel('sepal length')
plt.ylabel('sepal width')if __name__ == "__main__":
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']data = https://www.it610.com/article/np.array(df.iloc[:100, [0, 1, -1]])
train, test = train_test_split(data, test_size=0.1)
x0 = np.array([x0 for i, x0 in enumerate(train) if train[i][-1] == 0])
x1 = np.array([x1 for i, x1 in enumerate(train) if train[i][-1] == 1])kdt = KdTree()
kdt.create(train)
kdt.preOrder(kdt.KdTree)score = 0
for x in test:
show_train()
plt.scatter(x[0], x[1], c='red', marker='x')# 测试点
near, belong = kdt.search(x[:-1], 5)# 设置临近点的个数
if belong == x[-1]:
score += 1
print(x, "predict:", belong)
print("nearest:")
for n in near:
print(n[1].data, "dist:", n[0])
plt.scatter(n[1].data[0], n[1].data[1], c='green', marker='+')# k个最近邻点
plt.legend()
plt.show()score /= len(test)
print("score:", score)
声明:此文章为本人学习笔记,参考于:https://zhuanlan.zhihu.com/p/23966698
如果您觉得有用,欢迎关注我的公众号,我会不定期发布自己的学习笔记、AI资料、以及感悟,欢迎留言,与大家一起探索AI之路。
文章图片
AI探索之路
推荐阅读
- 热闹中的孤独
- JAVA(抽象类与接口的区别&重载与重写&内存泄漏)
- 放屁有这三个特征的,请注意啦!这说明你的身体毒素太多
- 一个人的旅行,三亚
- 布丽吉特,人生绝对的赢家
- 慢慢的美丽
- 尽力
- 一个小故事,我的思考。
- 家乡的那条小河
- 《真与假的困惑》???|《真与假的困惑》??? ——致良知是一种伟大的力量