2|2 Kd树的构造与搜索

1 KD-Tree 实现kNN算法时,最简单的实现方法就是线性扫描,正如我们上一章节内容介绍的一样->K近邻算法,需要计算输入实例与每一个训练样本的距离。当训练集很大时,会非常耗时。
为了提高kNN搜索的效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离的次数,KD-Tree就是其中的一种方法。

kd树是一个二叉树结构,相当于不断的用垂线将k维空间进行切分,构成一系列的k维超矩形区域。
2 如何构造KD-Tree 2.1 KD-Tree算法如下: K维空间数据集
其中
  1. 构造根节点
    选择为坐标轴,将T中所有实例以坐标为中位数,垂直轴切成两个矩形,由根节点生成深度为1的左、右两个子节点:左子节点对应的坐标都小于切分点,右子节点坐标都大于切分点坐标。
  2. 重复:对深度为j的节点,选择为切分的坐标轴, ,以该节点再次将矩形区域切分为两个子区域。
  3. 直到两个子区域没有实力存在时停止,从而形成KD-Tree的区域划分。
2.2 举例说明KD-Tree构造 随机生成 13 个点作为我们的数据集
2|2 Kd树的构造与搜索
文章图片
13个随机点分布 首先先沿 x 坐标进行切分,我们选出 x 坐标的中位点,获取最根部节点的坐标
2|2 Kd树的构造与搜索
文章图片
根结点 并且按照该点的x坐标将空间进行切分,所有 x 坐标小于 6.27 的数据用于构建左分支,x坐标大于 6.27 的点用于构建右分支。
2|2 Kd树的构造与搜索
文章图片
在下一步中 ,对应 y 轴,左右两边再按照 y 轴的排序进行切分,中位点记载于左右枝的节点。得到下面的树,左边的 x 是指这该层的节点都是沿 x 轴进行分割的。

2|2 Kd树的构造与搜索
文章图片

空间的切分如下
2|2 Kd树的构造与搜索
文章图片
下一步中,对应 x 轴,所以下面再按照 x 坐标进行排序和切分,有

2|2 Kd树的构造与搜索
文章图片

2|2 Kd树的构造与搜索
文章图片
最后只剩下了叶子结点,就此完成了 kd 树的构造。
2|2 Kd树的构造与搜索
文章图片
2|2 Kd树的构造与搜索
文章图片
2.3 构造代码
class Node: def __init__(self, data, depth=0, lchild=None, rchild=None): self.data = https://www.it610.com/article/data# 此结点 self.depth = depth# 树的深度 self.lchild = lchild # 左子结点 self.rchild = rchild # 右子节点class KdTree: def __init__(self): self.KdTree = None self.n = 0 self.nearest = Nonedef create(self, dataSet, depth=0):"""KD-Tree创建过程""" if len(dataSet) > 0: m, n = np.shape(dataSet) self.n = n - 1 # 按照哪个维度进行分割,比如0:x轴,1:y轴 axis = depth % self.n # 中位数 mid = int(m / 2) # 按照第几个维度(列)进行排序 dataSetcopy = sorted(dataSet, key=lambda x: x[axis]) # KD结点为中位数的结点,树深度为depth node = Node(dataSetcopy[mid], depth) if depth == 0: self.KdTree = node # 前mid行为左子结点,此时行数m改变,深度depth+1,axis会换个维度 node.lchild = self.create(dataSetcopy[:mid], depth+1) node.rchild = self.create(dataSetcopy[mid+1:], depth+1) return node return None

3 搜索KD-Tree 输入:已构造的kd树,目标点x
输出:x的k个最近邻集合L
3.1 KD-Tree的最近邻搜索算法
  1. 从根结点出发,递归向下访问KD-Tree,如果目标点x当前维小于切分点坐标,移动到左子节点,否则右子节点,直到子节点为叶子结点为止。
  2. 以此叶子结点为最近邻的点,插入到集合L中
  3. 递归向上回退,在这个节点进行以下操作:
  • a 如果该节点比L里的点更近,则替换集合L中距离最大的点。
  • b 目标点到此节点的分割线垂直的距离为d,判断集合L中距离最大的点与 d 相比较,如果比d大,说明d的另一侧区域中有可能有比集合L中距离要小,因此需要查看d的左右两个子节点的距离。
    如果集合L中距离最大的点比 d小,那说明另一侧区域的点距离目标点的距离都比d大,因此不用查找了,继续向上回退。
  1. 当回退到根结点时,搜索结束,最后的集合L里的k个点,就是x的最近邻点。
