终于搞懂了树链剖分的一些皮毛了……
树链剖分
“树链剖分”,顾名思义,就是把一棵树剖分成一条条的链……
重链剖分
重链剖分的基本概念
重链剖分是树链剖分的一种,它会把树剖分成一条条重链……
什么是重链呢?
重链就是连接每一个树的重儿子所形成的链。
重儿子就是其儿子重以儿子为根的子树大小最大的儿子。
画一个图来理解一下:
对于这样一棵树,它剖成重链应该是这样的(红色是重链,绿色包裹的是重链上的点):
重链剖分的过程
首先,如果我们需要知道重儿子,那得知道子树的大小在进行处理,所以要分两次 dfs 来实现。
第一次 dfs
这一次 dfs 主要是算出每一个点的深度、父亲、子树大小。
就是一个简单的 dfs,很好理解。
局部代码
void dfs1(int x,int father,int deep)
{
dep[x]=deep;
fa[x]=father;
siz[x]=1;
int mxson=-1;
for(auto y:g[x])
{
if(y==father)continue;
dfs1(y,x,deep+1);
siz[x]+=siz[y];
if(siz[y]>mxson)
{
mxson=siz[y];
son[x]=y;
}
}
}
第二次 dfs
算出每个点的链头、重儿子,以及其 dfs 序的编号(这个编号后续会有用)。
这一个比较简单的 dfs,应该也比较好理解。
局部代码
void dfs2(int x,int tf)
{
id[x]=++cnt;
top[x]=tf;
if(!son[x])return;
dfs2(son[x],tf);
for(auto y:g[x])
{
if(y==fa[x]||y==son[x])continue;
dfs2(y,y);
}
}
以上就是重链剖分的基本步骤,时间复杂度 ,也是十分的高级好吧。
重链剖分的应用
学会了重链剖分的基本步骤,那还是得学会怎么用对吧……
先来看一道题目
洛谷 P3384 【模板】重链剖分/树链剖分https://www.luogu.com.cn/problem/P3384
题目大意
解题思路
当然,一看题目的名称,就知道是重链剖分……
回顾重链剖分(这次加上每个点的 dfs 序编号):
也许你会发现,对于在同一个重链里面的点,他们的编号都是连续的。
由此,我们可以把树看成一个个序列,从树上问题转换为序列问题。
那题目中的每个操作都可以看做是在序列上进行的……
那么,很自然地就可以想到线段树可以解决它。
对于第 1 个操作
和倍增法求 LCA 类似地,我们可以依次往上跳,直到 都在同一条链上,每一次跳,都可以看做是在链上这个区间内加上了 ,直接套上线段树即可。
但是,最后也要记得处理最终的 之间的这个区间。
局部代码:
void update(int x,int y,int z)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x,y);
change(1,id[top[x]],id[x],z);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
change(1,id[x],id[y],z);
}
对于第 2 个操作
和第一个操作类似地,也是往上跳,直到 在同一条链上,只不过每次跳时都是对链这个区间进行一次求和查询,累加即可。
局部代码:
int que(int x,int y)
{
int ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x,y);
ans+=query(1,id[top[x]],id[x]);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans+=query(1,id[x],id[y]);
return ans;
}
对于第 3 个操作
可以再次发现,同一颗子树内的点的编号是连续的,所以可以一次修改操作就行了。
修改的区间是 , 是 的编号, 是以 为子树的大小。
局部代码(好像没啥好展示的):
change(1,id[x],id[x]+siz[x]-1,z);
对于第 4 个操作
修改的区间一样,只是查询操作而已。
局部代码:
cout<<query(1,id[x],id[x]+siz[x]-1)<<"\n";
完整代码
记得取模!!!
#include<bits/stdc++.h>
using namespace std;
#define int long long
int n,q,rt,mod;
int a[100001];
vector<int> g[100001];
int dep[100001];
int fa[100001];
int siz[100001];
int son[100001];
int id[100001],top[100001],wt[100001],cnt;
void dfs1(int x,int father,int deep)
{
dep[x]=deep;
fa[x]=father;
siz[x]=1;
int mxson=-1;
for(auto y:g[x])
{
if(y==father)continue;
dfs1(y,x,deep+1);
siz[x]+=siz[y];
if(siz[y]>mxson)
{
mxson=siz[y];
son[x]=y;
}
}
}
void dfs2(int x,int tf)
{
id[x]=++cnt;
wt[id[x]]=a[x];
top[x]=tf;
if(!son[x])return;
dfs2(son[x],tf);
for(auto y:g[x])
{
if(y==fa[x]||y==son[x])continue;
dfs2(y,y);
}
}
struct tree{
int sum,l,r,add;
}tr[400001];
void build(int u,int l,int r)
{
tr[u]={0,l,r,0};
if(l==r)
{
tr[u].sum=wt[l];
tr[u].sum%mod;
return;
}
int mid=l+r>>1;
build(u*2,l,mid);
build(u*2+1,mid+1,r);
tr[u].sum=tr[u*2].sum+tr[u*2+1].sum;
tr[u].sum%=mod;
}
void push_down(int u)
{
if(tr[u].add)
{
tr[u*2].sum+=(tr[u*2].r-tr[u*2].l+1)*tr[u].add;
tr[u*2].add+=tr[u].add;
tr[u*2].sum%=mod;
// tr[u*2].add%=mod;
tr[u*2+1].sum+=(tr[u*2+1].r-tr[u*2+1].l+1)*tr[u].add;
tr[u*2+1].add+=tr[u].add;
tr[u*2+1].sum%=mod;
// tr[u*2+1].add%=mod;
tr[u].add=0;
}
}
void push_up(int u)
{
tr[u].sum=tr[u*2].sum+tr[u*2+1].sum;
tr[u].sum%=mod;
}
void change(int u,int l,int r,int d)
{
if(l<=tr[u].l&&tr[u].r<=r)
{
tr[u].sum+=(tr[u].r-tr[u].l+1)*d;
tr[u].sum%=mod;
tr[u].add+=d;
return;
}
push_down(u);
int mid=tr[u].l+tr[u].r>>1;
if(l<=mid)
change(u*2,l,r,d);
if(r>mid)
change(u*2+1,l,r,d);
push_up(u);
}
int query(int u,int l,int r)
{
if(l<=tr[u].l&&tr[u].r<=r)
{
return tr[u].sum;
}
push_down(u);
int mid=tr[u].l+tr[u].r>>1;
int res=0;
if(l<=mid)
res+=query(u*2,l,r);
res%=mod;
if(r>mid)
res+=query(u*2+1,l,r);
res%=mod;
return res;
}
void update(int x,int y,int z)
{
z%=mod;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x,y);
change(1,id[top[x]],id[x],z);
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
change(1,id[x],id[y],z);
}
int que(int x,int y)
{
int ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
swap(x,y);
ans+=query(1,id[top[x]],id[x]);
ans%=mod;
x=fa[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans+=query(1,id[x],id[y]);
ans%=mod;
return ans;
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin>>n>>q>>rt>>mod;
for(int i=1;i<=n;i++)
{
cin>>a[i];
}
int u,v;
for(int i=1;i<n;i++)
{
cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs1(rt,0,1);
dfs2(rt,rt);
build(1,1,n);
int op,x,y,z;
while(q--)
{
cin>>op;
if(op==1)
{
cin>>x>>y>>z;
update(x,y,z);
}
if(op==2)
{
cin>>x>>y;
cout<<que(x,y)<<"\n";
}
if(op==3)
{
cin>>x>>z;
change(1,id[x],id[x]+siz[x]-1,z);
}
if(op==4)
{
cin>>x;
cout<<query(1,id[x],id[x]+siz[x]-1)%mod<<"\n";
}
}
}
最后求赞勿喷。