Problem

The median is the middle value in an ordered integer list. If the size of the list is even, there is no middle value, and the median is the mean of the two middle values.

  • For example, for arr = [2,3,4], the median is 3.
  • For example, for arr = [2,3], the median is (2 + 3) / 2 = 2.5.

Implement the MedianFinder class:

  • MedianFinder() initializes the MedianFinder object.
  • void addNum(int num) adds the integer num from the data stream to the data structure.
  • double findMedian() returns the median of all elements so far. Answers within 10⁻⁵ of the actual answer will be accepted.

https://leetcode.com/problems/find-median-from-data-stream/

Example 1:

Input
["MedianFinder", "addNum", "addNum", "findMedian", "addNum", "findMedian"]
[[], [1], [2], [], [3], []]
Output
[null, null, null, 1.5, null, 2.0]

Explanation

1
2
3
4
5
6
MedianFinder medianFinder = new MedianFinder();
medianFinder.addNum(1); // arr = [1]
medianFinder.addNum(2); // arr = [1, 2]
medianFinder.findMedian(); // return 1.5 (i.e., (1 + 2) / 2)
medianFinder.addNum(3); // arr[1, 2, 3]
medianFinder.findMedian(); // return 2.0

Constraints:

  • -10⁵ <= num <= 10⁵
  • There will be at least one element in the data structure before calling findMedian.
  • At most 5 * 10⁴ calls will be made to addNum and findMedian.

Follow up:

  • If all integer numbers from the stream are in the range [0, 100], how would you optimize your solution?
  • If 99% of all integer numbers from the stream are in the range [0, 100], how would you optimize your solution?

Test Cases

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class MedianFinder:

def __init__(self):


def addNum(self, num: int) -> None:


def findMedian(self) -> float:



# Your MedianFinder object will be instantiated and called as such:
# obj = MedianFinder()
# obj.addNum(num)
# param_2 = obj.findMedian()
solution_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import pytest

from solution import MedianFinder
from solution_follow1 import MedianFinder as MedianFinderFollowUp1


@pytest.mark.parametrize('actions, params, expects', [
(
["MedianFinder", "addNum", "addNum", "findMedian", "addNum", "findMedian"],
[[], [1], [2], [], [3], []],
[None, None, None, 1.5, None, 2.0]
),
])
class Test:
def test_solution(self, actions, params, expects):
self._run(MedianFinder, actions, params, expects)

def test_solution_follow1(self, actions, params, expects):
self._run(MedianFinderFollowUp1, actions, params, expects)

def _run(self, clazz, actions, params, expects):
finder = None
for action, args, expected in zip(actions, params, expects):
if action == 'MedianFinder':
finder = clazz()
elif action == 'findMedian':
assert finder.findMedian(*args) == pytest.approx(expected, abs=1e-5)
else:
assert getattr(finder, action)(*args) == expected


@pytest.mark.parametrize('nums, expects', [
([1, 2, 3], [1, 1.5, 2]),
([2, 3, 4], [2, 2.5, 3]),
([4, 2, 3], [4, 3, 3]),
(
[42, 37, 38, 50, 71, 5, 65, 12, 93, 71, 25, 55, 95, 4, 67, 18, 72, 36, 25, 17],
[42, 39.5, 38, 40.0, 42, 40.0, 42, 40.0, 42, 46.0, 42, 46.0, 50, 46.0, 50, 46.0, 50, 46.0, 42, 40.0]
),
(
[60, 50, 10, 80, 90, 40, 0, 80, 90, 20, 20, 10, 50, 70, 90, 40, 30, 80, 90, 20, 50, 50, 10, 20, 90, 40, 50, 40, 70, 50],
[60, 55.0, 50, 55.0, 60, 55.0, 50, 55.0, 60, 55.0, 50, 45.0, 50, 50.0, 50, 50.0, 50, 50.0, 50, 50.0, 50, 50.0, 50, 50.0, 50, 50.0, 50, 50.0, 50, 50.0]
),
(
[0, 50, 50, 100, 100, 0, 0, 100, 100, 0, 50, 0, 0, 100, 0, 0, 100, 0, 50, 0, 100, 50, 0, 50, 0, 50, 100, 0, 100, 100],
[0, 25.0, 50, 50.0, 50, 50.0, 50, 50.0, 50, 50.0, 50, 50.0, 50, 50.0, 50, 25.0, 50, 25.0, 50, 25.0, 50, 50.0, 50, 50.0, 50, 50.0, 50, 50.0, 50, 50.0]
),
])
class TestAll:
def test_solution(self, nums, expects):
finder = MedianFinder()
self._run(finder, nums, expects)

def test_solution_follow1(self, nums, expects):
finder = MedianFinderFollowUp1()
self._run(finder, nums, expects)

def _run(self, finder, nums, expects):
for (num, expected) in zip(nums, expects):
finder.addNum(num)
assert finder.findMedian() == pytest.approx(expected, abs=1e-5)

Thoughts

因为后续数据的分布未知,曾经见过的任何一个数都有可能成为中位数,从信息量的角度看,所有见过的数字是一定要存下来的。

