平衡树学习笔记

幸甚至哉,歌以咏肝

TIPS:文字量不多,感性理解。

BST性质

对于任何BST上的节点,其左子树的节点权值小于当前节点,右子树的节点的权值大于当前节点,因此这个性质能够很好地维护单调性,直接中序遍历即可得到答案。

当然也可以方便查询大小排名,前驱后继等等

也可以来搞LCT,也可以满足一些出题人的奇怪癖好,比如仙人掌上差分 .

Splay

Splay可以防止平衡树退化,通过不断旋转来不断维持自身的平衡性,

旋转

旋转时我们向子节点在父节点的方向的反方向旋转.

这样就可以保持BST的性质,又能维护平衡

于是,异或就可以完美地解决这个旋转。

直接修改相应的父子节点和节点信息即可

找方向

1
2
3
4
inline  int get(int x)
{
return f[f[x].fa].son[1]==x;
}

上传(维护size)

1
2
3
4
void push_up(int x)
{
f[x].size=(f[f[x].son[0]].size+f[f[x].son[1]].size+f[x].cnt);
}

旋转

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
inline void rotate(int x)
{
int fat=f[x].fa;
int gfa=f[fat].fa;
int now=get(x);
int ss=f[x].son[now^1];
f[fat].son[now]=ss;
f[ss].fa=fat;
f[gfa].son[get(fat)]=x;
f[x].fa=gfa;
f[x].son[now^1]=fat;
f[fat].fa=x;
push_up(fat);
push_up(x);
}

Splay

我们用Splay处理被更改的所有节点,使其再次满足平衡性质,这是为了保证查找复杂度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
inline void splay(int x,int go=0)
{
while(f[x].fa!=go)
{
int fat=f[x].fa;
int gfa=f[fat].fa;
if(gfa!=go)
{
if(get(x)==get(fat))
{
rotate(fat);
}
else
{
rotate(x);
}
}
rotate(x);
}
if(!go)
{

rt=x;
}
}

插入节点

注意第一个插入的节点要特判,之后需要不断向下寻找位置。

然后再进行Splay,维护平衡性

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
inline void insert(int x)
{
int now=rt,p=0;
while(now&&f[now].val!=x)
{
p=now;
now=f[now].son[x>f[now].val];
}
if(now)
{
f[now].cnt++;
}
else
{
now=++cc;
if(p)
{
f[p].son[x>f[p].val]=now;
f[now].son[0]=f[now].son[1]=0;
f[now].fa=p;
f[now].val=x;
f[now].cnt=f[now].size=1;
}
else
{
f[now].son[0]=f[now].son[1]=0;
f[now].fa=p;
f[now].val=x;
f[now].cnt=f[now].size=1;
}
}
splay(now,0);
}

辅助函数

将小于等于x的树Splay到根

1
2
3
4
5
6
7
8
9
void find(int x)
{
int now=rt;
while(f[now].son[x>f[now].val]&&x!=f[now].val)
{
now=f[now].son[x>f[now].val];
}
splay(now);
}

查找第K大的数

因为满足了BST的性质,并且我们在建树的过程中记录了SIZE,那么我们就可以很容易地用SIZE和BST的性质来进行查找,这个过程可以看成不断“切树”的过程,直到x被切完

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
int k_th(int x)
{
int now=rt;
while(1)
{
if(f[now].son[0]&&x<=f[f[now].son[0]].size)
{
now=f[now].son[0];
}
else
if(x>f[f[now].son[0]].size+f[now].cnt)
{
x-=f[f[now].son[0]].size+f[now].cnt;
now=f[now].son[1];
}
else
{
return now;
}
}
}

前驱

这个就没什么好讲的了,直接查找即可。后驱同理

1
2
3
4
5
6
7
8
9
10
11
int pre( int x)
{
find(x);
if(f[rt].val<x)return rt;
int now=f[rt].son[0];
while(f[now].son[1])
{
now=f[now].son[1];
}
return now;
}

后驱

1
2
3
4
5
6
7
8
9
10
11
int next(int x)
{
find(x);
if(f[rt].val>x)return rt;
int now=f[rt].son[1];
while(f[now].son[0])
{
now=f[now].son[0];
}
return now;
}

删除

若存在这个点,就直接将计数器减一,若只有一个,暴力,改变前驱后继,并Splay维护BST性质

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
void del(int x)
{
int last=pre(x),nxt=next(x);
splay(last,0);
splay(nxt,last);
int fk=f[nxt].son[0];
if(f[fk].cnt>1)
{
f[fk].cnt--;
splay(fk);
}
else
{
f[nxt].son[0]=0;
push_up(nxt);
push_up(rt);
}
}

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
#include<bits/stdc++.h>
#define max(a,b) (a>b?a:b)
#define min(a,b) (a<b?a:b)
using namespace std;

