给定一个二叉搜索树(BST),找到树中第 K 小的节点

题目:给定一个二叉搜索树(BST),找到树中第 K 小的节点。
参考答案:
* 考察点

  1. 基础数据结构的理解和编码能力
  2. 递归使用
* 示例
5 / \ 36 / \ 24 / 1

说明:保证输入的 K 满足 1< =K< =(节点数目)
解法1:树相关的题目,第一眼就想到递归求解,左右子树分别遍历。联想到二叉搜索树的性质,root 大于左子树,小于右子树,如果左子树的节点数目等于 K-1,那么 root 就是结果,否则如果左子树节点数目小于 K-1,那么结果必然在右子树,否则就在左子树。因此在搜索的时候同时返回节点数目,跟 K 做对比,就能得出结果了。
/** * Definition for a binary tree node. **/public class TreeNode { int val; TreeNode left; TreeNode right; TreeNode(int x) { val = x; } }class Solution { private class ResultType {boolean found; // 是否找到int val; // 节点数目 ResultType(boolean found, int val) { this.found = found; this.val = val; } }public int kthSmallest(TreeNode root, int k) { return kthSmallestHelper(root, k).val; }private ResultType kthSmallestHelper(TreeNode root, int k) { if (root == null) { return new ResultType(false, 0); }ResultType left = kthSmallestHelper(root.left, k); // 左子树找到,直接返回 if (left.found) { return new ResultType(true, left.val); }// 左子树的节点数目 = K-1,结果为 root 的值 if (k - left.val == 1) { return new ResultType(true, root.val); }// 右子树寻找 ResultType right = kthSmallestHelper(root.right, k - left.val - 1); if (right.found) { return new ResultType(true, right.val); }// 没找到,返回节点总数 return new ResultType(false, left.val + 1 + right.val); } }

【给定一个二叉搜索树(BST),找到树中第 K 小的节点】解法2:基于二叉搜索树的特性,在中序遍历的结果中,第k个元素就是本题的解。 最差的情况是k节点是bst的最右叶子节点,不过每个节点的遍历次数最多是1次。 遍历并不是需要全部做完,使用计数的方式,找到第k个元素就可以退出。 下面是go的一个简单实现。
// BST is binary search tree type BST struct { key, valueint left, right *BST }func (bst *BST) setLeft(b *BST) { bst.left = b }func (bst *BST) setRight(b *BST) { bst.right = b }// count 查找bst第k个节点的值,未找到就返回0 func count(bst *BST, k int) int { if k < 1 { return 0 } c := 0 ok, value := countRecursive(bst, & c, k) if ok { return value } return 0 }// countRecurisive 对bst使用中序遍历 // 用计数方式控制退出遍历,参数c就是已遍历节点数 func countRecursive(bst *BST, c *int, k int) (bool, int) { if bst.left != nil { ok, value := countRecursive(bst.left, c, k) if ok { return ok, value } } if *c == k-1 { return true, bst.value } *c++ if bst.right != nil { ok, value := countRecursive(bst.right, c, k) if ok { return ok, value } } return false, 0 }// 下面是测试代码,覆盖了退化的情况和普通bstfunc createBST1() *BST { b1 := & BST{key: 1, value: 10} b2 := & BST{key: 2, value: 20} b3 := & BST{key: 3, value: 30} b4 := & BST{key: 4, value: 40} b5 := & BST{key: 5, value: 50} b6 := & BST{key: 6, value: 60} b7 := & BST{key: 7, value: 70} b8 := & BST{key: 8, value: 80} b9 := & BST{key: 9, value: 90} b9.setLeft(b8) b8.setLeft(b7) b7.setLeft(b6) b6.setLeft(b5) b5.setLeft(b4) b4.setLeft(b3) b3.setLeft(b2) b2.setLeft(b1) return b9 }func createBST2() *BST { b1 := & BST{key: 1, value: 10} b2 := & BST{key: 2, value: 20} b3 := & BST{key: 3, value: 30} b4 := & BST{key: 4, value: 40} b5 := & BST{key: 5, value: 50} b6 := & BST{key: 6, value: 60} b7 := & BST{key: 7, value: 70} b8 := & BST{key: 8, value: 80} b9 := & BST{key: 9, value: 90} b1.setRight(b2) b2.setRight(b3) b3.setRight(b4) b4.setRight(b5) b5.setRight(b6) b6.setRight(b7) b7.setRight(b8) b8.setRight(b9) return b1 }func createBST3() *BST { b1 := & BST{key: 1, value: 10} b2 := & BST{key: 2, value: 20} b3 := & BST{key: 3, value: 30} b4 := & BST{key: 4, value: 40} b5 := & BST{key: 5, value: 50} b6 := & BST{key: 6, value: 60} b7 := & BST{key: 7, value: 70} b8 := & BST{key: 8, value: 80} b9 := & BST{key: 9, value: 90} b5.setLeft(b3) b5.setRight(b7) b3.setLeft(b2) b3.setRight(b4) b2.setLeft(b1) b7.setLeft(b6) b7.setRight(b8) b8.setRight(b9) return b5 }func createBST4() *BST { b := & BST{key: 1, value: 10} last := b for i := 2; i < 100000; i++ { n := & BST{key: i, value: i * 10} last.setRight(n)last = n } return b }func createBST5() *BST { b := & BST{key: 99999, value: 999990} last := b for i := 99998; i > 0; i-- { n := & BST{key: i, value: i * 10} last.setLeft(n)last = n } return b }func createBST6() *BST { b := & BST{key: 50000, value: 500000} last := b for i := 49999; i > 0; i-- { n := & BST{key: i, value: i * 10} last.setLeft(n)last = n } last = b for i := 50001; i < 100000; i++ { n := & BST{key: i, value: i * 10} last.setRight(n)last = n } return b }func TestK(t *testing.T) { bst1 := createBST1() bst2 := createBST2() bst3 := createBST3() bst4 := createBST4() check(t, bst1, 1, 10) check(t, bst1, 2, 20) check(t, bst1, 3, 30) check(t, bst1, 4, 40) check(t, bst1, 5, 50) check(t, bst1, 6, 60) check(t, bst1, 7, 70) check(t, bst1, 8, 80) check(t, bst1, 9, 90) check(t, bst2, 1, 10) check(t, bst2, 2, 20) check(t, bst2, 3, 30) check(t, bst2, 4, 40) check(t, bst2, 5, 50) check(t, bst2, 6, 60) check(t, bst2, 7, 70) check(t, bst2, 8, 80) check(t, bst2, 9, 90) check(t, bst3, 1, 10) check(t, bst3, 2, 20) check(t, bst3, 3, 30) check(t, bst3, 4, 40) check(t, bst3, 5, 50) check(t, bst3, 6, 60) check(t, bst3, 7, 70) check(t, bst3, 8, 80) check(t, bst3, 9, 90) check(t, bst4, 1, 10) check(t, bst4, 2, 20) check(t, bst4, 3, 30) check(t, bst4, 4, 40) check(t, bst4, 5, 50) check(t, bst4, 6, 60) check(t, bst4, 7, 70) check(t, bst4, 8, 80) check(t, bst4, 9, 90) check(t, bst4, 99991, 999910) check(t, bst4, 99992, 999920) check(t, bst4, 99993, 999930) check(t, bst4, 99994, 999940) check(t, bst4, 99995, 999950) check(t, bst4, 99996, 999960) check(t, bst4, 99997, 999970) check(t, bst4, 99998, 999980) check(t, bst4, 99999, 999990) }func check(t *testing.T, b *BST, k, value int) { t.Helper() checkCall(t, b, k, value, count) // 此处可添加其他解法的实现 }func checkCall(t *testing.T, b *BST, k, value int, find func(bst *BST, kth int) int) { t.Helper() got := find(b, k) if got != value { t.Fatalf("want:%d, got:%d", value, got) } }

    推荐阅读