Problem

You are given an array of non-overlapping intervals intervals where intervals[i] = [startᵢ, endᵢ] represent the start and the end of the iᵗʰ interval and intervals is sorted in ascending order by startᵢ. You are also given an interval newInterval = [start, end] that represents the start and end of another interval.

Insert newInterval into intervals such that intervals is still sorted in ascending order by startᵢ and intervals still does not have any overlapping intervals (merge overlapping intervals if necessary).

Return intervals after the insertion.

Note that you don’t need to modify intervals in-place. You can make a new array and return it.

https://leetcode.com/problems/insert-interval/

Example 1:

Input: intervals = [[1,3],[6,9]], newInterval = [2,5]
Output: [[1,5],[6,9]]

Example 2:

Input: intervals = [[1,2],[3,5],[6,7],[8,10],[12,16]], newInterval = [4,8]
Output: [[1,2],[3,10],[12,16]]
Explanation: Because the new interval [4,8] overlaps with [3,5],[6,7],[8,10].

Constraints:

  • 0 <= intervals.length <= 10⁴
  • intervals[i].length == 2
  • 0 <= startᵢ <= endᵢ <= 10⁵
  • intervals is sorted by startᵢ in ascending order.
  • newInterval.length == 2
  • 0 <= start <= end <= 10⁵

Test Cases

1
2
class Solution:
def insert(self, intervals: List[List[int]], newInterval: List[int]) -> List[List[int]]:
solution_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import pytest

from solution import Solution


@pytest.mark.parametrize('intervals, newInterval, expected', [
([[1,3],[6,9]], [2,5], [[1,5],[6,9]]),
([[1,2],[3,5],[6,7],[8,10],[12,16]], [4,8], [[1,2],[3,10],[12,16]]),

([[5,7],[8,10]], [2,3], [[2,3],[5,7],[8,10]]),
([[5,7],[8,10]], [2,5], [[2,7],[8,10]]),
([[5,7],[8,10]], [2,12], [[2,12]]),
([[5,6],[9,10]], [7,8], [[5,6],[7,8],[9,10]]),

([[4,4],[5,7],[8,10]], [10,18], [[4,4],[5,7],[8,18]])
])
class Test:
def test_solution(self, intervals, newInterval, expected):
sol = Solution()
assert sol.insert(intervals, newInterval) == expected

Thoughts

可以用二分法在原区间数组中,找出右端点小于(不等于)start 的最后一个区间(都小于则算 -1),记为 after

after 之后的区间,逐个与 newInterval 比较,如果与 newInterval 重叠则合并到 newInterval。最后用 newInterval 替换掉这些区间即可。

虽然查找是 O(log n) 时间,但跟后边的区间判定是否重叠,以及把 newInterval 替换进去可能造成后边的数组元素往前移动,最终时间复杂度还是 O(n)

实际上题目没有要求 in-place 插入,那就直接开新数组即可,也就不需要二分查找了,从头开始逐个比较即可。

Code

solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Solution:
def insert(self, intervals: list[list[int]], newInterval: list[int]) -> list[list[int]]:
after = self._bin_find_max_lt(intervals, newInterval[0])
i = after + 1
if i < len(intervals):
newInterval[0] = min(newInterval[0], intervals[i][0])
while i < len(intervals) and intervals[i][0] <= newInterval[1]:
i += 1
if i > after + 1:
newInterval[1] = max(newInterval[1], intervals[i-1][1])

intervals[after+1:i] = [newInterval]
return intervals

def _bin_find_max_lt(self, intervals: list[list[int]], val: int) -> int:
"""Finds the maximal index i, where `intervals[i][1] < val`, using binary search.
`intervals[:][1]` are ordered **distinct** integers.
Returns `-1` if val is smaller than all `intervals[:][1]`.
"""
l = 0
r = len(intervals) - 1
while l <= r:
m = (l + r) >> 1
if intervals[m][1] >= val:
r = m - 1
elif val > intervals[r][1]:
return r
else:
l = m + 1

return l - 1

依然使用了二分查找,练一下用二分法查找小于目标值的最大的元素。