inline void read(int &x)
{
x=0;
char ch=getchar();
int pd=1;
while(ch<'0'||ch>'9')
{
if(ch=='-')
{
pd=-pd;
}
ch=getchar();
}
while(ch>='0'&&ch<='9')
{
x=x*10+ch-'0';
ch=getchar();
}
x*=pd;
}
inline void write(const int &x)
{
char ggg[10001];
int s=0;
int tmp=x;
if(tmp==0)
{
putchar('0');
return;
}
if(tmp<0)
{
tmp=-tmp;
putchar('-');
}
while(tmp>0)
{
ggg[s++]=tmp%10+'0';
tmp/=10;
}
while(s>0)
{
putchar(ggg[--s]);
}
}
int rt,n;
int cc;
struct tree
{
int son[2],val,cnt,fa,size;
}f[1000010];

inline int get(int x)
{
return f[f[x].fa].son[1]==x;
}

void push_up(int x)
{
f[x].size=(f[f[x].son[0]].size+f[f[x].son[1]].size+f[x].cnt);
}

inline void rotate(int x)
{
int fat=f[x].fa;
int gfa=f[fat].fa;
int now=get(x);
int ss=f[x].son[now^1];
f[fat].son[now]=ss;
f[ss].fa=fat;
f[gfa].son[get(fat)]=x;
f[x].fa=gfa;
f[x].son[now^1]=fat;
f[fat].fa=x;
push_up(fat);
push_up(x);
}

inline void splay(int x,int go=0)
{
while(f[x].fa!=go)
{
int fat=f[x].fa;
int gfa=f[fat].fa;
if(gfa!=go)
{
if(get(x)==get(fat))
{
rotate(fat);
}
else
{
rotate(x);
}
}
rotate(x);
}
if(!go)
{

rt=x;
}
}

inline void insert(int x)
{
int now=rt,p=0;
while(now&&f[now].val!=x)
{
p=now;
now=f[now].son[x>f[now].val];
}
if(now)
{
f[now].cnt++;
}
else
{
now=++cc;
if(p)
{
f[p].son[x>f[p].val]=now;
f[now].son[0]=f[now].son[1]=0;
f[now].fa=p;
f[now].val=x;
f[now].cnt=f[now].size=1;
}
else
{
f[now].son[0]=f[now].son[1]=0;
f[now].fa=p;
f[now].val=x;
f[now].cnt=f[now].size=1;
}
}
splay(now,0);
}

void find(int x)
{
int now=rt;
while(f[now].son[x>f[now].val]&&x!=f[now].val)
{
now=f[now].son[x>f[now].val];
}
splay(now);
}

int k_th(int x)
{
int now=rt;
while(1)
{
if(f[now].son[0]&&x<=f[f[now].son[0]].size)
{
now=f[now].son[0];
}
else
if(x>f[f[now].son[0]].size+f[now].cnt)
{
x-=f[f[now].son[0]].size+f[now].cnt;
now=f[now].son[1];
}
else
{
return now;
}
}
}


int pre( int x)
{
find(x);
if(f[rt].val<x)return rt;
int now=f[rt].son[0];
while(f[now].son[1])
{
now=f[now].son[1];
}
return now;
}

int next(int x)
{
find(x);

if(f[rt].val>x)return rt;
int now=f[rt].son[1];
while(f[now].son[0])
{
now=f[now].son[0];
}
return now;
}

void del(int x)
{
int last=pre(x),nxt=next(x);
splay(last,0);
splay(nxt,last);
int fk=f[nxt].son[0];
if(f[fk].cnt>1)
{
f[fk].cnt--;
splay(fk);
}
else
{
f[nxt].son[0]=0;
push_up(nxt);
push_up(rt);
}
}

int main()
{
read(n);
insert(0x3f3f3f3f);
insert(-100000000);
for(register int i=1;i<=n;++i)
{
int aa;
read(aa);
int x;
read(x);
if(aa==1)
{
insert(x);
}
if(aa==2)
{
del(x);
}
if(aa==3)
{
find(x);
int ans=f[f[rt].son[0]].size;
write(ans);
puts("");
}
if(aa==4)
{
int ans=f[k_th(x+1)].val;
write(ans);
puts("");
}
if(aa==5)
{
int ans=f[pre(x)].val;
write(ans);
puts("");
}
if(aa==6)
{
int ans=f[next(x)].val;
write(ans);
puts("");
}
}
}