Problem

Given two strings s and part, perform the following operation on s until all occurrences of the substring part are removed:

  • Find the leftmost occurrence of the substring part and remove it from s.

Return s after removing all occurrences of part.

A substring is a contiguous sequence of characters in a string.

https://leetcode.com/problems/remove-all-occurrences-of-a-substring/

Example 1:

Input: s = "daabcbaabcbc", part = "abc"
Output: "dab"
Explanation: The following operations are done:

  • s = "daabcbaabcbc", remove "abc" starting at index 2, so s = "dabaabcbc".
  • s = "dabaabcbc", remove "abc" starting at index 4, so s = "dababc".
  • s = "dababc", remove "abc" starting at index 3, so s = "dab".

Now s has no occurrences of "abc".

Example 2:

Input: s = "axxxxyyyyb", part = "xy"
Output: "ab"
Explanation: The following operations are done:

  • s = "axxxxyyyyb", remove "xy" starting at index 4 so s = "axxxyyyb".
  • s = "axxxyyyb", remove "xy" starting at index 3 so s = "axxyyb".
  • s = "axxyyb", remove "xy" starting at index 2 so s = "axyb".
  • s = "axyb", remove "xy" starting at index 1 so s = "ab".

Now s has no occurrences of "xy".

Constraints:

  • 1 <= s.length <= 1000
  • 1 <= part.length <= 1000
  • s​​​​​​ and part consists of lowercase English letters.

Test Cases

1
2
class Solution:
def removeOccurrences(self, s: str, part: str) -> str:
solution_test.py
1
2
3
4
5
6
7
8
9
10
11
12
import pytest

from solution import Solution


@pytest.mark.parametrize('s, part, expected', [
("daabcbaabcbc", "abc", "dab"),
("axxxxyyyyb", "xy", "ab"),
])
@pytest.mark.parametrize('sol', [Solution()])
def test_solution(sol, s, part, expected):
assert sol.removeOccurrences(s, part) == expected

Thoughts

可以看作是 3174. Clear Digits 的进阶版,待匹配的 pattern 从两个字符(一个字母+一个数字)变成任意长度的字符串。

处理方法也类似,用一个栈记录结果字符串的字符。遍历 s,对于每个字符,先入栈,然后对比栈顶 m 个字符是否与 part 一致(设 part 的长度为 m),一致就把这些字符都弹出。

时间复杂度 O(n * m),空间复杂度 O(n)

Code

solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution:
def removeOccurrences(self, s: str, part: str) -> str:
stack: list[str] = []
m = len(part)
for c in s:
stack.append(c)
if len(stack) < m:
continue

for i in range(-1, -m - 1, -1):
if stack[i] != part[i]:
break
else:
for _ in range(m):
stack.pop()

return ''.join(stack)