3.2 时间复杂度 KD-Tree的平均时间复杂度为,N为训练样本的数量。
KD-Tree试用于训练样本数远大于空间维度的k近邻搜索。当空间维数接近训练样本数时,他的效率会迅速下降,几乎接近线性扫描。
3.3 实例说明 设我们想查询的点为 p=(?1,?5),设距离函数是普通的距离,我们想找距离目标点最近的 k=3 个点。如下:
2|2 Kd树的构造与搜索
文章图片
首先我们按照构造好的KD-Tree,从根结点开始查找
2|2 Kd树的构造与搜索
文章图片
和这个节点的 x 轴比较一下,p 的 x 轴更小。因此我们向左枝进行搜索:

2|2 Kd树的构造与搜索
文章图片
2|2 Kd树的构造与搜索
文章图片
接下来需要对比 y 轴

2|2 Kd树的构造与搜索
文章图片
p 的 y 值更小,因此向左枝进行搜索:
2|2 Kd树的构造与搜索
文章图片
这个节点只有一个子枝,就不需要对比了。由此找到了叶子节点 (?4.6,?10.55)。

2|2 Kd树的构造与搜索
文章图片
在二维图上是蓝色的点

2|2 Kd树的构造与搜索
文章图片
此时我们要执行第二步,将当前结点插入到集合L中,并记录下 L=[(?4.6,?10.55)]。访问过的节点就在二叉树上显示为被划掉的好了。
然后执行第三步,不是最顶端节点。我回退。上面的结点是 (?6.88,?5.4)。
2|2 Kd树的构造与搜索
文章图片
2|2 Kd树的构造与搜索
文章图片
执行 3a,因为我们记录下的点只有一个,小于 k=3,所以也将当前节点记录下,插入到集合L中,有 L=[(?4.6,?10.55),(?6.88,?5.4)].。 因为当前节点的左枝是空的,所以直接跳过,继续回退,判断不是顶部根节点
2|2 Kd树的构造与搜索
文章图片
2|2 Kd树的构造与搜索
文章图片
由于还是不够三个点,于是将当前点也插入到集合L中,有 L=[(?4.6,?10.55),(?6.88,?5.4),(1.24,?2.86)]。
此时发现,当前节点有其他的分枝,执行3b,计算得出 p 点和 L 中的三个点的距离分别是 6.62, 5.89, 3.10,但是 p 和当前节点的分割线的距离只有 2.14,小于与 L 的最大距离:
2|2 Kd树的构造与搜索
文章图片
到垂线距离小于L中最大的距离,说明垂线的另一侧可能有更近的点 因此,在分割线的另一端可能有更近的点。于是我们在当前结点的另一个分枝从头执行步骤1。好,我们在红线这里:
2|2 Kd树的构造与搜索
文章图片
此时处于x轴切分,因此要用 p 和这个节点比较 x 坐标:
2|2 Kd树的构造与搜索
文章图片
p 的 x 坐标更大,因此探索右枝 (1.75,12.26),并且发现右枝已经是最底部节点,执行步骤2与3a。

2|2 Kd树的构造与搜索
文章图片
经计算,(1.75,12.26) 与 p 的距离是 17.48,要大于 p 与 L 的距离,因此我们不将其放入记录中。

2|2 Kd树的构造与搜索
文章图片
然后 回退,判断出不是顶端节点,往上爬。
2|2 Kd树的构造与搜索
文章图片
执行3a,这个节点与 p 的距离是 4.91,要小于 p 与 L 的最大距离 6.62。

