treap实际上就是tree(BST,二叉搜索树)+heap(堆)
我们维护一个二叉树来储存值,但是为了避免二叉树由于值太特殊变成链式结构,我们对于每个点加入一个val值,这个是随机值,我们通过这个随机值来维护一个大根堆(只与val有关的大根堆),进而使得我们维护能够用一个比较对称的二叉树来维护所有数据,可以类比AVL树来思考。
那么我们为什么要用treap,或者换句话说treap可以进行哪些操作呢?
1.插入
2.删除
3.找前驱和后继(小于某个数最大的数和大于某个数最小的数,BST中不同节点数不同)
4.找最大/最小
5.求某个值的排名
6.求排名是rank的数是哪个
7.比某个数小的最大值(数可能不在我们维护的区间中)
8.比某个数大的最小值(同上)
对于这些操作,显然前几个是可以用set实现的,但是求某个值的排名和求排名是k的数是哪个,这两个操作显然用set没办法很好的实现。所以我们必须要使用Treap。能用set的一定可以用Treap,能用Treap的不一定能用set。
那么我们就要来看如何实现Treap,Treap是基于BST的,那么我们先来回顾一下BST,BST就是维护一棵二叉树,对于这个二叉树而言,它左子节点的值一定小于它,右子节点的值一定大于它。
我们一般是以第一个数为根节点的,但是如果要维护的区间是单调的,那么二叉树就会退化成链式结构,很影响时间复杂度。所以为了避免二叉树退化成链式结构,我们在每个点上增加一个随机值变量,然后根据这个随机值维护一个大根堆,也即父节点的值一定大于左右子节点。然后如果不满足的话,那么就把左右子节点中的大的点通过旋转操作放到根节点上去,旋转操作如下:
可以发现即使我们进行了旋转操作,维护的仍然是一个二叉树,不会对我们在二叉树上的操作产生影响。与此同时,这也给了我们一个删除中间节点的思路,我们可以把中间节点换到叶子节点上进行删除。
那么我们下面来看代码层面的实现:
首先是对每个节点的定义,最基本的需要一下几个值(当然由于题目不同,可能还要加一些别的变量):
struct node{
int l,r;//l,r表示左右子节点的下标,我们这里用下标表示指针
int key,val;//key表示我们实际存的值,val存的是我们随机赋的值,用以维护堆
}
然后我们来看有哪些基本操作:
新增节点:
int get_node()
{
tr[++idx].key=key;
tr[idx].val=rand();
return idx;
}
左旋:
右到父,父到左,右左到左右。先拔再转。
void zag(int &p)//左旋
{
int q=tr[p].r;
tr[p].r=tr[q].l,tr[q].l=p,p=q;
pushup(tr[p].l),pushup(p);
}
通过前三个赋值,我们已经实现了子节点的交换,相对于原来的p的父节点的更新我们该怎么实现了,显然我们没有存父节点的信息, 不能从子节点访问到父节点,但是我们肯定是在递归中实现的,我们只要在递归的时候传入tr[u].l的引用,那么在这一层进行修改的时候就可以实现对tr[u].l的修改。
右旋:
void zig(int &p)//右旋
{
int q=tr[p].l;
tr[p].l=tr[q].r,tr[q].r=p,p=q;
pushup(tr[p].r),pushup(p);
}
这里的pushup就是用子节点更新父节点的一些信息,和线段树中的pushup差不多。
初始化:
初始化的时候,我们设置了两个哨兵,它们的值分别赋值为正无穷和负无穷,先令root为负无穷对应的点1,它的右子为正无穷对应的点2.然后根据它们的val,判断是否需要左旋。
void build()
{
get_node(-INF),get_node(INF);
root=1,tr[root].r=2;
pushup(root);
if(tr[1].val<tr[2].val) zag(root);//维护大根堆
}
插入
和二叉树的插入操作一样,注意判断是否需要旋转操作即可。
void insert(int &p,int key)
{
if(!p) get_node(key);
else if(tr[p].key<key)
{
insert(tr[p].r,key);
if(tr[p].val<tr[tr[p].r].val) zig(p);
}
else
{
insert(tr[p].l,key);
if(tr[p].val<tr[tr[p].l].val) zag(p);
}
}
删除
删除要分三种情况,一种是要删除的数不存在,那么最后搜到的位置就是0,那么直接返回即可;一种是要删除的数在叶子节点上,那么我们直接把这个叶子节点的下标修改成0即可,因为我们在上一层传入的是它父节点的tr[u].l或者tr[u].r的引用,所以这里直接赋成0就相当于将它父节点的左指针或者右指针的指向修改成空;最后一种最麻烦,就是要删除的数不在叶子节点上,那么就需要通过旋转将它换到叶子节点上,然后再用第二种方法执行删除。
void remove(int &p,int key)
{
if(!p) return;
if(tr[p].l||tr[p.r])
{
if(!tr[p].r||tr[tr[p].l].val>tr[tr[p].r].val)
{
zig(p);//p和原p的父节点的某个子节点都被修改成了原p的左子节点
remove(tr[p].r,key);
}
else
{
zag(p);
remove(tr[p].l,key);
}
}
else p=0;
}
找严格小于key的数中最大的数
int get_prev(int p,int key)
{
if(!p) return -INF;//没有小于key的数
if(tr[p].key>=key) return get_prev(tr[p].l,key);
return max(tr[p].key,get_prev(tr[p].r,key));
}
找到严格大于key的最小数
int get_next(int p,int key)
{
if(!p) return INF;
if(p.key<=key) return get_next(tr[p].r,key);
return min(tr[p].key, get_next(tr[p].l,key) );
}
找最大:
int get_max(int p)
{
if(!p) return -INF;
return max(tr[p].key,get_max(tr[p].r));
}
找最小:
int get_min(int p)
{
if(!p) return INF;
return min(tr[p].key,get_max(tr[p].l));
}
求某个值的排名
求排名的时候为了方便,我们引入了以某个点为根节点的子树的大小size,如果有重复元素,我们还可以引入一个cnt变量,表示值为当前节点的点有多少个。
int get_rank_by_key(int p,int key)
{
if(!p) return 0;
if(tr[p].key==key) return tr[tr[p].l].size+1;
if(tr[p].key>key) return get_rank_by_key(tr[p].l,key);
return tr[tr[p].l].size()+get_rank_by_key(tr[p].r,key);
}
求排名是rank的数是哪个
int get_key_by_rank(int p,int rank)
{
if(!p) return INF;
if(tr[tr[p].l].size >= rank) return get_key_by_rank(tr[p].l,rank);
if(tr[p].size+tr[p].cnt>=key) return tr[p].key;//cnt表示的是当前数有多少个
return get_key_by_rank(tr[p].r,rank-tr[tr[p].l].size-tr[p].cnt);
}
差不多就是这些操作,剩下的我们根据具体的题目再来分析。
253. 普通平衡树(253. 普通平衡树 - AcWing题库)
这个就是一道比较裸的题目,这里要注意我们需要在原来节点定义的基础上引入size和cnt,因为这里有根据排名求数和根据数求排名两个操作。
#include<bits/stdc++.h>
using namespace std;
const int inf=0x3f3f3f3f;
struct node{
int l,r;
int key,val;
int size,cnt;
}tr[100010];
int root,idx;
int get_node(int key)
{
tr[++idx].key=key;
tr[idx].val=rand();
tr[idx].size=tr[idx].cnt=1;
return idx;
}
void pushup(int p)
{
tr[p].size=tr[tr[p].l].size+tr[p].cnt+tr[tr[p].r].size;
}
void left(int &p)//左旋
{
int q=tr[p].r;
tr[p].r=tr[q].l,tr[q].l=p,p=q;
pushup(tr[p].l),pushup(p);
}
void right(int &p)//右旋
{
int q=tr[p].l;
tr[p].l=tr[q].r,tr[q].r=p,p=q;
pushup(tr[p].r),pushup(p);
}
void build()
{
get_node(-inf),get_node(inf);
root=1,tr[1].r=2;
pushup(root);
if(tr[1].val<tr[2].val) left(root);
}
void insert(int &p,int key)
{
if(!p) p=get_node(key);//这里赋值了才算被插进去,因为p传入的是引用
else if(tr[p].key==key) tr[p].cnt++;
else if(tr[p].key>key)
{
insert(tr[p].l,key);
if(tr[p].val<tr[tr[p].l].val) right(p);
}
else
{
insert(tr[p].r,key);
if(tr[p].val<tr[tr[p].r].val) left(p);
}
pushup(p);
}
void remove(int &p,int key)
{
if(!p) return;
if(tr[p].key==key)
{
if(tr[p].cnt>1) tr[p].cnt--;
else if(tr[p].l||tr[p].r)
{
if(!tr[p].r||tr[tr[p].l].val>tr[tr[p].r].val)
{
right(p);
remove(tr[p].r,key);
}
else
{
left(p);
remove(tr[p].l,key);
}
}else p=0;
}
else if(tr[p].key>key) remove(tr[p].l,key);
else remove(tr[p].r,key);
pushup(p);
}
int get_rank_by_key(int p,int key)
{
if(!p) return 0;
if(tr[p].key==key) return tr[tr[p].l].size+1;
else if(tr[p].key>key) return get_rank_by_key(tr[p].l,key);
else return tr[tr[p].l].size+tr[p].cnt+get_rank_by_key(tr[p].r,key);
}
int get_key_by_rank(int p,int rank)
{
if(!p) return inf;
if(tr[tr[p].l].size>=rank) return get_key_by_rank(tr[p].l,rank);
else if(tr[tr[p].l].size+tr[p].cnt>=rank) return tr[p].key;
else get_key_by_rank(tr[p].r,rank-tr[p].cnt-tr[tr[p].l].size);
}
int get_prev(int p,int key)
{
if(!p) return -inf;
if(tr[p].key>=key) return get_prev(tr[p].l,key);
else return max(tr[p].key,get_prev(tr[p].r,key));
}
int get_next(int p,int key)
{
if(!p) return inf;
if(tr[p].key<=key) return get_next(tr[p].r,key);
else return min(tr[p].key,get_next(tr[p].l,key));
}
int main()
{
build();
int n;
scanf("%d",&n);
while(n--)
{
int op,x;
scanf("%d%d",&op,&x);
if(op==1) insert(root,x);
else if(op==2) remove(root,x);
else if(op==3) cout<<get_rank_by_key(root,x)-1<<endl;//因为有哨兵
else if(op==4) cout<<get_key_by_rank(root,x+1)<<endl;
else if(op==5) cout<<get_prev(root,x)<<endl;
else cout<<get_next(root,x)<<endl;
}
}
265. 营业额统计(265. 营业额统计 - AcWing题库)
思路:这题对于每个数要找到在它插入前距离它最近的数,那么实际上就是在每个数插入前找它的前驱(最大的最小值)和后继(最小的最大值),然后取离它更近的那个值。
#include<bits/stdc++.h>
using namespace std;
const int inf=0x3f3f3f3f;
struct node{
int l,r;
int key,val;
}tr[100010];
int root,idx;
int get_node(int key)
{
tr[++idx].key=key;
tr[idx].val=rand();
return idx;
}
void left(int &p)
{
int q=tr[p].r;
tr[p].r=tr[q].l,tr[q].l=p,p=q;
}
void right(int &p)
{
int q=tr[p].l;
tr[p].l=tr[q].r,tr[q].r=p,p=q;
}
void build()
{
get_node(-inf),get_node(inf);
root=1,tr[1].r=2;
if(tr[1].val<tr[2].val) left(root);
}
void insert(int &p,int key)
{
if(!p) p=get_node(key);
else if(tr[p].key==key) return;
else if(tr[p].key>key)
{
insert(tr[p].l,key);
if(tr[tr[p].l].val>tr[p].val) right(p);
}
else
{
insert(tr[p].r,key);
if(tr[tr[p].r].val>tr[p].val) left(p);
}
}
int get_prev(int p,int key)
{
if(!p) return -inf;
if(tr[p].key>key) return get_prev(tr[p].l,key);
return max(tr[p].key,get_prev(tr[p].r,key));
}
int get_next(int p,int key)
{
if(!p) return inf;
if(tr[p].key<key) return get_next(tr[p].r,key);
return min(tr[p].key,get_next(tr[p].l,key));
}
int main()
{
build();
int n;
scanf("%d",&n);
long long ans=0;
for(int i=1;i<=n;i++)
{
int x;
scanf("%d",&x);
if(i==1) ans+=x;
else ans+=min(x-get_prev(root,x),get_next(root,x)-x);
insert(root,x);
}
printf("%lld",ans);
}
在插入的时候,一定不要忘记判断是否需要左右旋。