Problem

In this problem, a tree is an undirected graph that is connected and has no cycles.

You are given a graph that started as a tree with n nodes labeled from 1 to n, with one additional edge added. The added edge has two different vertices chosen from 1 to n, and was not an edge that already existed. The graph is represented as an array edges of length n where edges[i] = [aᵢ, bᵢ] indicates that there is an edge between nodes aᵢ and bᵢ in the graph.

Return an edge that can be removed so that the resulting graph is a tree of n nodes. If there are multiple answers, return the answer that occurs last in the input.

https://leetcode.com/problems/redundant-connection/

Example 1:

case1

Input: edges = [[1,2],[1,3],[2,3]]
Output: [2,3]

Example 2:

case2

Input: edges = [[1,2],[2,3],[3,4],[1,4],[1,5]]
Output: [1,4]

Constraints:

  • n == edges.length
  • 3 <= n <= 1000
  • edges[i].length == 2
  • 1 <= aᵢ < bᵢ <= edges.length
  • aᵢ != bᵢ
  • There are no repeated edges.
  • The given graph is connected.

Test Cases

1
2
class Solution:
def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
solution_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import pytest

from solution import Solution



@pytest.mark.parametrize('edges, expected', [
([[1,2],[1,3],[2,3]], [2,3]),
([[1,2],[2,3],[3,4],[1,4],[1,5]], [1,4]),

([[9,10],[5,8],[2,6],[1,5],[3,8],[4,9],[8,10],[4,10],[6,8],[7,9]], [4,10])
])
@pytest.mark.parametrize('sol', [Solution()])
def test_solution(sol, edges, expected):
result = sol.findRedundantConnection(edges)
assert result == expected

Thoughts

开始就直接用类似 1591. Strange Printer II207. Course Schedule 中提到的方法判断给定的图中是否有环。当然一定有环,而发现环的那条边就是环上的一条边,可以删掉。

结果发现题目要求如果有多个可行解,需要返回给定的边中最后出现的那条。有点儿麻烦,而且还要再花不少额外的处理时间。

另一个直观的想法是维护能连通到一起的点集。按顺序扫描所有的边,如果某条边的两个顶点本来就已经是连通的,说明这条边是冗余的。显然再之后的边都不可能是冗余的。

开始想自己用多个顶点的集合来维护各个连通的点集,但复杂度大约在 O(n²)

实际上这个就是「并查集」结构的典型应用场景(Disjoint-set data structure),支持查询(find)和添加(union)操作。神奇的是平均情况下,find 和 union 方法的时间复杂度都接近 O(1)(最坏情况 O(log n))。

整体平均时间复杂度 O(n),空间复杂度 O(1)

Code

solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class DisjointSet:
def __init__(self, n: int) -> None:
self.parent = list(range(n)) # parent[u]: u's parent node.
self.depth = [0] * n # depth[u]: the max depth starting from u.

def find(self, u: int) -> int:
while self.parent[u] != u: u = self.parent[u]
return u

def union(self, u: int, v: int) -> bool:
ur = self.find(u)
vr = self.find(v)
if ur == vr: return False

if (diff := self.depth[ur] - self.depth[vr]) == 0:
self.depth[ur] += 1
elif diff < 0:
ur, vr = vr, ur # Make sure that depth[ur] >= depth[vr]

self.parent[vr] = ur
return True


class Solution:
def findRedundantConnection(self, edges: list[list[int]]) -> list[int]:
disjoint = DisjointSet(len(edges))
for u, v in edges:
if not disjoint.union(u - 1, v - 1):
return [u, v]