树套树,就是线段树、平衡树、树状数组等数据结构的嵌套。
最简单的是线段树套set,可以解决一些比较简单的问题,而且代码根线段树是一样的只是一些细节不太一样。
本题中用的是线段树套splay,代码较长。
树套树中的splay和单一的splay原理是一样的,只不过是建了很多的splay树,因为不止一个,所以跟板子不同的是,大部分函数都要传splay的根节点规定起点。
而线段树中存储的就是每个区间对应的splay的root节点。
只要线段树和splay板子都懂了,这一题就很好理解。
const int mod = 1e9 + 7, INF = 2147483647;
const int N = 1e7+ 10;
int n, m;
struct Node {
int s[2], p, v; // 左右儿子、父节点、值
int size, cnt; // 子树大小、懒标记
void init(int _v, int _p) { // 初始化函数
v = _v, p = _p;
cnt = size = 1;
}
} tr[N];
int L[N], R[N], T[N], idx;
int w[N];
void pushup(int u) { // 向上更新传递,与线段树一样
tr[u].size = tr[tr[u].s[0]].size + tr[tr[u].s[1]].size + tr[u].cnt;
}
void rotate(int x) { // 核心函数
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y, tr[y].p = x;
pushup(y), pushup(x);
}
void splay(int& root, int x, int k) { // 将x节点旋转到k节点下
while(tr[x].p != k) { //
int y = tr[x].p; // x节点的父节点
int z = tr[y].p; // x节点的父节点的父节点
if(z != k) // 向上旋转
if((tr[y].s[1] == x) != (tr[z].s[1] == y)) rotate(x); // 转一次x
else rotate(y); // 转一次y
rotate(x); // 转一次x
}
if(!k) root = x; // 更新root节点
}
void upper(int& root, int v) { // 将v值节点转到根节点
int u = root; // 根节点
while(tr[u].s[v > tr[u].v] && tr[u].v != v) // 存在则找到v值节点,不存在则找到v值节点的前驱或者后继节点
u = tr[u].s[v > tr[u].v]; // 向下寻找
splay(root, u, 0); // 将u节点旋转到跟节点
}
int get_prev(int& root, int v) { // 获取v值的前驱节点,严格小于v的最大节点
upper(root, v); // 将v值节点转到根节点
if(tr[root].v < v) return root; // 若是该值在树中不存在,根节点就是v的前驱或者后继节点
int u = tr[root].s[0]; // 前驱节点在左子树的最右边
while(tr[u].s[1]) u = tr[u].s[1]; // 找到最右边的一个节点
return u;
}
int get_next(int& root, int v) { // 获取某值的后继节点,严格大于v的最小节点
upper(root, v); // 将v值节点转到根节点
if(tr[root].v > v) return root; // 若是该值在树中不存在,根节点就是v的前驱或者后继节点
int u = tr[root].s[1]; // 后继节点在右子树的最左边
while(tr[u].s[0]) u = tr[u].s[0]; // 找到最左的节点,就是最小的节点
return u; // 返回节点
}
void insert(int& root, int v) { // 在二叉树中插入一个值
int u = root, p = 0; // p维护为当前节点的父节点
while(u && tr[u].v != v) // 没找到则一直向下寻找
p = u, u = tr[u].s[v > tr[u].v]; // 更新父节点,更新当前节点
if(u) tr[u].cnt ++; // v值的节点已经存在则直接加一即可
else { // 不存在则创建节点
u = ++ idx; // 分配节点序号
if(p) tr[p].s[v > tr[p].v] = u; // 将父节点也就是前驱节点指向当前节点
tr[u].init(v, p); // 初始化当前节点的值、父节点信息
}
splay(root, u, 0); // 将u节点旋转到根节点下
}
int get_k(int root, int v) { // 获得树中有多少比v小的数
int u = root, res = 0;
while(u) {
if(tr[u].v < v) res += tr[tr[u].s[0]].size + tr[u].cnt, u = tr[u].s[1];
else u = tr[u].s[0];
}
return res;
}
void remove(int& root, int v) { // 删除一个值为v的节点
int prev = get_prev(root, v), nex = get_next(root, v); // 获取该节点的前驱以及后继节点。
splay(root, prev, 0), splay(root, nex, prev); // 将前继节点旋转到根节点,将后继节点旋转到前驱节点下面也就是根节点下面
int w = tr[nex].s[0]; // 后继节点的左子树就是v的节点
if(tr[w].cnt > 1) tr[w].cnt --, splay(root, w, 0); // 该节点的v不止存在一个,减一,w节点旋转到根节点
else tr[nex].s[0] = 0, splay(root, nex, 0); // 唯一,那么直接把后继节点的左子树指向空也就是0即可
}
void update(int& root, int x, int y) { // 将一个x值点改为y值
remove(root, x); // 先删除
insert(root, y); // 再插入
}
void build(int u, int l, int r) {
L[u] = l, R[u] = r; // 存储某个节点的左右边界
insert(T[u], -INF), insert(T[u], INF); // 插入哨兵
for(int i = l; i <= r; i ++) insert(T[u], w[i]); // 初始化线段树每个节点的平衡树
if(l == r) return ;
int mid = l + r >> 1;
build(u << 1, l, mid); // 建左子树
build(u << 1 | 1, mid + 1, r); // 建右子树
}
int query(int u, int a, int b, int x) { // 查询区间a,b之间有多少比x值小的数
if(a <= L[u] && R[u] <= b) return get_k(T[u], x) - 1;
int mid = L[u] + R[u] >> 1, res = 0;
if(a <= mid) res += query(u << 1, a, b, x); // 查询左子树中有多少是该区间并且小于x的数
if(mid < b) res += query(u << 1 | 1, a, b, x); // 查询右子树中有多少是该区间并且小于x的数
return res;
}
void change(int u, int p, int x) { // 将线段树中p位置数值改为x
update(T[u], w[p], x); // 修改当前节点中平衡树中的值
if(L[u] == R[u]) return ;
int mid = L[u] + R[u] >> 1;
if(p <= mid) change(u << 1, p, x); // 修改左子树
else change(u << 1 | 1, p, x); // 修改右子树
}
int query_prev(int u, int a, int b, int x) { // 查询再该区间中x的前驱节点
if(a <= L[u] && R[u] <= b) return tr[get_prev(T[u], x)].v; // 该函数为查找当前子树中x的前驱节点
int mid = L[u] + R[u] >> 1, res = -INF;
if(a <= mid) res = max(res, query_prev(u << 1, a, b, x)); // 递归左子树
if(mid < b) res = max(res, query_prev(u << 1 | 1, a, b, x)); // 递归右子树
return res; // 返回左右子树中的最大值
}
int query_next(int u, int a, int b, int x) { // 查询再该区间中x的后继节点
if(a <= L[u] && R[u] <= b) return tr[get_next(T[u], x)].v; // 该函数为查找当前子树中x的后继节点
int mid = L[u] + R[u] >> 1, res = INF;
if(a <= mid) res = min(res, query_next(u << 1, a, b, x));
if(mid < b) res = min(res, query_next(u << 1 | 1, a, b, x));
return res; // 返回左右子树中的最小值
}
int get_rank_to_tr(int a, int b, int x) { // 查找区间内排名第x的数
int l = 0, r = 1e8;
while(l < r) { // 通过二分获得答案,因为只能判断某个数在区间内的排名。
int mid = l + r + 1 >> 1;
if(query(1, a, b, mid) + 1 <= x) l = mid; //
else r = mid - 1;
}
return r;
}
inline void sovle() {
cin >> n >> m;
for(int i = 1; i <= n; i ++)cin >> w[i];
build(1, 1, n);
while(m --) {
int op, a, b, x;
cin >> op >> a >> b;
if(op != 3) cin >> x;
if(op == 1) cout << query(1, a, b, x) + 1 << endl;
if(op == 2) cout << get_rank_to_tr(a, b, x) << endl;
if(op == 3) {
change(1, a, b);
w[a] = b;
}
if(op == 4) cout << query_prev(1, a, b, x) << endl;
if(op == 5) cout << query_next(1, a, b, x) << endl;
}
}