Problem

There exists an undirected tree rooted at node 0 with n nodes labeled from 0 to n - 1. You are given a 2D integer array edges of length n - 1, where edges[i] = [aᵢ, bᵢ] indicates that there is an edge between nodes aᵢ and bᵢ in the tree. You are also given a 0-indexed array coins of size n where coins[i] indicates the number of coins in the vertex i, and an integer k.

Starting from the root, you have to collect all the coins such that the coins at a node can only be collected if the coins of its ancestors have been already collected.

Coins at nodeᵢ can be collected in one of the following ways:

  • Collect all the coins, but you will get coins[i] - k points. If coins[i] - k is negative then you will lose abs(coins[i] - k) points.
  • Collect all the coins, but you will get floor(coins[i] / 2) points. If this way is used, then for all the nodeⱼ present in the subtree of nodeᵢ, coins[j] will get reduced to floor(coins[j] / 2).

Return the maximum points you can get after collecting the coins from all the tree nodes.

https://leetcode.cn/problems/maximum-points-after-collecting-coins-from-all-nodes/

Example 1:

Input: edges = [[0,1],[1,2],[2,3]], coins = [10,10,3,3], k = 5
Output: 11
Explanation:
Collect all the coins from node 0 using the first way. Total points = 10 - 5 = 5.
Collect all the coins from node 1 using the first way. Total points = 5 + (10 - 5) = 10.
Collect all the coins from node 2 using the second way so coins left at node 3 will be floor(3 / 2) = 1. Total points = 10 + floor(3 / 2) = 11.
Collect all the coins from node 3 using the second way. Total points = 11 + floor(1 / 2) = 11.
It can be shown that the maximum points we can get after collecting coins from all the nodes is 11.

Example 2:

Input: edges = [[0,1],[0,2]], coins = [8,4,4], k = 0
Output: 16
Explanation:
Coins will be collected from all the nodes using the first way. Therefore, total points = (8 - 0) + (4 - 0) + (4 - 0) = 16.

Constraints:

  • n == coins.length
  • 2 <= n <= 10⁵
  • 0 <= coins[i] <= 10⁴
  • edges.length == n - 1
  • 0 <= edges[i][0], edges[i][1] < n
  • 0 <= k <= 10⁴

Test Cases

1
2
class Solution:
def maximumPoints(self, edges: List[List[int]], coins: List[int], k: int) -> int:
solution_test.py
1
2
3
4
5
6
7
8
9
10
11
12
import pytest

from solution import Solution


@pytest.mark.parametrize('edges, coins, k, expected', [
([[0,1],[1,2],[2,3]], [10,10,3,3], 5, 11),
([[0,1],[0,2]], [8,4,4], 0, 16),
])
@pytest.mark.parametrize('sol', [Solution()])
def test_solution(sol, edges, coins, k, expected):
assert sol.maximumPoints(edges, coins, k) == expected

Thoughts

定义 dp(u, i) 表示以 u 为根节点的子树,其祖先节点已经执行过 i 次方案二,所能得到的最大 points。显然(其中 v 是 u 的子节点):

dp(u,i)=max{coins[u]i+vdp(v,i)coins[u](i+1)+vdp(v,i+1)dp(u,i)=\max\begin{cases} coins[u]\gg i+\sum_{v}{dp(v,i)} \\ coins[u]\gg(i+1)+\sum_{v}{dp(v,i+1)} \end{cases}

即对节点 u 执行方案一或者方案二,二者取大。

考虑到 coins[u] 最大值为 10⁴,当 i ≥ 14 时,coins[u] 就一定为 0 了,所以 i 可以只取 0 到 13。另外 i 也不会超过树的高度,但本题中树的高度可以达到 O(n) 量级。

用后序遍历树,处理完所有的子节点后,汇总出中间节点的结果。

这里直接用递归实现。也可以用栈加循环来模拟递归(类似于 3249. Count the Number of Good Nodes),但不一定比直接递归快。

时间复杂度 O(n * log m),其中 m 是 10⁴ 或者 coins 中的最大值,空间复杂度 O(n * log m) 或者 O(h * log m),其中 h 是树的高度(平均情况下 h ≈ O(log n))。

Code

solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Solution:
def maximumPoints(self, edges: list[list[int]], coins: list[int], k: int) -> int:
tree: list[list[int]] = [[] for _ in range(len(coins))]
for u, v in edges:
tree[u].append(v)
tree[v].append(u)

MAX = 14 # (10**4).bit_length() or max(coins).bit_length()
max2 = lambda a, b: a if a >= b else b

def dfs(root: int, parent: int) -> list[int]:
points = [0] * MAX
for child in tree[root]:
if child != parent:
child_points = dfs(child, root)
for i, point in enumerate(child_points):
if point == 0:
break
points[i] += point

coin = coins[root]
for i in range(MAX - 1):
p1 = coin - k + points[i]
coin >>= 1
p2 = coin + points[i + 1]
points[i] = max2(p1, p2)
points[-1] = coin - k
return points

return dfs(0, -1)[0]