还是一道比较明显的求
LCA
(最近公共祖先)模型的题目,我们可以使用多种方法来解决该问题,这里我们使用更好写的离线的tarjan
算法来解决该问题。
除去tarjan
算法必用的基础数组,我们还有一个数组d[]
,d[i]
记录的是每个点的出度,也就是它的延迟时间,以及数组w[]
,w[i]
的含义是点i
到根节点的延迟时间。在通过dfs
求出每个点i
的w[i]
以后,在tarjan
中我们该如何求出两点的延迟时间呢?
我们设点i
到j
的延迟时间为f(x)
,当我们求得i
与j
的最近公共祖先为anc
,我们首先让f(x)=w[i]+w[j]
但很明显,我们多加了两w[anc]
,所以我们需要减去两倍的w[anc]
但延迟时间还包括经过anc
的时间,所以还得加上一个d[anc]
。此处请结合w[]
和d[]
的含义理解。
最后能得出式子:f(x)=w[i]+w[h]−w[anc]2+d[anc]
我们利用这个式子在tarjan
函数中就能得出每个询问的答案,当然对于起始和结束都在同一个节点的情况下,它的答案就是当前节点的出度,我们可以进行特判一下。输入输出较多,建议使用scanf
和printf
进行输入输出。
时间复杂度:dfs
:每个点遍历一次,复杂度级别O(n)
,tarjan
算法复杂度接近 O(n+m)
。
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair<int,int> PII;
const int N=100010;
unordered_map<int,vector<int>> gra;
int n,m;
//单个点的出度
int d[N];
//记录点i到根节点的延迟
int w[N];
//并查集数组
int q[N];
//记录答案
int res[N];
int st[N];
//存下查询
vector<PII> query[N];
//并查集查询
int find(int x){
if(x!=q[x]) q[x]=find(q[x]);
return q[x];
}
void dfs(int u,int fa)
{
w[u]+=d[u];
for(auto g:gra[u]){
if(g==fa) continue;
w[g]+=w[u];
dfs(g,u);
}
}
void tarjan(int u)
{
st[u]=1;
for(auto j:gra[u]){
if(!st[j])
{
tarjan(j);
q[j]=u;
}
}
for(auto item: query[u]){
int y=item.first,id=item.second;
if(st[y]==2){
int anc=find(y);
res[id]=w[y]+w[u]-w[anc]*2+d[anc];
}
}
st[u]=2;
}
int main()
{
cin>>n>>m;
for(int i=0;i<n-1;++i){
int a,b;
scanf("%d%d",&a,&b);
gra[a].push_back(b);
gra[b].push_back(a);
d[a]++,d[b]++;
}
for(int i=0;i<m;++i){
int a,b;
scanf("%d%d",&a,&b);
if(a!=b){
query[a].push_back({b,i});
query[b].push_back({a,i});
}else{
res[i]=d[a];
}
}
dfs(1,-1);
for(int i=1;i<=n;++i) q[i]=i;
tarjan(1);
for(int i=0;i<m;++i) printf("%d\n",res[i]);
return 0;
}
错误答案:用floyd直接爆炸
错误答案
#include<bits/stdc++.h>
using namespace std;
const int N=1005,M=1005;
int deg[N];//度
int dis[N][N];
int main(){
ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
memset(dis,0x7f,sizeof(dis));
int n,m;cin>>n>>m;int v1,v2;
for(int i=1;i<n;++i){
cin>>v1>>v2;
++deg[v1];++deg[v2];
}
for(int i=1;i<n;++i){
dis[v1][v2]=deg[v1];
dis[v2][v1]=deg[v2];
}
for(int k=1;k<=n;k++)for(int v1=1;v1<=n;v1++)for(int v2=1;v2<=n;v2++)//枚举点
if((v1!=k)&&(v2!=k)&&(v1!=v2))
dis[v1][v2]=min(dis[v1][v2],dis[v1][k]+dis[k][v2]);
int start,end;
while(m--){
cin>>start>>end;
cout<<dis[start][end]+deg[end];
}
return 0;
}
/*
4 3
1 2
1 3
2 4
2 3
3 4
3 3
*/