1. 题目背景与核心需求
这道算法题来自2026年美团春招的算法岗笔试第三题,考察的是无向树相关的图论知识。题目描述了一棵包含n个节点的无向树,每条边都有一个权值。我们需要找到树上所有简单路径中,满足路径上边权乘积不超过给定阈值k的路径数量。
无向树是一种特殊的无向图,具有n-1条边且连通无环。这类问题在实际业务场景中非常常见,比如社交网络中的好友关系分析、物流配送路线规划、网络拓扑结构优化等。美团作为生活服务电商平台,在骑手路径规划、商家推荐系统等场景都会用到这类算法。
2. 解题思路分析
2.1 暴力解法及其局限性
最直观的解法是枚举所有可能的节点对,然后检查它们之间路径的边权乘积。对于n个节点的树,共有n(n-1)/2个节点对,每次检查路径需要O(n)时间,总时间复杂度为O(n³)。这在n较大时(比如n=1e5)完全不可行。
2.2 分治算法选择
更优的解法是采用基于重心的树分治算法,时间复杂度可以降到O(n log² n)。这种算法的核心思想是:
- 找到当前树的重心(使删除该节点后最大子树最小的节点)
- 计算经过重心的合法路径数
- 递归处理各子树
重心分解能保证递归深度为O(log n),每层处理时间为O(n log n)。
2.3 具体实现步骤
2.3.1 寻找树的重心
使用DFS遍历树,记录每个节点的子树大小,同时维护最大子树的最小值:
python复制def find_centroid(u, parent, total):
size = 1
max_sub = 0
centroid = None
for v in adj[u]:
if v != parent and not deleted[v]:
centroid_v, size_v = find_centroid(v, u, total)
if centroid_v is not None:
return centroid_v, 0
size += size_v
max_sub = max(max_sub, size_v)
max_sub = max(max_sub, total - size)
if max_sub <= total // 2:
centroid = u
return centroid, size
2.3.2 计算经过重心的路径
对每个子节点进行DFS,记录从重心出发到各节点的路径乘积,然后使用双指针或前缀和统计合法路径数:
python复制def get_paths(u, parent, current_product, paths):
if current_product > k:
return
paths.append(current_product)
for v, w in adj[u]:
if v != parent and not deleted[v]:
get_paths(v, u, current_product * w, paths)
2.3.3 合并子树结果
使用排序+双指针技术高效统计满足条件的路径组合:
python复制def count_pairs(paths1, paths2):
paths2.sort()
count = 0
for p in paths1:
max_q = k // p
count += bisect.bisect_right(paths2, max_q)
return count
3. 完整代码实现
3.1 Python实现
python复制import bisect
from collections import defaultdict
def solve():
n, k = map(int, input().split())
adj = defaultdict(list)
for _ in range(n-1):
u, v, w = map(int, input().split())
adj[u].append((v, w))
adj[v].append((u, w))
deleted = [False] * (n + 1)
total = 0
def dfs_size(u, parent):
size = 1
for v, w in adj[u]:
if v != parent and not deleted[v]:
size += dfs_size(v, u)
return size
def find_centroid(u, parent, total_size):
size = 1
max_sub = 0
centroid = None
for v, w in adj[u]:
if v != parent and not deleted[v]:
centroid_candidate, child_size = find_centroid(v, u, total_size)
if centroid_candidate is not None:
return centroid_candidate, 0
size += child_size
max_sub = max(max_sub, child_size)
max_sub = max(max_sub, total_size - size)
if max_sub <= total_size // 2:
centroid = u
return centroid, size
def get_paths(u, parent, current_product, paths):
if current_product > k:
return
paths.append(current_product)
for v, w in adj[u]:
if v != parent and not deleted[v]:
get_paths(v, u, current_product * w, paths)
def count_pairs(paths1, paths2):
paths2_sorted = sorted(paths2)
count = 0
for p in paths1:
max_q = k // p
count += bisect.bisect_right(paths2_sorted, max_q)
return count
def decompose(u):
total_size = dfs_size(u, -1)
centroid, _ = find_centroid(u, -1, total_size)
deleted[centroid] = True
total = 0
all_paths = []
# 处理各子树
for v, w in adj[centroid]:
if not deleted[v]:
paths = []
get_paths(v, centroid, w, paths)
total += count_pairs(paths, paths) // 2 # 同一子树内的路径
total += count_pairs(paths, all_paths) # 不同子树间的路径
all_paths.extend(paths)
# 单独处理经过重心的单边路径
total += bisect.bisect_right(all_paths, k)
# 递归处理子树
for v, w in adj[centroid]:
if not deleted[v]:
total += decompose(v)
deleted[centroid] = False
return total
result = decompose(1)
print(result + n) # 加上单节点路径
solve()
3.2 Java实现
java复制import java.util.*;
import java.io.*;
public class Main {
static List<List<int[]>> adj;
static boolean[] deleted;
static int k;
static int n;
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String[] firstLine = br.readLine().split(" ");
n = Integer.parseInt(firstLine[0]);
k = Integer.parseInt(firstLine[1]);
adj = new ArrayList<>();
for (int i = 0; i <= n; i++) {
adj.add(new ArrayList<>());
}
for (int i = 0; i < n-1; i++) {
String[] line = br.readLine().split(" ");
int u = Integer.parseInt(line[0]);
int v = Integer.parseInt(line[1]);
int w = Integer.parseInt(line[2]);
adj.get(u).add(new int[]{v, w});
adj.get(v).add(new int[]{u, w});
}
deleted = new boolean[n+1];
int result = decompose(1) + n;
System.out.println(result);
}
static int dfsSize(int u, int parent) {
int size = 1;
for (int[] edge : adj.get(u)) {
int v = edge[0];
if (v != parent && !deleted[v]) {
size += dfsSize(v, u);
}
}
return size;
}
static int findCentroid(int u, int parent, int totalSize) {
int size = 1;
int maxSub = 0;
for (int[] edge : adj.get(u)) {
int v = edge[0];
if (v != parent && !deleted[v]) {
int res = findCentroid(v, u, totalSize);
if (res >= 0) return res;
int childSize = -res;
size += childSize;
maxSub = Math.max(maxSub, childSize);
}
}
maxSub = Math.max(maxSub, totalSize - size);
if (maxSub <= totalSize / 2) {
return u;
}
return -size;
}
static void getPaths(int u, int parent, long currentProduct, List<Long> paths) {
if (currentProduct > k) return;
paths.add(currentProduct);
for (int[] edge : adj.get(u)) {
int v = edge[0];
int w = edge[1];
if (v != parent && !deleted[v]) {
getPaths(v, u, currentProduct * w, paths);
}
}
}
static int countPairs(List<Long> paths1, List<Long> paths2) {
Collections.sort(paths2);
int count = 0;
for (long p : paths1) {
if (p == 0) continue;
long maxQ = k / p;
int idx = Collections.binarySearch(paths2, maxQ);
if (idx < 0) {
idx = -idx - 2;
}
count += idx + 1;
}
return count;
}
static int decompose(int u) {
int totalSize = dfsSize(u, -1);
int centroid = findCentroid(u, -1, totalSize);
deleted[centroid] = true;
int total = 0;
List<Long> allPaths = new ArrayList<>();
for (int[] edge : adj.get(centroid)) {
int v = edge[0];
int w = edge[1];
if (!deleted[v]) {
List<Long> paths = new ArrayList<>();
getPaths(v, centroid, w, paths);
// 同一子树内的路径
total += countPairs(paths, paths) / 2;
// 不同子树间的路径
total += countPairs(paths, allPaths);
allPaths.addAll(paths);
}
}
// 处理单边路径
for (long p : allPaths) {
if (p <= k) total++;
}
// 递归处理子树
for (int[] edge : adj.get(centroid)) {
int v = edge[0];
if (!deleted[v]) {
total += decompose(v);
}
}
deleted[centroid] = false;
return total;
}
}
3.3 C++实现
cpp复制#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
vector<vector<pair<int, int>>> adj;
vector<bool> deleted;
int k, n;
int dfs_size(int u, int parent) {
int size = 1;
for (auto &[v, w] : adj[u]) {
if (v != parent && !deleted[v]) {
size += dfs_size(v, u);
}
}
return size;
}
int find_centroid(int u, int parent, int total_size) {
int size = 1;
int max_sub = 0;
for (auto &[v, w] : adj[u]) {
if (v != parent && !deleted[v]) {
int res = find_centroid(v, u, total_size);
if (res >= 0) return res;
int child_size = -res;
size += child_size;
max_sub = max(max_sub, child_size);
}
}
max_sub = max(max_sub, total_size - size);
if (max_sub <= total_size / 2) {
return u;
}
return -size;
}
void get_paths(int u, int parent, long long current_product, vector<long long> &paths) {
if (current_product > k) return;
paths.push_back(current_product);
for (auto &[v, w] : adj[u]) {
if (v != parent && !deleted[v]) {
get_paths(v, u, current_product * w, paths);
}
}
}
int count_pairs(vector<long long> &paths1, vector<long long> &paths2) {
sort(paths2.begin(), paths2.end());
int count = 0;
for (long long p : paths1) {
if (p == 0) continue;
long long max_q = k / p;
auto it = upper_bound(paths2.begin(), paths2.end(), max_q);
count += it - paths2.begin();
}
return count;
}
int decompose(int u) {
int total_size = dfs_size(u, -1);
int centroid = find_centroid(u, -1, total_size);
deleted[centroid] = true;
int total = 0;
vector<long long> all_paths;
for (auto &[v, w] : adj[centroid]) {
if (!deleted[v]) {
vector<long long> paths;
get_paths(v, centroid, w, paths);
// 同一子树内的路径
total += count_pairs(paths, paths) / 2;
// 不同子树间的路径
total += count_pairs(paths, all_paths);
all_paths.insert(all_paths.end(), paths.begin(), paths.end());
}
}
// 处理单边路径
for (long long p : all_paths) {
if (p <= k) total++;
}
// 递归处理子树
for (auto &[v, w] : adj[centroid]) {
if (!deleted[v]) {
total += decompose(v);
}
}
deleted[centroid] = false;
return total;
}
int main() {
cin >> n >> k;
adj.resize(n+1);
deleted.resize(n+1, false);
for (int i = 0; i < n-1; i++) {
int u, v, w;
cin >> u >> v >> w;
adj[u].emplace_back(v, w);
adj[v].emplace_back(u, w);
}
int result = decompose(1) + n;
cout << result << endl;
return 0;
}
4. 算法复杂度分析
4.1 时间复杂度
- 重心分解:每次递归都将问题规模至少减半,递归深度为O(log n)
- 每层处理:
- 寻找重心:O(n)
- 收集路径:O(n)
- 排序和双指针:O(n log n)
- 总时间复杂度:O(n log² n)
4.2 空间复杂度
- 邻接表存储:O(n)
- 递归栈:O(log n)
- 路径存储:O(n)
- 总空间复杂度:O(n)
5. 边界条件与注意事项
5.1 特殊输入处理
- 空树或单节点树:直接返回节点数
- 所有边权为1:转化为统计路径长度≤k的路径数
- k=0:只有当边权为0时才可能有解
5.2 数值溢出问题
边权乘积可能非常大,需要注意:
- 使用64位整数存储乘积(long in Java, long long in C++)
- 当乘积超过k时提前终止DFS
5.3 实现细节
- 重心标记:在递归处理子树前标记已处理的重心
- 路径收集:从重心出发到各子树的路径
- 双指针优化:先排序再使用二分查找统计合法路径数
6. 测试用例设计
6.1 基础测试用例
code复制输入:
3 6
1 2 2
2 3 3
输出:
5
解释:
路径:(1), (2), (3), (1-2:2), (2-3:6)
其中2×3=6不满足≤6,所以共5条
6.2 较大规模测试用例
code复制输入:
5 10
1 2 2
1 3 3
2 4 1
3 5 2
输出:
12
6.3 极端测试用例
code复制输入:
1 100
输出:
1
7. 实际应用场景
这类算法在美团的业务中有多种应用:
- 骑手路径规划:寻找满足时间/距离约束的配送路线
- 商家推荐系统:基于用户-商家关系树寻找关联度高的推荐
- 网络拓扑优化:在数据中心网络中选择满足带宽要求的路径
8. 算法优化方向
- 并行化处理:子树分解可以并行计算
- 预处理技术:对常见查询进行预处理
- 近似算法:对大规模数据使用近似统计
9. 常见错误与调试技巧
- 忘记重置deleted数组:会导致无限递归
- 乘积溢出:没有使用足够大的整数类型
- 双指针实现错误:注意二分查找的边界条件
- 路径重复计数:同一子树内的路径会被重复计算
调试时可以:
- 打印中间结果验证重心选择
- 小规模测试手动验证
- 对比暴力解法的结果