Problem

Given an integer n, return an array ans of length n + 1 such that for each i (0 <= i <= n), ans[i] is the number of 1’s in the binary representation of i.

https://leetcode.com/problems/counting-bits/

Example 1:

Input: n = 2
Output: [0,1,1]
Explanation:

1
2
3
0 --> 0
1 --> 1
2 --> 10

Example 2:

Input: n = 5
Output: [0,1,1,2,1,2]
Explanation:

1
2
3
4
5
6
0 --> 0
1 --> 1
2 --> 10
3 --> 11
4 --> 100
5 --> 101

Constraints:

  • 0 <= n <= 10⁵

Follow up:

  • It is very easy to come up with a solution with a runtime of O(n log n). Can you do it in linear time O(n) and possibly in a single pass?
  • Can you do it without using any built-in function (i.e., like __builtin_popcount in C++)?

Test Cases

1
2
class Solution:
def countBits(self, n: int) -> List[int]:
solution_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import pytest

from solution import Solution
from solution2 import Solution as Solution2
from solution3 import Solution as Solution3


@pytest.mark.parametrize('n, expected', [
(2, [0,1,1]),
(5, [0,1,1,2,1,2]),

(8, [0,1,1,2,1,2,2,3,1]),
])
class Test:
def test_solution(self, n, expected):
sol = Solution()
assert sol.countBits(n) == expected

def test_solution2(self, n, expected):
sol = Solution2()
assert sol.countBits(n) == expected

def test_solution3(self, n, expected):
sol = Solution3()
assert sol.countBits(n) == expected

Thoughts

计算一个整数 n 的二进制中 1 的个数,就循环右移直到数字降为 0,过程中记录最低位是 1 的次数,时间为 O(log n)

对所有从 0 到 n 的整数都算一次,总时间是 O(n log n)

Code

solution.py
1
2
3
4
5
6
7
8
9
10
11
12
class Solution:
def countBits(self, n: int) -> list[int]:
counts = [0] * (n + 1)
for i in range(1, n + 1):
cnt = 0
j = i
while j > 0:
cnt += j & 1
j >>= 1
counts[i] = cnt

return counts

Follow Up - O(n)

连续计算的时候,可以利用已有的结果加速。

一个思路是看当从 n - 1 变到 n 时,减少了几个 1,增加了几个 1

比如 n = 0b1001000,那么 n - 1 = 0b1000111。二者二进制位没有变化的部分可以用 bit-and 计算,即 common = n & (n-1) = 0b1000000

common 分别与两个数计算 bit-xor,结果就是变化的部分。

xor(common, n-1) = 0b111,这些 1 都被去掉了。xor(common, n) = 0b1000,这些 1 是新加入的。

f(n) 表示整数 n 的二进制中 1 的个数,那么:

f(n) = f(n-1) - f(xor(common, n-1)) + f(xor(common, n))

一般 xor(common, n-1)xor(common, n) 都小于 n。唯独当 n 是 2 的幂时,xor(common, n) 等于 n,可以对这个情况做特判,f(2ⁱ) = 1

这样对于每个数,都只需要常数时间进行计算和查表,总时间是 O(n)

solution2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Solution:
def countBits(self, n: int) -> list[int]:
counts = [0] * (n + 1)
prev = 0
for i in range(1, n + 1):
common = i & prev
add = i ^ common
if add == i:
counts[i] = 1 # 2's power.
else:
remove = prev ^ common
counts[i] = counts[prev] - counts[remove] + counts[add]
prev = i

return counts

Faster O(n)

上边的方法还是太繁琐(且慢),而且「n 是 2 的幂」这种特例容易被忽略而造成 bug。

实际上对于一个 d 位的二进制数 n,最高位是 1,剩下的 d - 1 位构成了 m = n - 2ᵈ⁻¹,显然 0 <= m < n。如果 m 中 1 的个数已知,那么显然 f(n) = 1 + f(m)

只剩下一个问题就是常数时间确定 d,或者 P = 2ᵈ⁻¹。在遍历整数的过程中,很容易跟踪 P 的变化。初始 P = n = 1,当 n 递增到 P * 2 时,就可以更新 P = P * 2

总的时间复杂度也是 O(n),但系数要小得多。

solution3.py
1
2
3
4
5
6
7
8
9
10
11
12
class Solution:
def countBits(self, n: int) -> list[int]:
counts = [0] * (n + 1)
power = 1
next_power = 2
for i in range(1, n + 1):
if i == next_power:
power = next_power
next_power <<= 1
counts[i] = 1 + counts[i - power]

return counts

三种算法的实际运行时间对比:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
[1:n log n] n =     10:      4.994 μs
[2: linear] n = 10: 2.907 μs
[3: linear] n = 10: 1.752 μs

[1:n log n] n = 100: 83.195 μs
[2: linear] n = 100: 27.409 μs
[3: linear] n = 100: 10.770 μs

[1:n log n] n = 1000: 1223.218 μs
[2: linear] n = 1000: 280.717 μs
[3: linear] n = 1000: 109.602 μs

[1:n log n] n = 10000: 16359.447 μs
[2: linear] n = 10000: 2837.669 μs
[3: linear] n = 10000: 1194.283 μs

[1:n log n] n = 100000: 209923.616 μs
[2: linear] n = 100000: 28099.642 μs
[3: linear] n = 100000: 12005.032 μs