中位数比它左边的所有数字都不小,且比它右边的所有数字都不大。可以用最大堆存左半个数组,最小堆存右半个数组,并保持两个堆的大小一致或相差不超过 1。

如果数据总数是偶数,那么最中间的两个数就分别是左半边的最大也右半边的最小,取两个堆的堆顶求平均即可。如果总数是奇数,那么较大的堆(比另一个堆多一个数)的堆顶就是中位数。

实现的时候,可以实现比如最大堆,并用存储相反数的方式模拟最小堆(反之亦然)。

没必要始终保持左半边的堆的大小不比右半边的小,虽然代码会简洁一些,但平均会多一组出堆再入堆操作。对于新的数字,先根据两个堆的堆顶数字大小判断应该放入哪边,然后如果出现两边的堆大小过于不平衡(大小相差超过 1),再从较大的堆弹出堆顶,放入另一个堆中。

设已经有了 n 个数字,那么下次增加新数字的时间复杂度是 O(log n),查中位数的时间复杂度是 O(1)。整体需要 O(n) 空间复杂度维持两个堆。

Code

solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class MaxHeap:
def __init__(self) -> None:
self._store = []

@property
def size(self):
return len(self._store)

@property
def top(self):
return self._store[0]

def add(self, value: int) -> None:
n = len(self._store)
self._store.append(value)
self._shift_up(n)

def pop(self) -> int:
if len(self._store) == 1:
return self._store.pop()

top = self._store[0]
self._store[0] = self._store.pop()
self._shift_down(0)
return top

def _shift_up(self, pos: int) -> None:
"""Shifts store[pos] up to proper new position."""
while pos > 0:
parent = (pos - 1) >> 1
if self._store[parent] < self._store[pos]:
self._store[parent], self._store[pos] = self._store[pos], self._store[parent]
pos = parent
else:
return

def _shift_down(self, pos: int) -> None:
"""Shifts store[pos] down to proper new position."""
n = len(self._store)
leaf = n >> 1
while pos < leaf:
left = (pos << 1) + 1
right = left + 1
child = right if right < n and self._store[right] > self._store[left] else left
if self._store[pos] < self._store[child]:
self._store[pos], self._store[child] = self._store[child], self._store[pos]
pos = child
else:
return


class MedianFinder:

def __init__(self):
self._left = MaxHeap()
self._right = MaxHeap()

def addNum(self, num: int) -> None:
if self._left.size == 0 or num < self._left.top:
self._left.add(num)
else:
self._right.add(-num)

if self._left.size > self._right.size + 1:
self._right.add(-self._left.pop())
elif self._right.size > self._left.size + 1:
self._left.add(-self._right.pop())

def findMedian(self) -> float:
if self._left.size > self._right.size:
return self._left.top
elif self._right.size > self._left.size:
return -self._right.top

return (self._left.top - self._right.top) / 2


# Your MedianFinder object will be instantiated and called as such:
# obj = MedianFinder()
# obj.addNum(num)
# param_2 = obj.findMedian()

Follow up 1

如果所有的数据都只在 [1, 100] 范围内,就不需要堆了,只要记录每个整数的次数。

记录当前中位数是哪个整数,以及该整数在左右半边分别有多少个。下次 addNum 时,更新这个信息。

addNumfindMedian 的时间和空间复杂度都是 O(1)

solution_follow1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class MedianFinder:

def __init__(self):
self._store = [0] * 101
self._size = 0
self._mid_num = None
self._mid_offset = 0

def _next(self, num: int) -> int:
return next(v for v in range(num + 1, 101) if self._store[v] > 0)

def _prev(self, num: int) -> int:
return next(v for v in range(num - 1, -1, -1) if self._store[v] > 0)

def addNum(self, num: int) -> None:
self._store[num] += 1

if self._mid_num is None:
self._mid_num = num
elif num >= self._mid_num and self._size & 1 == 0:
self._mid_offset += 1
if self._mid_offset == self._store[self._mid_num]:
self._mid_num = self._next(self._mid_num)
self._mid_offset = 0
elif num < self._mid_num and self._size & 1 == 1:
self._mid_offset -= 1
if self._mid_offset < 0:
self._mid_num = self._prev(self._mid_num)
self._mid_offset = self._store[self._mid_num] - 1

self._size += 1

def findMedian(self) -> float:
if self._size & 1 == 1 or self._mid_offset < self._store[self._mid_num] - 1:
return self._mid_num
else:
return (self._mid_num + self._next(self._mid_num)) / 2


# Your MedianFinder object will be instantiated and called as such:
# obj = MedianFinder()
# obj.addNum(num)
# param_2 = obj.findMedian()

Follow up 2

如果所有的数据有 99% 在 [1, 100] 范围内,那么中位数也有极大的概率出现在 [1, 100] 内,可以在上边的基础上,多记录 < 0> 100 的数字个数。其他的逻辑不受影响。

极端情况,比如给的第一个数字就在 [1, 100] 范围之外,然后立刻要获取中位数,想要依然能正确返回结果,可以记录 < 0 的最大的几个数和 > 100 的最小的几个数,根据输入数据的分布特点,可能最多各存一到两个就行。