Problem

Given the root of a binary search tree, and an integer k, return the k^th smallest value (1-indexed) of all the values of the nodes in the tree.

https://leetcode.com/problems/kth-smallest-element-in-a-bst/

Example 1:

Input: root = [3,1,4,null,2], k = 1
Output: 1

Example 2:

Input: root = [5,3,6,2,4,null,null,1], k = 3
Output: 3

Constraints:

  • The number of nodes in the tree is n.
  • 1 <= k <= n <= 10^4
  • 0 <= Node.val <= 10^4

Follow up: If the BST is modified often (i.e., we can do insert and delete operations) and you need to find the kth smallest frequently, how would you optimize?

Test Cases

1
2
3
4
5
6
7
8
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
solution_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import pytest

import sys
sys.path.append('..')
from _utils.binary_tree import build_tree
from solution import Solution

null = None


@pytest.mark.parametrize('root, k, expected', [
([3,1,4,null,2], 1, 1),
([5,3,6,2,4,null,null,1], 3, 3),
])
class Test:
def test_solution(self, root, k, expected):
sol = Solution()
assert sol.kthSmallest(build_tree(root), k) == expected

Thoughts

直接按照中序(in-order,LNR)遍历 BST 二叉树。访问到的第 k 个节点,其值就是第 k 小的数。

时间复杂度 O(n)。虽然找到第 k 个节点就结束,但中序遍历的话,即便是找到值最小的节点,最坏情况也可能需要先路过所有的其他节点。

所以频繁增删的 BST,可能会变得不平衡导致需要花很长时间才能找到第 k 小的节点,可以对 BST 做平衡调节。

Code

solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from typing import Optional

import sys
sys.path.append('..')
from _utils.binary_tree import TreeNode


# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
i = 0
stack = []
while root or stack:
if root:
stack.append(root)
root = root.left
else:
root = stack.pop()
if (i := i + 1) == k:
return root.val
root = root.right