Problem

You are given an array of intervals, where intervals[i] = [startᵢ, endᵢ] and each startᵢ is unique.

The right interval for an interval i is an interval j such that startⱼ >= endᵢ and startⱼ is minimized. Note that i may equal j.

Return an array of right interval indices for each interval i. If no right interval exists for interval i, then put -1 at index i.

https://leetcode.com/problems/find-right-interval/

Example 1:

Input: intervals = [[1,2]]
Output: [-1]
Explanation: There is only one interval in the collection, so it outputs -1.

Example 2:

Input: intervals = [[3,4],[2,3],[1,2]]
Output: [-1,0,1]
Explanation: There is no right interval for [3,4].
The right interval for [2,3] is [3,4] since start0 = 3 is the smallest start that is >= end1 = 3.
The right interval for [1,2] is [2,3] since start1 = 2 is the smallest start that is >= end2 = 2.

Example 3:

Input: intervals = [[1,4],[2,3],[3,4]]
Output: [-1,2,-1]
Explanation: There is no right interval for [1,4] and [3,4].
The right interval for [2,3] is [3,4] since start2 = 3 is the smallest start that is >= end1 = 3.

Constraints:

  • 1 <= intervals.length <= 2 * 10⁴
  • intervals[i].length == 2
  • -10⁶ <= startᵢ <= endᵢ <= 10⁶
  • The start point of each interval is unique.

Test Cases

1
2
class Solution:
def findRightInterval(self, intervals: List[List[int]]) -> List[int]:
solution_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
import pytest

from solution import Solution


@pytest.mark.parametrize('intervals, expected', [
([[1,2]], [-1]),
([[3,4],[2,3],[1,2]], [-1,0,1]),
([[1,4],[2,3],[3,4]], [-1,2,-1]),
])
@pytest.mark.parametrize('sol', [Solution()])
def test_solution(sol, intervals, expected):
assert sol.findRightInterval(intervals) == expected

Thoughts

对所有的区间按起点坐标排序,然后对某个区间,用二分搜索查找区间终点的插入位置,此位置对应的区间即为其「右区间」。因为区间一定不会出现在区间起点的左边,所以可以只对排序后当前区间开始的右半个数组做二分搜索(代码中 lo=idx 参数)。

时间复杂度 O(n log n),空间复杂度 O(n)

Code

solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from bisect import bisect_left


class Solution:
def findRightInterval(self, intervals: list[list[int]]) -> list[int]:
n = len(intervals)
indices = sorted(range(n), key=lambda i: intervals[i][0])
ans = [-1] * n
for idx, i in enumerate(indices):
j = bisect_left(indices, intervals[i][1], lo=idx, key=lambda j: intervals[j][0])
if j < n:
ans[i] = indices[j]

return ans