Problem

You are given an array of non-overlapping intervals intervals where intervals[i] = [start_i, end_i] represent the start and the end of the i^th interval and intervals is sorted in ascending order by start_i. You are also given an interval newInterval = [start, end] that represents the start and end of another interval.

Insert newInterval into intervals such that intervals is still sorted in ascending order by start_i and intervals still does not have any overlapping intervals (merge overlapping intervals if necessary).

Return intervals after the insertion.

Note that you don’t need to modify intervals in-place. You can make a new array and return it.

https://leetcode.com/problems/insert-interval/

Example 1:

Input: intervals = [[1,3],[6,9]], newInterval = [2,5]
Output: [[1,5],[6,9]]

Example 2:

Input: intervals = [[1,2],[3,5],[6,7],[8,10],[12,16]], newInterval = [4,8]
Output: [[1,2],[3,10],[12,16]]
Explanation: Because the new interval [4,8] overlaps with [3,5],[6,7],[8,10].

Constraints:

  • 0 <= intervals.length <= 10^4
  • intervals[i].length == 2
  • 0 <= start_i <= end_i <= 10^5
  • intervals is sorted by start_i in ascending order.
  • newInterval.length == 2
  • 0 <= start <= end <= 10^5

Test Cases

1
2
class Solution:
def insert(self, intervals: List[List[int]], newInterval: List[int]) -> List[List[int]]:
solution_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import pytest

from solution import Solution


@pytest.mark.parametrize('intervals, newInterval, expected', [
([[1,3],[6,9]], [2,5], [[1,5],[6,9]]),
([[1,2],[3,5],[6,7],[8,10],[12,16]], [4,8], [[1,2],[3,10],[12,16]]),

([[5,7],[8,10]], [2,3], [[2,3],[5,7],[8,10]]),
([[5,7],[8,10]], [2,5], [[2,7],[8,10]]),
([[5,7],[8,10]], [2,12], [[2,12]]),
([[5,6],[9,10]], [7,8], [[5,6],[7,8],[9,10]]),

([[4,4],[5,7],[8,10]], [10,18], [[4,4],[5,7],[8,18]])
])
class Test:
def test_solution(self, intervals, newInterval, expected):
sol = Solution()
assert sol.insert(intervals, newInterval) == expected

Thoughts

可以用二分法在原区间数组中,找出右端点小于(不等于)start 的最后一个区间(都小于则算 -1),记为 after

after 之后的区间,逐个与 newInterval 比较,如果与 newInterval 重叠则合并到 newInterval。最后用 newInterval 替换掉这些区间即可。

虽然查找是 O(log n) 时间,但跟后边的区间判定是否重叠,以及把 newInterval 替换进去可能造成后边的数组元素往前移动,最终时间复杂度还是 O(n)

实际上题目没有要求 in-place 插入,那就直接开新数组即可,也就不需要二分查找了,从头开始逐个比较即可。

Code

solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution:
def insert(self, intervals: list[list[int]], newInterval: list[int]) -> list[list[int]]:
after = self._bin_find_max_lt(intervals, newInterval[0])
i = after + 1
if i < len(intervals):
newInterval[0] = min(newInterval[0], intervals[i][0])
while i < len(intervals) and intervals[i][0] <= newInterval[1]:
i += 1
if i > after + 1:
newInterval[1] = max(newInterval[1], intervals[i-1][1])

intervals[after+1:i] = [newInterval]
return intervals

def _bin_find_max_lt(self, intervals: list[list[int]], val: int) -> int:
"""Finds the maximal index i, where `intervals[i][1] < val`, using binary search.
`intervals[:][1]` are ordered **distinct** integers.
Returns `-1` if val is smaller than all `intervals[:][1]`.
"""
l = 0
r = len(intervals) - 1
while l <= r:
m = (l + r) >> 1
if intervals[m][1] >= val:
r = m - 1
elif val > intervals[r][1]:
return r
else:
l = m + 1

return l - 1

依然使用了二分查找,练一下用二分法查找小于目标值的最大的元素。