第一次用Python解决背包问题时,我盯着屏幕上那个运行了10分钟还没出结果的程序,突然意识到算法效率的重要性。那次经历让我明白,即使掌握了语法,不懂算法分析就像开车不看油表——代码可能随时在半路抛锚。
算法分析是衡量代码性能的显微镜,它能告诉我们:
假设我们要在电话簿中找人:
用Python实现这两个算法:
python复制# 线性搜索 O(n)
def linear_search(phonebook, name):
for entry in phonebook:
if entry['name'] == name:
return entry['phone']
return None
# 二分搜索 O(log n)
def binary_search(phonebook, name):
low, high = 0, len(phonebook)-1
while low <= high:
mid = (low + high) // 2
if phonebook[mid]['name'] == name:
return phonebook[mid]['phone']
elif phonebook[mid]['name'] < name:
low = mid + 1
else:
high = mid - 1
return None
当电话簿有10,000条记录时:
| 复杂度 | 名称 | n=10时的操作次数 | n=100时的增长倍数 | 典型算法 |
|---|---|---|---|---|
| O(1) | 常数时间 | 1 | 1 | 数组索引、哈希表查找 |
| O(log n) | 对数时间 | ~3 | ~7 | 二分搜索、平衡树操作 |
| O(n) | 线性时间 | 10 | 100 | 线性搜索、遍历链表 |
| O(n log n) | 线性对数时间 | ~30 | ~700 | 快速排序、归并排序 |
| O(n²) | 平方时间 | 100 | 10,000 | 冒泡排序、简单矩阵运算 |
| O(2ⁿ) | 指数时间 | 1,024 | 1.27e+30 | 穷举搜索、汉诺塔问题 |
经验法则:在Python中,当n>1,000时,O(n²)算法就可能出现明显延迟;当n>10,000时,O(n³)算法基本不可用。
python复制# 方法1:递归 O(2ⁿ)
def fib_recursive(n):
if n <= 1:
return n
return fib_recursive(n-1) + fib_recursive(n-2)
# 方法2:带缓存的递归 O(n)
from functools import lru_cache
@lru_cache(maxsize=None)
def fib_memo(n):
if n <= 1:
return n
return fib_memo(n-1) + fib_memo(n-2)
# 方法3:动态规划 O(n)
def fib_dp(n):
if n == 0:
return 0
a, b = 0, 1
for _ in range(2, n+1):
a, b = b, a + b
return b
# 方法4:矩阵快速幂 O(log n)
def fib_matrix(n):
def matrix_pow(mat, power):
result = [[1,0],[0,1]]
while power > 0:
if power % 2 == 1:
result = [[result[0][0]*mat[0][0]+result[0][1]*mat[1][0],
result[0][0]*mat[0][1]+result[0][1]*mat[1][1]],
[result[1][0]*mat[0][0]+result[1][1]*mat[1][0],
result[1][0]*mat[0][1]+result[1][1]*mat[1][1]]]
mat = [[mat[0][0]*mat[0][0]+mat[0][1]*mat[1][0],
mat[0][0]*mat[0][1]+mat[0][1]*mat[1][1]],
[mat[1][0]*mat[0][0]+mat[1][1]*mat[1][0],
mat[1][0]*mat[0][1]+mat[1][1]*mat[1][1]]]
power //= 2
return result
if n == 0:
return 0
mat = [[1,1],[1,0]]
return matrix_pow(mat, n-1)[0][0]
实测性能对比(计算fib(40)):
python复制import random
from timeit import timeit
# 生成测试数据
data_small = random.sample(range(1_000), 100)
data_medium = random.sample(range(10_000), 1_000)
data_large = random.sample(range(100_000), 10_000)
# 冒泡排序 O(n²)
def bubble_sort(arr):
n = len(arr)
for i in range(n):
for j in range(0, n-i-1):
if arr[j] > arr[j+1]:
arr[j], arr[j+1] = arr[j+1], arr[j]
# 归并排序 O(n log n)
def merge_sort(arr):
if len(arr) > 1:
mid = len(arr)//2
L, R = arr[:mid], arr[mid:]
merge_sort(L)
merge_sort(R)
i = j = k = 0
while i < len(L) and j < len(R):
if L[i] < R[j]:
arr[k] = L[i]
i += 1
else:
arr[k] = R[j]
j += 1
k += 1
while i < len(L):
arr[k] = L[i]
i += 1
k += 1
while j < len(R):
arr[k] = R[j]
j += 1
k += 1
实测结果(单位:秒):
| 数据规模 | 冒泡排序 | 归并排序 |
|---|---|---|
| 100条 | 0.0012 | 0.0006 |
| 1,000条 | 0.11 | 0.007 |
| 10,000条 | 11.8 | 0.09 |
避坑指南:Python内置的sorted()函数使用Timsort算法(O(n log n)),在大多数情况下都比手写排序更高效。实际开发中应优先使用内置函数。
python复制# 看似O(n)实则O(n²)的操作
def bad_remove_duplicates(lst):
result = []
for item in lst: # O(n)
if item not in result: # O(n) 因为每次都要扫描整个result列表
result.append(item)
return result
# 优化方案:使用集合 O(n)
def good_remove_duplicates(lst):
seen = set()
result = []
for item in lst: # O(n)
if item not in seen: # O(1) 集合查找
seen.add(item) # O(1)
result.append(item) # O(1)* 平摊时间
return result
| 操作 | 平均复杂度 | 最坏情况 | 适用场景 |
|---|---|---|---|
| 查找 | O(1) | O(n) | 快速查找、去重 |
| 插入 | O(1) | O(n) | 动态数据收集 |
| 删除 | O(1) | O(n) | 缓存淘汰策略 |
python复制# 高效统计词频 O(n)
def word_count(text):
count = {}
for word in text.split():
count[word] = count.get(word, 0) + 1
return count
# 低效版本 O(n²)
def bad_word_count(text):
words = text.split()
return {word: words.count(word) for word in words}
python复制# 两数之和问题
def two_sum_naive(nums, target): # O(n²)
for i in range(len(nums)):
for j in range(i+1, len(nums)):
if nums[i] + nums[j] == target:
return [i, j]
return None
def two_sum_optimized(nums, target): # O(n)
seen = {}
for i, num in enumerate(nums):
complement = target - num
if complement in seen:
return [seen[complement], i]
seen[num] = i
return None
python复制# 优化前 O(n³)
def find_triplets_naive(arr):
n = len(arr)
result = []
for i in range(n):
for j in range(i+1, n):
for k in range(j+1, n):
if arr[i] + arr[j] + arr[k] == 0:
result.append((arr[i], arr[j], arr[k]))
return result
# 优化后 O(n²)
def find_triplets_optimized(arr):
arr.sort()
n = len(arr)
result = []
for i in range(n-2):
if i > 0 and arr[i] == arr[i-1]:
continue
left, right = i+1, n-1
while left < right:
total = arr[i] + arr[left] + arr[right]
if total < 0:
left += 1
elif total > 0:
right -= 1
else:
result.append((arr[i], arr[left], arr[right]))
while left < right and arr[left] == arr[left+1]:
left += 1
while left < right and arr[right] == arr[right-1]:
right -= 1
left += 1
right -= 1
return result
| 操作 | 数据结构 | 时间复杂度 | 替代方案 |
|---|---|---|---|
| x in s | 列表 | O(n) | 改用集合O(1) |
| s[i:j] | 列表 | O(k) k=切片长度 | 考虑itertools.islice |
| s.append(x) | 列表 | O(1)* | 预分配空间减少扩容 |
| s.insert(0,x) | 列表 | O(n) | 改用collections.deque |
对于形式为T(n) = aT(n/b) + f(n)的递归:
应用案例:归并排序a=2, b=2, f(n)=Θ(n) → T(n)=Θ(n log n)
动态数组(Python列表)的扩容策略:
python复制import sys
def track_list_growth():
lst = []
last_size = 0
for i in range(1000):
lst.append(i)
if len(lst) != last_size:
print(f"Length: {len(lst)}, Allocated: {sys.getsizeof(lst)}")
last_size = len(lst)
python复制import timeit
def test_algorithm():
setup = '''
from __main__ import fib_dp, fib_matrix
n = 1000
'''
stmt1 = 'fib_dp(n)'
stmt2 = 'fib_matrix(n)'
time1 = timeit.timeit(stmt1, setup, number=100)
time2 = timeit.timeit(stmt2, setup, number=100)
print(f"DP: {time1:.6f} sec")
print(f"Matrix: {time2:.6f} sec")
test_algorithm()
python复制import cProfile
def profile_sorting():
data = random.sample(range(1_000_000), 100_000)
cProfile.run('sorted(data)', sort='cumtime')
profile_sorting()
错误案例:
python复制def misleading(n):
for i in range(n): # O(n)
for j in range(100): # O(100)
print(i+j) # 看似O(100n)即O(n)
实际上:O(100n) = O(n)是正确的,因为常数因子在大O表示法中会被忽略
python复制# 差
for word in document:
processed = process(word.lower()) # 每次循环都调用lower()
# 好
lower_word = word.lower()
for word in document:
processed = process(lower_word)
python复制# 意外的O(n²)
result = []
for num in numbers:
result += str(num) # 每次+=都会创建新字符串
python复制# 差
def get_squares(n):
return [x**2 for x in range(n)] # 预先生成所有结果
# 好
def generate_squares(n):
yield from (x**2 for x in range(n)) # 按需生成