2|2 Kd树的构造与搜索
文章图片
因此,我们用这个新的节点替代 L 中离 p 最远的 (?4.6,?10.55)。
2|2 Kd树的构造与搜索
文章图片
然后3b,我们比对 p 和当前节点的分割线的距离
2|2 Kd树的构造与搜索
文章图片
image 这个距离小于 L 与 p 的最大距离,因此我们要到当前节点的另一个枝执行步骤1。当然,那个枝只有一个点。
2|2 Kd树的构造与搜索
文章图片
计算距离发现这个点离 p 比 L 更远,因此不进行替代。
2|2 Kd树的构造与搜索
文章图片
然后回退,不是根结点,我们向上爬
2|2 Kd树的构造与搜索
文章图片
image 这个是已经访问过的了,所以再向上爬
2|2 Kd树的构造与搜索
文章图片
再爬
2|2 Kd树的构造与搜索
文章图片
【2|2 Kd树的构造与搜索】此时到顶点了。所以完了吗?当然不,还要执行3b呢。现在是步骤1的回合。
我们进行计算比对发现顶端节点与p的距离比L还要更远,因此不进行更新。
2|2 Kd树的构造与搜索
文章图片
然后计算 p 和分割线的距离发现也是更远。
2|2 Kd树的构造与搜索
文章图片
因此也不需要检查另一个分枝。
判断当前节点是顶点,因此计算完成!输出距离 p 最近的三个样本是 L=[(?6.88,?5.4),(1.24,?2.86),(?2.96,?2.5)]。
3.3 代码
def search(self, x, count=1): """KD-Tree的搜索""" nearest = [] # 记录近邻点的集合 for i in range(count): nearest.append([-1, None]) self.nearest = np.array(nearest)def recurve(node): """内方法,负责查找count个近邻点""" if node is not None: # 步骤1:怎么找叶子节点 # 在哪个维度的分割线,0,1,0,1表示x,y,x,y axis = node.depth % self.n # 判断往左走or右走,递归,找到叶子结点 daxis = x[axis] - node.data[axis] if daxis < 0: recurve(node.lchild) else: recurve(node.rchild)# 步骤2:满足的就插入到近邻点集合中 # 求test点与此点的距离 dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data))) # 遍历k个近邻点,如果不满k个,直接加入,如果距离比已有的近邻点距离小,替换掉,距离是从小到大排序的 for i, d in enumerate(self.nearest): if d[0] < 0 or dist < d[0]: self.nearest = np.insert(self.nearest, i, [dist, node], axis=0) self.nearest = self.nearest[:-1] break# 步骤3:判断与垂线的距离,如果比这大,要查找垂线的另一侧 n = list(self.nearest[:, 0]).count(-1) # -n-1表示不为-1的最后一行,就是记录最远的近邻点(也就是最大的距离) # 如果大于到垂线之间的距离,表示垂线的另一侧可能还有比他离的近的点 if self.nearest[-n-1, 0] > abs(daxis): # 如果axis < 0,表示测量点在垂线的左侧,因此要在垂线右侧寻找点 if daxis < 0: recurve(node.rchild) else: recurve(node.lchild)recurve(self.KdTree)# 调用根节点,开始查找 knn = self.nearest[:, 1]# knn为k个近邻结点 belong = []# 记录k个近邻结点的分类 for i in knn: belong.append(i.data[-1]) b = max(set(belong), key=belong.count) # 找到测试点所属的分类return self.nearest, b

