Problem

Given an integer array nums, return all the triplets [nums[i], nums[j], nums[k]] such that i != j, i != k, and j != k, and nums[i] + nums[j] + nums[k] == 0.

Notice that the solution set must not contain duplicate triplets.

https://leetcode.com/problems/3sum/

Example 1:

Input: nums = [-1,0,1,2,-1,-4]
Output: [[-1,-1,2],[-1,0,1]]
Explanation:
nums[0] + nums[1] + nums[2] = (-1) + 0 + 1 = 0.
nums[1] + nums[2] + nums[4] = 0 + 1 + (-1) = 0.
nums[0] + nums[3] + nums[4] = (-1) + 2 + (-1) = 0.
The distinct triplets are [-1,0,1] and [-1,-1,2].
Notice that the order of the output and the order of the triplets does not matter.

Example 2:

Input: nums = [0,1,1]
Output: []
Explanation: The only possible triplet does not sum up to 0.

Example 3:

Input: nums = [0,0,0]
Output: [[0,0,0]]
Explanation: The only possible triplet sums up to 0.

Constraints:

  • 3 <= nums.length <= 3000
  • -10⁵ <= nums[i] <= 10⁵

Test Cases

1
2
class Solution:
def threeSum(self, nums: List[int]) -> List[List[int]]:
solution_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import pytest

from solution import Solution
from solution2 import Solution as Solution2
from solution3 import Solution as Solution3


@pytest.mark.parametrize('nums, expected', [
([-1,0,1,2,-1,-4], [[-1,-1,2],[-1,0,1]]),
([0,1,1], []),
([0,0,0], [[0,0,0]]),

([0,0,0,0,0,0,0], [[0,0,0]]),
([-1,-1,-1,-1,-1,0,0,0,0,1,1,1,1,1], [[-1,0,1],[0,0,0]]),

([-1,0,1,0], [[-1,0,1]]),
])
class Test:
def test_solution(self, nums, expected):
sol = Solution()
res = sol.threeSum(nums)
self._sort(res)
self._sort(expected)
assert res == expected

def test_solution2(self, nums, expected):
sol = Solution2()
res = sol.threeSum(nums)
self._sort(res)
self._sort(expected)
assert res == expected

def test_solution3(self, nums, expected):
sol = Solution3()
res = sol.threeSum(nums)
self._sort(res)
self._sort(expected)
assert res == expected

def _sort(self, triplets: list[list[int]]):
for triplet in triplets:
triplet.sort()
triplets.sort()

Thoughts

先把数组排序,因为输出的每组三个数值而非数组下标,所以不用记录原始的数组下标。

为了避免产生重复的三元组,额外限制 nums[i] <= nums[j] <= nums[k]

两重遍历所有的 i、j 组合 0i<j<n10\le i<j<n-1,用二分法在 nums[j+1:n] 中查找 0 - i - j。时间复杂度 O(n² log n),空间复杂度 O(n)

简单直接,但是不高效,没有充分利用到 和为零 的特点。

三个数之和为零,有几种可能:

  1. 三个数全是 0(要求 nums 至少有三个 0)。
  2. 一个数是 0,另外两个数一正一负(nums 至少有一个 0)。
  3. 两个负数,一个正数。
  4. 一个负数,两个正数。

由此,整个处理过程是:

  1. nums 分成正数和负数两个子数组,同时记录 0 的个数,O(n) 时间,O(n) 空间。
  2. 对两个子数组分别排序,O(n log n) 时间。
  3. 如果 0 的个数大于 2,记录 [0, 0, 0] 为可行解。
  4. 如果 0 的个数大于 0,用类似归并排序的方式,同时遍历负数和正数子数组,找出所有绝对值相等的正负数对,O(n) 时间。
  5. 用两重循环遍历所有的负数对组合,内循环的时候也用归并法同步遍历正数子数组,找到与两个负数之和绝对值相等的正数,O(n²) 时间。
  6. 同理,找出一负两正的组合,O(n²) 时间。

整体时间复杂度为 O(n²)

Code

solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Generator, List


class Solution:
def threeSum(self, nums: List[int]) -> List[List[int]]:
negatives = []
positives = []
zero = 0
for v in nums:
if v < 0:
negatives.append(-v)
elif v > 0:
positives.append(v)
else:
zero += 1

negatives.sort()
positives.sort()
triplets = []

if zero >= 3:
triplets.append([0, 0, 0])

