This documentation is automatically generated by online-judge-tools/verification-helper
View the Project on GitHub toyama1710/cpp_library
#include "bbst/persistent_avl_set.hpp"
#ifndef PERSISTENT_AVL_SET_HPP #define PERSISTENT_AVL_SET_HPP #include "avl_set.hpp" // insert/erase base AVLtree // multiset // merge,split are not implemented template <class T> struct PersistentAVLSet : public AVLSet<T> { using Set = AVLSet<T>; using Node = typename AVLSet<T>::Node; PersistentAVLSet(Node *root = nullptr) : Set(root){}; static Node *copy(Node *u) { if (u == nullptr) return nullptr; if (Set::balance_factor(u) == 2) { u->ch[0] = new Node(u->ch[0]); if (Set::balance_factor(u->ch[0]) == -1) u->ch[0]->ch[1] = new Node(u->ch[0]->ch[1]); } else if (Set::balance_factor(u) == -2) { u->ch[1] = new Node(u->ch[1]); if (Set::balance_factor(u->ch[1]) == 1) u->ch[1]->ch[0] = new Node(u->ch[1]->ch[0]); } return u; }; PersistentAVLSet insert(const T &dat) const { Node *nv = new Node(dat); return PersistentAVLSet(insert(this->root, nv)); }; Node *insert(Node *u, Node *nv) const { if (u == nullptr) return nv; u = new Node(u); if (u->dat < nv->dat) u->ch[1] = insert(u->ch[1], nv); else u->ch[0] = insert(u->ch[0], nv); return Set::balance(Set::recalc(u)); }; PersistentAVLSet erase(const T &dat) const { return PersistentAVLSet(erase(this->root, dat)); }; Node *erase(Node *u, const T &dat) const { if (u == nullptr) return nullptr; u = new Node(u); if (u->dat < dat) { u->ch[1] = erase(u->ch[1], dat); } else if (dat < u->dat) { u->ch[0] = erase(u->ch[0], dat); } else { u = isolate_node(u); } return Set::balance(copy(Set::recalc(u))); }; Node *isolate_node(Node *u) const { if (u->ch[0] == nullptr || u->ch[1] == nullptr) { return u->ch[0] != nullptr ? u->ch[0] : u->ch[1]; } else { auto [l, nv] = split_rightest_node(u->ch[0]); nv = new Node(nv); nv->ch[0] = l; nv->ch[1] = u->ch[1]; return Set::balance(copy(Set::recalc(nv))); } }; std::pair<Node *, Node *> split_rightest_node(Node *v) const { if (v->ch[1] != nullptr) { v = new Node(v); auto [l, ret] = split_rightest_node(v->ch[1]); v->ch[1] = l; return {Set::balance(copy(Set::recalc(v))), ret}; } else { return {isolate_node(v), v}; } }; }; #endif
#line 1 "bbst/persistent_avl_set.hpp" #line 1 "bbst/avl_set.hpp" #include <algorithm> #include <cassert> #include <iostream> #include <optional> #include <utility> #include <vector> // insert/erase base AVLtree // multiset template <class T> struct AVLSet { struct Node { int sz, hi; T dat; Node *ch[2]; Node(const Node *x) : sz(x->sz), hi(x->hi), dat(x->dat), ch{x->ch[0], x->ch[1]} {}; Node(T dat) : sz(1), hi(1), dat(dat), ch{nullptr, nullptr} {}; }; Node *root; AVLSet(Node *r = nullptr) : root(r){}; AVLSet(const AVLSet &x) : root(x.root){}; AVLSet &operator=(const AVLSet &x) { root = x.root; return *this; }; int size() const { return size(root); }; static int size(Node *u) { if (u != nullptr) return u->sz; else return 0; }; static int height(Node *u) { if (u != nullptr) return u->hi; else return 0; }; template <int d> // 0: left, 1: right static Node *rotate(Node *u) { assert(u != nullptr && u->ch[d] != nullptr); Node *v = u->ch[d]; u->ch[d] = v->ch[d ^ 1]; v->ch[d ^ 1] = u; recalc(u); recalc(v); return v; }; static int balance_factor(Node *u) { if (u == nullptr) return 0; return height(u->ch[0]) - height(u->ch[1]); }; static Node *balance(Node *u) { if (u == nullptr) return nullptr; assert(-2 <= balance_factor(u) && balance_factor(u) <= 2); if (balance_factor(u) == 2) { if (balance_factor(u->ch[0]) == -1) u->ch[0] = rotate<1>(u->ch[0]); u = rotate<0>(u); } else if (balance_factor(u) == -2) { if (balance_factor(u->ch[1]) == 1) u->ch[1] = rotate<0>(u->ch[1]); u = rotate<1>(u); } return u; }; static Node *recalc(Node *u) { if (u == nullptr) return nullptr; u->sz = size(u->ch[0]) + size(u->ch[1]) + 1; u->hi = std::max(height(u->ch[0]), height(u->ch[1])) + 1; return u; }; AVLSet &insert(const T &dat) { Node *u = new Node(dat); root = insert(root, u); return *this; }; Node *insert(Node *u, Node *nv) { if (u == nullptr) return nv; if (u->dat < nv->dat) u->ch[1] = insert(u->ch[1], nv); else u->ch[0] = insert(u->ch[0], nv); return balance(recalc(u)); }; AVLSet &erase(const T &dat) { root = erase(root, dat); return *this; }; Node *erase(Node *u, const T &dat) { if (u == nullptr) return nullptr; if (u->dat < dat) { u->ch[1] = erase(u->ch[1], dat); } else if (dat < u->dat) { u->ch[0] = erase(u->ch[0], dat); } else { Node *del = u; u = isolate_node(u); delete del; } return balance(recalc(u)); }; Node *isolate_node(Node *u) { if (u->ch[0] == nullptr || u->ch[1] == nullptr) { Node *ret = u->ch[0] != nullptr ? u->ch[0] : u->ch[1]; return ret; } else { auto [l, nv] = split_rightest_node(u->ch[0]); nv->ch[0] = l; nv->ch[1] = u->ch[1]; return balance(recalc(nv)); } }; std::pair<Node *, Node *> split_rightest_node(Node *v) { if (v->ch[1] != nullptr) { auto [l, ret] = split_rightest_node(v->ch[1]); v->ch[1] = l; return {balance(recalc(v)), ret}; } else { return {isolate_node(v), v}; } }; bool contains(const T &dat) const { Node *u = root; while (u != nullptr) { if (dat < u->dat) { u = u->ch[0]; } else if (u->dat < dat) { u = u->ch[1]; } else { return true; } } return false; }; std::optional<T> lower_bound(const T &x) const { return lower_bound(root, x); }; std::optional<T> lower_bound(Node *u, const T &x) const { if (u == nullptr) return std::nullopt; if (u->dat < x) { return lower_bound(u->ch[1], x); } else { auto ret = lower_bound(u->ch[0], x); if (ret) return ret; else return u->dat; } }; std::optional<T> upper_bound(const T &x) const { return upper_bound(root, x); }; std::optional<T> upper_bound(Node *u, const T &x) const { if (u == nullptr) return std::nullopt; if (x < u->dat) { auto ret = upper_bound(u->ch[0], x); if (ret) return ret; else return u->dat; } else { return upper_bound(u->ch[1], x); } }; // 0-indexed std::optional<T> find_Kth(int k) const { if (size() <= k || k < 0) return std::nullopt; else return find_Kth(root, k)->dat; }; Node *find_Kth(Node *u, int k) const { if (size(u->ch[0]) == k) return u; else if (size(u->ch[0]) > k) return find_Kth(u->ch[0], k); else return find_Kth(u->ch[1], k - size(u->ch[0]) - 1); }; int count(const T &x) const { return size() - count_upper(x) - count_lower(x); }; int count_lower(const T &x) const { return count_lower(x, root); }; int count_lower(const T &x, Node *u) const { if (u == nullptr) return 0; if (u->dat < x) return count_lower(x, u->ch[1]) + size(u->ch[0]) + 1; else return count_lower(x, u->ch[0]); }; int count_upper(const T &x) const { return count_upper(x, root); }; int count_upper(const T &x, Node *u) const { if (u == nullptr) return 0; if (x < u->dat) return count_upper(x, u->ch[0]) + size(u->ch[1]) + 1; else return count_upper(x, u->ch[1]); }; AVLSet &merge_with(AVLSet &r) { if (r.size() == 0) { return *this; } else if (size() == 0) { root = r.root; } else { auto [l, tmp] = split_rightest_node(root); root = merge(tmp, l, r.root); r.root = nullptr; } return *this; }; Node *merge(Node *root, Node *l, Node *r) { if (abs(height(l) - height(r)) <= 2) { root->ch[0] = l; root->ch[1] = r; return balance(recalc(root)); } else if (height(l) > height(r)) { l->ch[1] = merge(root, l->ch[1], r); return balance(recalc(l)); } else { r->ch[0] = merge(root, l, r->ch[0]); return balance(recalc(r)); } }; std::pair<AVLSet, AVLSet> split(int k) { assert(k >= 0 && k <= size()); auto [l, r] = split(root, k); root = nullptr; return {AVLSet(l), AVLSet(r)}; }; std::pair<Node *, Node *> split(Node *u, int k) { if (u == nullptr) return {nullptr, nullptr}; int lsize = size(u->ch[0]); Node *l = u->ch[0]; Node *r = u->ch[1]; u->ch[0] = u->ch[1] = nullptr; if (lsize == k) { return {l, insert(r, recalc(u))}; } else if (k < lsize) { auto [x, y] = split(l, k); return {x, merge(recalc(u), y, r)}; } else { auto [x, y] = split(r, k - lsize - 1); return {merge(recalc(u), l, x), y}; } }; std::vector<T> list() const { std::vector<T> ret; ret.reserve(size()); auto dfs = [&](Node *u, auto &&f) { if (u == nullptr) return; f(u->ch[0], f); ret.emplace_back(u->dat); f(u->ch[1], f); }; dfs(root, dfs); return ret; }; void dump() const { auto f = [](auto &&f, int d, Node *u) -> void { if (u == nullptr) return; f(f, d + 1, u->ch[1]); for (int i = 0; i < d; i++) { std::cout << " "; } std::cout << "(" << u->dat << ", " << u->sz << ", " << u->hi << ")" << std::endl; f(f, d + 1, u->ch[0]); }; f(f, 0, root); }; }; #line 5 "bbst/persistent_avl_set.hpp" // insert/erase base AVLtree // multiset // merge,split are not implemented template <class T> struct PersistentAVLSet : public AVLSet<T> { using Set = AVLSet<T>; using Node = typename AVLSet<T>::Node; PersistentAVLSet(Node *root = nullptr) : Set(root){}; static Node *copy(Node *u) { if (u == nullptr) return nullptr; if (Set::balance_factor(u) == 2) { u->ch[0] = new Node(u->ch[0]); if (Set::balance_factor(u->ch[0]) == -1) u->ch[0]->ch[1] = new Node(u->ch[0]->ch[1]); } else if (Set::balance_factor(u) == -2) { u->ch[1] = new Node(u->ch[1]); if (Set::balance_factor(u->ch[1]) == 1) u->ch[1]->ch[0] = new Node(u->ch[1]->ch[0]); } return u; }; PersistentAVLSet insert(const T &dat) const { Node *nv = new Node(dat); return PersistentAVLSet(insert(this->root, nv)); }; Node *insert(Node *u, Node *nv) const { if (u == nullptr) return nv; u = new Node(u); if (u->dat < nv->dat) u->ch[1] = insert(u->ch[1], nv); else u->ch[0] = insert(u->ch[0], nv); return Set::balance(Set::recalc(u)); }; PersistentAVLSet erase(const T &dat) const { return PersistentAVLSet(erase(this->root, dat)); }; Node *erase(Node *u, const T &dat) const { if (u == nullptr) return nullptr; u = new Node(u); if (u->dat < dat) { u->ch[1] = erase(u->ch[1], dat); } else if (dat < u->dat) { u->ch[0] = erase(u->ch[0], dat); } else { u = isolate_node(u); } return Set::balance(copy(Set::recalc(u))); }; Node *isolate_node(Node *u) const { if (u->ch[0] == nullptr || u->ch[1] == nullptr) { return u->ch[0] != nullptr ? u->ch[0] : u->ch[1]; } else { auto [l, nv] = split_rightest_node(u->ch[0]); nv = new Node(nv); nv->ch[0] = l; nv->ch[1] = u->ch[1]; return Set::balance(copy(Set::recalc(nv))); } }; std::pair<Node *, Node *> split_rightest_node(Node *v) const { if (v->ch[1] != nullptr) { v = new Node(v); auto [l, ret] = split_rightest_node(v->ch[1]); v->ch[1] = l; return {Set::balance(copy(Set::recalc(v))), ret}; } else { return {isolate_node(v), v}; } }; };