原题链接:2867. 统计树中的合法路径数目
题目描述:
给你一棵 n
个节点的无向树,节点编号为 1
到 n
。给你一个整数 n
和一个长度为 n - 1
的二维整数数组 edges
,其中 edges[i] = [ui, vi]
表示节点 ui
和 vi
在树中有一条边。
请你返回树中的 合法路径数目 。
如果在节点 a
到节点 b
之间 恰好有一个 节点的编号是质数,那么我们称路径 (a, b)
是 合法的 。
注意:
- 路径
(a, b)
指的是一条从节点a
开始到节点b
结束的一个节点序列,序列中的节点 互不相同 ,且相邻节点之间在树上有一条边。 - 路径
(a, b)
和路径(b, a)
视为 同一条 路径,且只计入答案 一次 。
输入输出描述:
示例 1:
输入:n = 5, edges = [[1,2],[1,3],[2,4],[2,5]] 输出:4 解释:恰好有一个质数编号的节点路径有: - (1, 2) 因为路径 1 到 2 只包含一个质数 2 。 - (1, 3) 因为路径 1 到 3 只包含一个质数 3 。 - (1, 4) 因为路径 1 到 4 只包含一个质数 2 。 - (2, 4) 因为路径 2 到 4 只包含一个质数 2 。 只有 4 条合法路径。
示例 2:
输入:n = 6, edges = [[1,2],[1,3],[2,4],[3,5],[3,6]] 输出:6 解释:恰好有一个质数编号的节点路径有: - (1, 2) 因为路径 1 到 2 只包含一个质数 2 。 - (1, 3) 因为路径 1 到 3 只包含一个质数 3 。 - (1, 4) 因为路径 1 到 4 只包含一个质数 2 。 - (1, 6) 因为路径 1 到 6 只包含一个质数 3 。 - (2, 4) 因为路径 2 到 4 只包含一个质数 2 。 - (3, 6) 因为路径 3 到 6 只包含一个质数 3 。 只有 6 条合法路径。
提示:
1 <= n <= 105
edges.length == n - 1
edges[i].length == 2
1 <= ui, vi <= n
- 输入保证
edges
形成一棵合法的树。
解题思路:
首先数据量有1e5,那么我们可以先预处理,线性筛筛出所有质数,便于后续处理,题目说了合法路径指的是路径上包含刚好一个质数点,那么我们dfs的同时枚举每一个质数点,考虑每一个质数点的贡献,下面画个图描述一下:
由于我们需要记录某个点出发在不经过质数点的情况最多经过多少个点,我们可以用sz[x]记录从x出发在不经过质数点的情况下最多会经过几个点,类似记忆化的思想,所以枚举每一个质数贡献的时候,如果遇到某个点的sz[y]已经被计算过了,直接拿来用即可,避免重复搜索。
时间复杂度:不考虑预处理线性筛的时间,由于采取了记忆化思想,每个点只会访问一次,所以时间复杂度为O(n)。
空间复杂度:O(n)。
cpp代码如下:
const int N=1e5+10;
typedef long long LL;
int primes[N],cnt=0;
bool st[N];
int init=[](){ //线性筛预处理所有质数
st[1]=true;
for(int i=2;i<N;i++)
{
if(!st[i])primes[cnt++]=i;
for(int j=0;primes[j]<=(N-1)/i;j++)
{
st[primes[j]*i]=true;
if(i%primes[j]==0)break;
}
}
return 0;
}();
class Solution {
public:
long long countPaths(int n, vector<vector<int>>& edges) {
vector<vector<int>>g(n+1);
vector<int>sz(n+1);
for(auto& t:edges){
int x=t[0],y=t[1];
g[x].push_back(y);
g[y].push_back(x);
}
vector<int>nodes;
//dfs遍历在不经过指数的情况下最多经过多少个点,也就是非质数连通块大小
function<void(int,int)>dfs=[&](int x,int fa){
nodes.push_back(x);
for(auto& y:g[x]){
if(y!=fa && st[y]){
dfs(y,x);
}
}
};
LL ans=0;
for(int x=1;x<=n;x++)
{
if(st[x])continue; //只需要枚举质数,非质数跳过
int sum=0;
for(int y:g[x]){
if(!st[y])continue; //y是质数,不需要在搜索
if(sz[y]==0) //没有被搜索过,搜索一下
{
nodes.clear();
dfs(y,-1);
for(int x:nodes){ //对于这个连通块中所有点都要标记一下连通块大小,避免重复计算
sz[x]=nodes.size();
}
}
ans+=(LL)sz[y]*sum; //计算贡献
sum+=sz[y]; //sum记录左边的子树大小之和
}
ans+=sum; //这个也是贡献的一部分
}
return ans;
}
};