maomao90's Library
A C++20 library for competitive programming.
Loading...
Searching...
No Matches
splaytree.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <cassert>
4#include <vector>
5
7
8// Modified from https://judge.yosupo.jp/submission/144167,
9// https://judge.yosupo.jp/submission/136748,
10// https://judge.yosupo.jp/submission/278235
11
12namespace maomao90 {
13using namespace std;
14
16template <Monoid T, Lazy<T> L, bool store_reverse> struct Node {
17 Node *l, *r;
18 int sz;
19 bool rev;
20 T val, sum;
21 L lz;
22 Node(T val = T::id())
23 : l(nullptr), r(nullptr), sz(1), rev(false), val(val), sum(val),
24 lz(L::id()) {}
25};
26
27template <Monoid T, Lazy<T> L> struct Node<T, L, true> {
28 Node *l, *r;
29 int sz;
30 bool rev;
32 L lz;
33 Node(T val = T::id())
34 : l(nullptr), r(nullptr), sz(1), rev(false), val(val), sum(val),
35 rev_sum(val), lz(L::id()) {}
36};
37} // namespace internal::splaytree
38
54template <Monoid T, Lazy<T> L, bool store_reverse = false> struct SplayTree {
55private:
56 using splaytree = SplayTree<T, L, store_reverse>;
58 node *root;
59
60 int size(node *v) { return !v ? 0 : v->sz; }
61 void update(node *v) {
62 v->sz = 1;
63 v->sum = v->val;
64 if constexpr (store_reverse) {
65 v->rev_sum = v->val;
66 }
67 if (v->l) {
68 v->sz += v->l->sz;
69 v->sum = v->l->sum.merge(v->sum);
70 if constexpr (store_reverse) {
71 v->rev_sum = v->rev_sum.merge(v->l->rev_sum);
72 }
73 }
74 if (v->r) {
75 v->sz += v->r->sz;
76 v->sum = v->sum.merge(v->r->sum);
77 if constexpr (store_reverse) {
78 v->rev_sum = v->r->rev_sum.merge(v->rev_sum);
79 }
80 }
81 }
82 void push_down(node *v) {
83 if (!v) {
84 return;
85 }
86 if (v->l) {
87 propagate(v->l, v->lz);
88 }
89 if (v->r) {
90 propagate(v->r, v->lz);
91 }
92 v->lz = L::id();
93 if (v->rev) {
94 if (v->l) {
95 reverse(v->l);
96 }
97 if (v->r) {
98 reverse(v->r);
99 }
100 v->rev = false;
101 }
102 }
103 void propagate(node *v, L x) {
104 v->lz = x.merge(v->lz);
105 v->val = x.apply(v->val, 1);
106 v->sum = x.apply(v->sum, v->sz);
107 if constexpr (store_reverse) {
108 v->rev_sum = x.apply(v->rev_sum, v->sz);
109 }
110 }
111 void reverse(node *v) {
112 swap(v->l, v->r);
113 if constexpr (store_reverse) {
114 swap(v->sum, v->rev_sum);
115 }
116 v->rev ^= 1;
117 }
118 node *rotate_right(node *v) {
119 node *l = v->l;
120 v->l = l->r;
121 l->r = v;
122 update(v);
123 update(l);
124 return l;
125 }
126 node *rotate_left(node *v) {
127 node *r = v->r;
128 v->r = r->l;
129 r->l = v;
130 update(v);
131 update(r);
132 return r;
133 }
134 node *splay_top_down(node *v, int k) {
135 push_down(v);
136 int szl = v->l ? v->l->sz : 0;
137 if (k == szl) {
138 return v;
139 }
140 if (k < szl) {
141 push_down(v->l);
142 int szll = size(v->l->l);
143 if (k == szll) {
144 v = rotate_right(v);
145 } else if (k < szll) {
146 v->l->l = splay_top_down(v->l->l, k);
147 v = rotate_right(v);
148 v = rotate_right(v);
149 } else {
150 v->l->r = splay_top_down(v->l->r, k - szll - 1);
151 v->l = rotate_left(v->l);
152 v = rotate_right(v);
153 }
154 } else {
155 push_down(v->r);
156 k -= szl + 1;
157 int szrl = size(v->r->l);
158 if (k == szrl) {
159 v = rotate_left(v);
160 } else if (k < szrl) {
161 v->r->l = splay_top_down(v->r->l, k);
162 v->r = rotate_right(v->r);
163 v = rotate_left(v);
164 } else {
165 v->r->r = splay_top_down(v->r->r, k - szrl - 1);
166 v = rotate_left(v);
167 v = rotate_left(v);
168 }
169 }
170 update(v);
171 return v;
172 }
173 node *merge_inner(node *l, node *r) {
174 if (!l || !r) {
175 return !l ? r : l;
176 }
177 r = splay_top_down(r, 0);
178 r->l = l;
179 update(r);
180 return r;
181 }
182 pair<node *, node *> split_inner(node *v, int k) {
183 int n = size(v);
184 if (k >= n) {
185 return {v, nullptr};
186 }
187 v = splay_top_down(v, k);
188 node *l = v->l;
189 v->l = nullptr;
190 update(v);
191 return {l, v};
192 }
193 tuple<node *, node *, node *> split3_inner(node *v, int l, int r) {
194 if (l == 0) {
195 auto [b, c] = split_inner(v, r);
196 return {nullptr, b, c};
197 }
198 v = splay_top_down(v, l - 1);
199 auto [b, c] = split_inner(v->r, r - l);
200 v->r = nullptr;
201 update(v);
202 return {v, b, c};
203 }
204 // Only can be used to merge if it was split using `split3_inner`. O(1)
205 // compared to general `merge3_inner`.
206 node *inv_split3_inner(node *a, node *b, node *c) {
207 node *v = merge_inner(b, c);
208 if (!a) {
209 return v;
210 }
211 a->r = v;
212 update(a);
213 return a;
214 }
215 node *merge3_inner(node *a, node *b, node *c) {
216 node *v = merge_inner(b, c);
217 return merge_inner(a, v);
218 }
219 node *set_inner(node *v, int k, T x) {
220 v = splay_top_down(v, k);
221 v->val = x;
222 update(v);
223 return v;
224 }
225 node *get_inner(node *v, int k, T &x) {
226 v = splay_top_down(v, k);
227 x = v->val;
228 return v;
229 }
230 node *update_inner(node *v, int l, int r, L x) {
231 if (r == l) {
232 return v;
233 }
234 auto [a, b, c] = split3_inner(v, l, r);
235 propagate(b, x);
236 return inv_split3_inner(a, b, c);
237 }
238 node *query_inner(node *v, int l, int r, T &res) {
239 if (r == l) {
240 return v;
241 }
242 auto [a, b, c] = split3_inner(v, l, r);
243 res = b->sum;
244 return inv_split3_inner(a, b, c);
245 }
246 template <typename P>
247 node *max_right_inner(node *v, int l, P pred, int &res) {
248 res = l;
249 if (l == size(v)) {
250 return v;
251 }
252 v = splay_top_down(v, l);
253 if (!pred(v->val)) {
254 return v;
255 }
256 res++;
257 push_down(v);
258 v->r = max_right_inner(v->r, v->val, pred, res);
259 if (v->r) {
260 v = rotate_left(v);
261 }
262 update(v);
263 return v;
264 }
265 template <typename P>
266 node *max_right_inner(node *v, T sum, P pred, int &res) {
267 if (!v) {
268 return v;
269 }
270 push_down(v);
271 T lsum = sum;
272 if (v->l) {
273 lsum = lsum.merge(v->l->sum);
274 }
275 lsum = lsum.merge(v->val);
276 if (pred(lsum)) {
277 int szl = v->l ? v->l->sz : 0;
278 res += szl + 1;
279 v->r = max_right_inner(v->r, lsum, pred, res);
280 if (v->r) {
281 v = rotate_left(v);
282 }
283 } else {
284 v->l = max_right_inner(v->l, sum, pred, res);
285 if (v->l) {
286 v = rotate_right(v);
287 }
288 }
289 update(v);
290 return v;
291 }
292 template <typename P> node *min_left_inner(node *v, int r, P pred, int &res) {
293 res = r;
294 if (r == 0) {
295 return v;
296 }
297 v = splay_top_down(v, r - 1);
298 if (!pred(v->val)) {
299 return v;
300 }
301 res--;
302 push_down(v);
303 v->l = min_left_inner(v->l, v->val, pred, res);
304 if (v->l) {
305 v = rotate_right(v);
306 }
307 update(v);
308 return v;
309 }
310 template <typename P> node *min_left_inner(node *v, T sum, P pred, int &res) {
311 if (!v) {
312 return v;
313 }
314 push_down(v);
315 T rsum = sum;
316 if (v->r) {
317 rsum = v->r->sum.merge(rsum);
318 }
319 rsum = v->val.merge(rsum);
320 if (pred(rsum)) {
321 int szr = v->r ? v->r->sz : 0;
322 res -= szr + 1;
323 v->l = min_left_inner(v->l, rsum, pred, res);
324 if (v->l) {
325 v = rotate_right(v);
326 }
327 } else {
328 v->r = min_left_inner(v->r, sum, pred, res);
329 if (v->r) {
330 v = rotate_left(v);
331 }
332 }
333 update(v);
334 return v;
335 }
336 node *reverse_inner(node *v, int l, int r) {
337 if (r == l) {
338 return v;
339 }
340 auto [a, b, c] = split3_inner(v, l, r);
341 if (b) {
342 reverse(b);
343 }
344 return inv_split3_inner(a, b, c);
345 }
346 node *insert_inner(node *v, int k, node *u) {
347 if (k == size(v)) {
348 u->l = v;
349 update(u);
350 return u;
351 }
352 if (k == 0) {
353 u->r = v;
354 update(u);
355 return u;
356 }
357 v = splay_top_down(v, k);
358 u->l = v->l;
359 v->l = u;
360 update(u);
361 update(v);
362 return v;
363 }
364 node *erase_inner(node *v, int k) {
365 v = splay_top_down(v, k);
366 return merge_inner(v->l, v->r);
367 }
368 node *build(const vector<T> &v, int l, int r) {
369 int m = (l + r) >> 1;
370 node *u = new node(v[m]);
371 if (m > l) {
372 u->l = build(v, l, m);
373 }
374 if (r > m + 1) {
375 u->r = build(v, m + 1, r);
376 }
377 update(u);
378 return u;
379 }
380 SplayTree(node *r) : root(r) {}
381
382public:
386 SplayTree() : root(nullptr) {}
393 explicit SplayTree(int n) : SplayTree(vector<T>(n, T::id())) {}
399 explicit SplayTree(const vector<T> &v) : root(nullptr) {
400 if (!v.empty()) {
401 root = build(v, 0, v.size());
402 }
403 }
404
410 int size() { return size(root); }
419 void set(int k, T x) {
420 assert(0 <= k && k < size());
421 root = set_inner(root, k, x);
422 }
423
431 T get(int k) {
432 assert(0 <= k && k < size());
433 T res = T::id();
434 root = get_inner(root, k, res);
435 return res;
436 }
437
446 void update(int l, int r, L x) {
447 assert(0 <= l && l <= r && r <= size());
448 root = update_inner(root, l, r, x);
449 }
450
459 T query(int l, int r) {
460 assert(0 <= l && l <= r && r <= size());
461 T res = T::id();
462 root = query_inner(root, l, r, res);
463 return res;
464 }
465
482 template <typename P> int max_right(int l, P pred) {
483 assert(0 <= l && l <= size());
484 int res = l;
485 root = max_right_inner(root, l, pred, res);
486 return res;
487 }
488
491 template <bool (*pred)(T)> int max_right(int l) {
492 return max_right(l, [](T x) { return pred(x); });
493 }
494
510 template <typename P> int min_left(int r, P pred) {
511 assert(0 <= r && r <= size());
512 int res = r;
513 root = min_left_inner(root, r, pred, res);
514 return res;
515 }
516
519 template <bool (*pred)(T)> int min_left(int r) {
520 return min_left(r, [](T x) { return pred(x); });
521 }
522
531 void reverse(int l, int r) {
532 assert(0 <= l && l <= r && r <= size());
533 root = reverse_inner(root, l, r);
534 }
535
546 void insert(int k, T x) {
547 assert(0 <= k && k <= size());
548 root = insert_inner(root, k, new node(x));
549 }
550
557 void erase(int k) {
558 assert(0 <= k && k < size());
559 root = erase_inner(root, k);
560 }
561
569 splaytree split(int k) {
570 assert(0 <= k && k <= size());
571 auto [a, b] = split_inner(root, k);
572 root = a;
573 return splaytree(b);
574 }
575
588 pair<splaytree, splaytree> split(int l, int r) {
589 assert(0 <= l && l <= r && r <= size());
590 auto [a, b, c] = split3_inner(root, l, r);
591 root = a;
592 return {splaytree(b), splaytree(c)};
593 }
594
602 void merge(splaytree &o) {
603 root = merge_inner(root, o.root);
604 o.root = nullptr;
605 }
606
615 void merge(splaytree &b, splaytree &c) {
616 root = merge3_inner(root, b.root, c.root);
617 b.root = c.root = nullptr;
618 }
619};
620} // namespace maomao90
Definition splaytree.hpp:15
Definition hashmap.hpp:8
int max_right(int l)
This is an overloaded member function, provided for convenience. It differs from the above function o...
Definition splaytree.hpp:491
int min_left(int r, P pred)
Finds the smallest x such that the predicate returns true for the left-associative fold on the half-o...
Definition splaytree.hpp:510
int min_left(int r)
This is an overloaded member function, provided for convenience. It differs from the above function o...
Definition splaytree.hpp:519
void merge(splaytree &b, splaytree &c)
Merge splay tree b to the right of this, then c to the right of this and b.
Definition splaytree.hpp:615
void reverse(int l, int r)
Reverses the half-open interval [l, r).
Definition splaytree.hpp:531
SplayTree(int n)
Initialises the splay tree with n elements, all equal to the identity element T::id().
Definition splaytree.hpp:393
splaytree split(int k)
Splits this into two parts, then, set this to be the left part and returns the right part.
Definition splaytree.hpp:569
SplayTree(const vector< T > &v)
Initialises the splay tree with values from vector v.
Definition splaytree.hpp:399
void merge(splaytree &o)
Merge splay tree o to the right of this.
Definition splaytree.hpp:602
void set(int k, T x)
Set the k-th index (0-indexed) to x.
Definition splaytree.hpp:419
SplayTree()
Initialises empty splay tree.
Definition splaytree.hpp:386
pair< splaytree, splaytree > split(int l, int r)
Splits this into three parts, then, set this to be the left part and returns the middle and right par...
Definition splaytree.hpp:588
void insert(int k, T x)
Inserts x into the k-th index (0-indexed).
Definition splaytree.hpp:546
T query(int l, int r)
Query the half-open interval [l, r).
Definition splaytree.hpp:459
int max_right(int l, P pred)
Finds the largest x such that the predicate returns true for the left-associative fold on the half-op...
Definition splaytree.hpp:482
void update(int l, int r, L x)
Apply update x to the half-open interval [l, r).
Definition splaytree.hpp:446
T get(int k)
Get the value at the k-th index (0-indexed).
Definition splaytree.hpp:431
int size()
Gets the number of elements in the splay tree.
Definition splaytree.hpp:410
void erase(int k)
Erases the value at the k-th index (0-indexed).
Definition splaytree.hpp:557
Node * l
Definition splaytree.hpp:28
Node(T val=T::id())
Definition splaytree.hpp:33
Node * r
Definition splaytree.hpp:28
Definition splaytree.hpp:16
int sz
Definition splaytree.hpp:18
Node * l
Definition splaytree.hpp:17
Node * r
Definition splaytree.hpp:17
T sum
Definition splaytree.hpp:20
L lz
Definition splaytree.hpp:21
T val
Definition splaytree.hpp:20
Node(T val=T::id())
Definition splaytree.hpp:22
bool rev
Definition splaytree.hpp:19