# Find all [-, 0, +] triplets.
if zero > 0:
for v in self._find_all(negatives, positives):
triplets.append([-v, 0, v])

# Find all [-, -, +] triplets.
for i, v_i in enumerate(negatives):
if i > 0 and v_i == negatives[i - 1]:
continue # Prevent duplication.

for v_j in self._find_all(negatives, positives, i + 1, v_i):
triplets.append([-v_i, -v_j, v_i + v_j])

# Find all [-, +, +] triplets.
for j, v_j in enumerate(positives):
if j > 0 and v_j == positives[j - 1]:
continue # Prevent duplication.

for v_k in self._find_all(positives, negatives, j + 1, v_j):
triplets.append([-v_j - v_k, v_j, v_k])

return triplets

def _find_all(self, nums1: list[int], nums2: list[int], start1: int = 0, gap: int = 0) -> Generator[int, None, None]:
"""Finds all unique v1 from nums1[start1:], where v1 + gap exists in nums2.
Both nums1 and nums2 are sorted.
"""
i2 = 0
count2 = len(nums2)
prev_v1 = None
for i1 in range(start1, len(nums1)):
v1 = nums1[i1]
if v1 == prev_v1:
continue # Prevent duplication.

prev_v1 = v1
while i2 < count2 and nums2[i2] < v1 + gap:
i2 += 1

if i2 == count2:
return

if nums2[i2] == v1 + gap:
yield v1

Simpler

上边处理得点儿复杂。其实也不用特别考虑正负数,这不是加法特点引起的,只是和为 的特例。

先只看两数之和为 0(或任意目标值)的问题。只需要对数组排序,然后两个下标分别指向数组的首尾。

如果和小于目标值,需要把左边的下标往右移;反之把右边的下标往左移。如果等于目标值,那就找到一组解,然后同时向中间移动两个下标。直到两个下标相遇或者在目标值的同侧(跟上边归并处理的逻辑一致)。

而扩展到三个数,只需要先固定第一个数,然后对其右边的子数组,按上述方法找到与第一个数字之和为 0 的每一对数即可。

遍历一遍第一个数,每次遍历需要 O(n) 时间,总的时间复杂度是 O(n²)

代码简洁不少,但速度却慢了不少。可能因为这次是整个数组排序,遍历的时候比较次数也更多一些。

solution2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from typing import List


class Solution:
def threeSum(self, nums: List[int]) -> List[List[int]]:
n = len(nums)
nums.sort()
triplets = []
for i in range(n - 2):
if (v_i := nums[i]) > 0:
break
if i > 0 and v_i == nums[i - 1]:
continue # Prevent duplication.

v_j_max = -v_i >> 1
v_k_min = -(v_i >> 1)
j = i + 1
k = n - 1
while j < k and nums[j] <= v_j_max and nums[k] >= v_k_min:
if (s := v_i + nums[j] + nums[k]) == 0:
triplets.append([v_i, nums[j], nums[k]])
if s <= 0:
j += 1
while j < k and nums[j] == nums[j - 1]:
j += 1 # Prevent duplication.
if s >= 0:
k -= 1
while k > j and nums[k] == nums[k + 1]:
k -= 1 # Prevent duplication.

return triplets

Faster

还是看两数之和为定值的问题,可以不对数组排序,但把所有的数存入哈希表,遍历每一个数,在哈希表中查另一个数是否存在。

整体时间复杂度还是 O(n²),不过省去了整体排序的时间,其他判断的时间也能减少不少。

solution3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from typing import List


class Solution:
def threeSum(self, nums: List[int]) -> List[List[int]]:
vals: dict[int, int] = {} # {value: count}
for v in nums:
vals[v] = vals.get(v, 0) + 1

# Triplet: [vi, vj, vk], where vi <= vj <= vk and vi + vj + vk == 0
triplets: list[list[int]] = []
for v_i, c_i in vals.items():
if v_i == 0 and c_i > 2:
triplets.append([0, 0, 0])
elif v_i < 0:
if c_i > 1 and (v_i_double := -v_i << 1) in vals:
triplets.append([v_i, v_i, v_i_double])

v_i_half = -(v_i >> 1) # ceil(-v_i / 2)
if v_i & 1 == 0 and vals.get(v_i_half, 0) > 1:
triplets.append([v_i, v_i_half, v_i_half])

for v_j in vals:
if v_i < v_j < v_i_half and (v_k := -v_i - v_j) in vals:
triplets.append([v_i, v_j, v_k])

return triplets