Problem

Given an integer n, return all the structurally unique BST’s (binary search trees), which has exactly n nodes of unique values from 1 to n. Return the answer in any order.

https://leetcode.com/problems/unique-binary-search-trees-ii/

Example 1:

case1

Input: n = 3
Output: [[1,null,2,null,3],[1,null,3,2],[2,1,3],[3,1,null,null,2],[3,2,null,1]]

Example 2:

Input: n = 1
Output: [[1]]

Constraints:

  • 1 <= n <= 8

Test Cases

1
2
3
4
5
6
7
8
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def generateTrees(self, n: int) -> List[Optional[TreeNode]]:
solution_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import pytest

import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from _utils.binary_tree import print_tree
from solution import Solution
from solution2 import Solution as Solution2

null = None


@pytest.mark.parametrize('n, expected', [
(3, [[1,null,2,null,3],[1,null,3,2],[2,1,3],[3,1,null,null,2],[3,2,null,1]]),
(1, [[1]]),
])
@pytest.mark.parametrize('sol', [Solution(), Solution2()])
def test_solution(sol, n, expected):
result = sol.generateTrees(n)
result = [print_tree(root) for root in result]
result.sort()
assert result == sorted(expected)

Thoughts

96. Unique Binary Search Trees 一样,只不过这里是要列举出所有可能的 BST 来(这回不能直接用卡塔兰数的数学公式了)。

用递归来实现吧,借助 Python 内置的 functools.cache 缓存一些中间结果来加速(甚至复用一些二叉树节点)。

当然也可以用循环来做,自己维护缓存 dp[i][j] 记录从 i 到 j 的所有 BST。循环的时候,第一重对子问题的规模循环,可以避免引用到尚未计算的 dp。

Code

Recursively

solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from functools import cache
from itertools import product
from typing import Optional


# Definition for a binary tree node.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right


class Solution:
def generateTrees(self, n: int) -> list[Optional[TreeNode]]:
@cache
def gen(i: int, j: int) -> list[TreeNode]:
if i == j: return [TreeNode(i)]
elif i > j: return [None]

trees = []
for k in range(i, j + 1):
left = gen(i, k - 1)
right = gen(k + 1, j)
for l, r in product(left, right):
trees.append(TreeNode(k, l, r))

return trees

return gen(1, n)

Iteratively

solution2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from itertools import product
from typing import Optional


# Definition for a binary tree node.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right


class Solution:
def generateTrees(self, n: int) -> list[Optional[TreeNode]]:
# dp[i][j]: all trees for i...j
dp = [[[None] for _ in range(n + 2)] for _ in range(n + 2)]
for size in range(1, n + 1):
for i in range(1, n - size + 2):
j = i + size - 1
trees = []
for k in range(i, j + 1):
for l, r in product(dp[i][k-1], dp[k+1][j]):
trees.append(TreeNode(k, l, r))

dp[i][j] = trees

return dp[1][n]