4 整体代码
import numpy as np from math import sqrt import pandas as pd from sklearn.datasets import load_iris import matplotlib.pyplot as plt from sklearn.model_selection import train_test_splitclass Node: def __init__(self, data, depth=0, lchild=None, rchild=None): self.data = https://www.it610.com/article/data# 此结点 self.depth = depth# 树的深度 self.lchild = lchild # 左子结点 self.rchild = rchild # 右子节点class KdTree: def __init__(self): self.KdTree = None self.n = 0 self.nearest = Nonedef create(self, dataSet, depth=0):"""KD-Tree创建过程""" if len(dataSet) > 0: m, n = np.shape(dataSet) self.n = n - 1 # 按照哪个维度进行分割,比如0:x轴,1:y轴 axis = depth % self.n # 中位数 mid = int(m / 2) # 按照第几个维度(列)进行排序 dataSetcopy = sorted(dataSet, key=lambda x: x[axis]) # KD结点为中位数的结点,树深度为depth node = Node(dataSetcopy[mid], depth) if depth == 0: self.KdTree = node # 前mid行为左子结点,此时行数m改变,深度depth+1,axis会换个维度 node.lchild = self.create(dataSetcopy[:mid], depth+1) node.rchild = self.create(dataSetcopy[mid+1:], depth+1) return node return Nonedef preOrder(self, node): """遍历KD-Tree""" if node is not None: print(node.depth, node.data) self.preOrder(node.lchild) self.preOrder(node.rchild)def search(self, x, count=1): """KD-Tree的搜索""" nearest = [] # 记录近邻点的集合 for i in range(count): nearest.append([-1, None]) self.nearest = np.array(nearest)def recurve(node): """内方法,负责查找count个近邻点""" if node is not None: # 步骤1:怎么找叶子节点 # 在哪个维度的分割线,0,1,0,1表示x,y,x,y axis = node.depth % self.n # 判断往左走or右走,递归,找到叶子结点 daxis = x[axis] - node.data[axis] if daxis < 0: recurve(node.lchild) else: recurve(node.rchild)# 步骤2:满足的就插入到近邻点集合中 # 求test点与此点的距离 dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.data))) # 遍历k个近邻点,如果不满k个,直接加入,如果距离比已有的近邻点距离小,替换掉,距离是从小到大排序的 for i, d in enumerate(self.nearest): if d[0] < 0 or dist < d[0]: self.nearest = np.insert(self.nearest, i, [dist, node], axis=0) self.nearest = self.nearest[:-1] break# 步骤3:判断与垂线的距离,如果比这大,要查找垂线的另一侧 n = list(self.nearest[:, 0]).count(-1) # -n-1表示不为-1的最后一行,就是记录最远的近邻点(也就是最大的距离) # 如果大于到垂线之间的距离,表示垂线的另一侧可能还有比他离的近的点 if self.nearest[-n-1, 0] > abs(daxis): # 如果axis < 0,表示测量点在垂线的左侧,因此要在垂线右侧寻找点 if daxis < 0: recurve(node.rchild) else: recurve(node.lchild)recurve(self.KdTree)# 调用根节点,开始查找 knn = self.nearest[:, 1]# knn为k个近邻结点 belong = []# 记录k个近邻结点的分类 for i in knn: belong.append(i.data[-1]) b = max(set(belong), key=belong.count) # 找到测试点所属的分类return self.nearest, bdef show_train(): plt.scatter(x0[:, 0], x0[:, 1], c='pink', label='[0]') plt.scatter(x1[:, 0], x1[:, 1], c='orange', label='[1]') plt.xlabel('sepal length') plt.ylabel('sepal width')if __name__ == "__main__": iris = load_iris() df = pd.DataFrame(iris.data, columns=iris.feature_names) df['label'] = iris.target df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']data = https://www.it610.com/article/np.array(df.iloc[:100, [0, 1, -1]]) train, test = train_test_split(data, test_size=0.1) x0 = np.array([x0 for i, x0 in enumerate(train) if train[i][-1] == 0]) x1 = np.array([x1 for i, x1 in enumerate(train) if train[i][-1] == 1])kdt = KdTree() kdt.create(train) kdt.preOrder(kdt.KdTree)score = 0 for x in test: show_train() plt.scatter(x[0], x[1], c='red', marker='x')# 测试点 near, belong = kdt.search(x[:-1], 5)# 设置临近点的个数 if belong == x[-1]: score += 1 print(x, "predict:", belong) print("nearest:") for n in near: print(n[1].data, "dist:", n[0]) plt.scatter(n[1].data[0], n[1].data[1], c='green', marker='+')# k个最近邻点 plt.legend() plt.show()score /= len(test) print("score:", score)

声明:此文章为本人学习笔记,参考于:https://zhuanlan.zhihu.com/p/23966698
如果您觉得有用,欢迎关注我的公众号,我会不定期发布自己的学习笔记、AI资料、以及感悟,欢迎留言,与大家一起探索AI之路。
2|2 Kd树的构造与搜索
文章图片
AI探索之路

    推荐阅读