cpp_library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub toyama1710/cpp_library

:heavy_check_mark: test/aoj/2270.test.cpp

Depends on

Code

#define PROBLEM "https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=2270"
#include <iostream>

#include "../../bbst/persistent_avl_set.hpp"
#include "../../tree/doubling_tree.hpp"

#define _overload(_1, _2, _3, _4, name, ...) name
#define _rep1(Itr, N) _rep3(Itr, 0, N, 1)
#define _rep2(Itr, a, b) _rep3(Itr, a, b, 1)
#define _rep3(Itr, a, b, step) for (i64 Itr = a; Itr < b; Itr += step)
#define repeat(...) _overload(__VA_ARGS__, _rep3, _rep2, _rep1)(__VA_ARGS__)
#define rep(...) repeat(__VA_ARGS__)

#define ALL(X) begin(X), end(X)

using namespace std;
using i64 = long long;
using u64 = unsigned long long;

int main() {
    cin.tie(nullptr);
    ios::sync_with_stdio(false);

    int n, q;
    cin >> n >> q;
    vector<int> x(n);
    for (auto &vs : x) cin >> vs;

    DoublingTreeBuilder buider(n);
    rep(i, n - 1) {
        i64 a, b;
        cin >> a >> b;
        --a, --b;
        buider.add_edge(a, b);
    }

    auto tr = buider.build();
    vector<PersistentAVLSet<int>> st(n);
    st[0] = st[0].insert(x[0]);
    auto dfs = [&](int u, auto &&f) {
        if (st[u].size() > 0) return;
        f(tr.climb(u, 1).value(), f);
        st[u] = st[tr.climb(u, 1).value()].insert(x[u]);
        return;
    };
    rep(i, 1, n) {
        //
        dfs(i, dfs);
    }

    auto count = [&](int u, int v, int r, int t) {
        return st[u].count_lower(t) + st[v].count_lower(t) -
               2 * st[r].count_lower(t) + (x[r] < t);
    };
    auto find = [&](int u, int v, int l) {
        int valid = 0;
        int invalid = 1 << 30;
        int lca = tr.lca(u, v);
        while (abs(valid - invalid) > 1) {
            i64 mid = (valid + invalid) / 2;
            if (count(u, v, lca, mid) < l)
                valid = mid;
            else
                invalid = mid;
        }
        return valid;
    };

    rep(_, q) {
        int u, v, l;
        cin >> u >> v >> l;
        --u, --v;

        cout << find(u, v, l) << '\n';
    }

    return 0;
}
#line 1 "test/aoj/2270.test.cpp"
#define PROBLEM "https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=2270"
#include <iostream>

#line 1 "bbst/persistent_avl_set.hpp"



#line 1 "bbst/avl_set.hpp"



#include <algorithm>
#include <cassert>
#line 7 "bbst/avl_set.hpp"
#include <optional>
#include <utility>
#include <vector>

// insert/erase base AVLtree
// multiset
template <class T>
struct AVLSet {
    struct Node {
        int sz, hi;
        T dat;
        Node *ch[2];
        Node(const Node *x)
            : sz(x->sz), hi(x->hi), dat(x->dat), ch{x->ch[0], x->ch[1]} {};
        Node(T dat) : sz(1), hi(1), dat(dat), ch{nullptr, nullptr} {};
    };

    Node *root;

    AVLSet(Node *r = nullptr) : root(r){};
    AVLSet(const AVLSet &x) : root(x.root){};

    AVLSet &operator=(const AVLSet &x) {
        root = x.root;
        return *this;
    };

    int size() const {
        return size(root);
    };

    static int size(Node *u) {
        if (u != nullptr)
            return u->sz;
        else
            return 0;
    };

    static int height(Node *u) {
        if (u != nullptr)
            return u->hi;
        else
            return 0;
    };

    template <int d>  // 0: left, 1: right
    static Node *rotate(Node *u) {
        assert(u != nullptr && u->ch[d] != nullptr);
        Node *v = u->ch[d];
        u->ch[d] = v->ch[d ^ 1];
        v->ch[d ^ 1] = u;
        recalc(u);
        recalc(v);
        return v;
    };

    static int balance_factor(Node *u) {
        if (u == nullptr) return 0;
        return height(u->ch[0]) - height(u->ch[1]);
    };

    static Node *balance(Node *u) {
        if (u == nullptr) return nullptr;
        assert(-2 <= balance_factor(u) && balance_factor(u) <= 2);
        if (balance_factor(u) == 2) {
            if (balance_factor(u->ch[0]) == -1) u->ch[0] = rotate<1>(u->ch[0]);
            u = rotate<0>(u);
        } else if (balance_factor(u) == -2) {
            if (balance_factor(u->ch[1]) == 1) u->ch[1] = rotate<0>(u->ch[1]);
            u = rotate<1>(u);
        }
        return u;
    };

    static Node *recalc(Node *u) {
        if (u == nullptr) return nullptr;
        u->sz = size(u->ch[0]) + size(u->ch[1]) + 1;
        u->hi = std::max(height(u->ch[0]), height(u->ch[1])) + 1;
        return u;
    };

    AVLSet &insert(const T &dat) {
        Node *u = new Node(dat);
        root = insert(root, u);
        return *this;
    };

    Node *insert(Node *u, Node *nv) {
        if (u == nullptr) return nv;
        if (u->dat < nv->dat)
            u->ch[1] = insert(u->ch[1], nv);
        else
            u->ch[0] = insert(u->ch[0], nv);

        return balance(recalc(u));
    };

    AVLSet &erase(const T &dat) {
        root = erase(root, dat);
        return *this;
    };

    Node *erase(Node *u, const T &dat) {
        if (u == nullptr) return nullptr;
        if (u->dat < dat) {
            u->ch[1] = erase(u->ch[1], dat);
        } else if (dat < u->dat) {
            u->ch[0] = erase(u->ch[0], dat);
        } else {
            Node *del = u;
            u = isolate_node(u);
            delete del;
        }
        return balance(recalc(u));
    };

    Node *isolate_node(Node *u) {
        if (u->ch[0] == nullptr || u->ch[1] == nullptr) {
            Node *ret = u->ch[0] != nullptr ? u->ch[0] : u->ch[1];
            return ret;
        } else {
            auto [l, nv] = split_rightest_node(u->ch[0]);
            nv->ch[0] = l;
            nv->ch[1] = u->ch[1];
            return balance(recalc(nv));
        }
    };

    std::pair<Node *, Node *> split_rightest_node(Node *v) {
        if (v->ch[1] != nullptr) {
            auto [l, ret] = split_rightest_node(v->ch[1]);
            v->ch[1] = l;
            return {balance(recalc(v)), ret};
        } else {
            return {isolate_node(v), v};
        }
    };

    bool contains(const T &dat) const {
        Node *u = root;
        while (u != nullptr) {
            if (dat < u->dat) {
                u = u->ch[0];
            } else if (u->dat < dat) {
                u = u->ch[1];
            } else {
                return true;
            }
        }
        return false;
    };

    std::optional<T> lower_bound(const T &x) const {
        return lower_bound(root, x);
    };

    std::optional<T> lower_bound(Node *u, const T &x) const {
        if (u == nullptr) return std::nullopt;
        if (u->dat < x) {
            return lower_bound(u->ch[1], x);
        } else {
            auto ret = lower_bound(u->ch[0], x);
            if (ret)
                return ret;
            else
                return u->dat;
        }
    };

    std::optional<T> upper_bound(const T &x) const {
        return upper_bound(root, x);
    };

    std::optional<T> upper_bound(Node *u, const T &x) const {
        if (u == nullptr) return std::nullopt;
        if (x < u->dat) {
            auto ret = upper_bound(u->ch[0], x);
            if (ret)
                return ret;
            else
                return u->dat;
        } else {
            return upper_bound(u->ch[1], x);
        }
    };

    // 0-indexed
    std::optional<T> find_Kth(int k) const {
        if (size() <= k || k < 0)
            return std::nullopt;
        else
            return find_Kth(root, k)->dat;
    };

    Node *find_Kth(Node *u, int k) const {
        if (size(u->ch[0]) == k)
            return u;
        else if (size(u->ch[0]) > k)
            return find_Kth(u->ch[0], k);
        else
            return find_Kth(u->ch[1], k - size(u->ch[0]) - 1);
    };

    int count(const T &x) const {
        return size() - count_upper(x) - count_lower(x);
    };

    int count_lower(const T &x) const {
        return count_lower(x, root);
    };

    int count_lower(const T &x, Node *u) const {
        if (u == nullptr) return 0;
        if (u->dat < x)
            return count_lower(x, u->ch[1]) + size(u->ch[0]) + 1;
        else
            return count_lower(x, u->ch[0]);
    };

    int count_upper(const T &x) const {
        return count_upper(x, root);
    };

    int count_upper(const T &x, Node *u) const {
        if (u == nullptr) return 0;
        if (x < u->dat)
            return count_upper(x, u->ch[0]) + size(u->ch[1]) + 1;
        else
            return count_upper(x, u->ch[1]);
    };

    AVLSet &merge_with(AVLSet &r) {
        if (r.size() == 0) {
            return *this;
        } else if (size() == 0) {
            root = r.root;
        } else {
            auto [l, tmp] = split_rightest_node(root);
            root = merge(tmp, l, r.root);
            r.root = nullptr;
        }
        return *this;
    };

    Node *merge(Node *root, Node *l, Node *r) {
        if (abs(height(l) - height(r)) <= 2) {
            root->ch[0] = l;
            root->ch[1] = r;
            return balance(recalc(root));
        } else if (height(l) > height(r)) {
            l->ch[1] = merge(root, l->ch[1], r);
            return balance(recalc(l));
        } else {
            r->ch[0] = merge(root, l, r->ch[0]);
            return balance(recalc(r));
        }
    };

    std::pair<AVLSet, AVLSet> split(int k) {
        assert(k >= 0 && k <= size());
        auto [l, r] = split(root, k);
        root = nullptr;
        return {AVLSet(l), AVLSet(r)};
    };

    std::pair<Node *, Node *> split(Node *u, int k) {
        if (u == nullptr) return {nullptr, nullptr};
        int lsize = size(u->ch[0]);
        Node *l = u->ch[0];
        Node *r = u->ch[1];
        u->ch[0] = u->ch[1] = nullptr;
        if (lsize == k) {
            return {l, insert(r, recalc(u))};
        } else if (k < lsize) {
            auto [x, y] = split(l, k);
            return {x, merge(recalc(u), y, r)};
        } else {
            auto [x, y] = split(r, k - lsize - 1);
            return {merge(recalc(u), l, x), y};
        }
    };

    std::vector<T> list() const {
        std::vector<T> ret;
        ret.reserve(size());
        auto dfs = [&](Node *u, auto &&f) {
            if (u == nullptr) return;
            f(u->ch[0], f);
            ret.emplace_back(u->dat);
            f(u->ch[1], f);
        };
        dfs(root, dfs);
        return ret;
    };

    void dump() const {
        auto f = [](auto &&f, int d, Node *u) -> void {
            if (u == nullptr) return;
            f(f, d + 1, u->ch[1]);
            for (int i = 0; i < d; i++) {
                std::cout << "      ";
            }
            std::cout << "(" << u->dat << ", " << u->sz << ", " << u->hi << ")"
                      << std::endl;
            f(f, d + 1, u->ch[0]);
        };
        f(f, 0, root);
    };
};


#line 5 "bbst/persistent_avl_set.hpp"

// insert/erase base AVLtree
// multiset
// merge,split are not implemented
template <class T>
struct PersistentAVLSet : public AVLSet<T> {
    using Set = AVLSet<T>;
    using Node = typename AVLSet<T>::Node;
    PersistentAVLSet(Node *root = nullptr) : Set(root){};

    static Node *copy(Node *u) {
        if (u == nullptr) return nullptr;
        if (Set::balance_factor(u) == 2) {
            u->ch[0] = new Node(u->ch[0]);

            if (Set::balance_factor(u->ch[0]) == -1)
                u->ch[0]->ch[1] = new Node(u->ch[0]->ch[1]);
        } else if (Set::balance_factor(u) == -2) {
            u->ch[1] = new Node(u->ch[1]);

            if (Set::balance_factor(u->ch[1]) == 1)
                u->ch[1]->ch[0] = new Node(u->ch[1]->ch[0]);
        }
        return u;
    };

    PersistentAVLSet insert(const T &dat) const {
        Node *nv = new Node(dat);
        return PersistentAVLSet(insert(this->root, nv));
    };

    Node *insert(Node *u, Node *nv) const {
        if (u == nullptr) return nv;
        u = new Node(u);
        if (u->dat < nv->dat)
            u->ch[1] = insert(u->ch[1], nv);
        else
            u->ch[0] = insert(u->ch[0], nv);

        return Set::balance(Set::recalc(u));
    };

    PersistentAVLSet erase(const T &dat) const {
        return PersistentAVLSet(erase(this->root, dat));
    };

    Node *erase(Node *u, const T &dat) const {
        if (u == nullptr) return nullptr;
        u = new Node(u);

        if (u->dat < dat) {
            u->ch[1] = erase(u->ch[1], dat);
        } else if (dat < u->dat) {
            u->ch[0] = erase(u->ch[0], dat);
        } else {
            u = isolate_node(u);
        }
        return Set::balance(copy(Set::recalc(u)));
    };

    Node *isolate_node(Node *u) const {
        if (u->ch[0] == nullptr || u->ch[1] == nullptr) {
            return u->ch[0] != nullptr ? u->ch[0] : u->ch[1];
        } else {
            auto [l, nv] = split_rightest_node(u->ch[0]);
            nv = new Node(nv);
            nv->ch[0] = l;
            nv->ch[1] = u->ch[1];
            return Set::balance(copy(Set::recalc(nv)));
        }
    };

    std::pair<Node *, Node *> split_rightest_node(Node *v) const {
        if (v->ch[1] != nullptr) {
            v = new Node(v);
            auto [l, ret] = split_rightest_node(v->ch[1]);
            v->ch[1] = l;
            return {Set::balance(copy(Set::recalc(v))), ret};
        } else {
            return {isolate_node(v), v};
        }
    };
};


#line 1 "tree/doubling_tree.hpp"


#include <iterator>
#line 6 "tree/doubling_tree.hpp"

// 0-indexed
// climb(u, d): climb d steps towards root
// fold(u, v):
struct DoublingTree {
    std::vector<std::vector<std::optional<int>>> parent;
    std::vector<int> depth;
    int logn;

    template <class InputItr>
    DoublingTree(InputItr first, InputItr last) {
        int n = std::distance(first, last);
        std::vector<std::optional<int>> p(n, std::nullopt);
        int i = 0;
        for (auto itr = first; itr != last; itr++, i++) {
            if (itr->has_value()) p[i] = (int)itr->value();
        }
        build(p);
    };

    void build(const std::vector<std::optional<int>> &p) {
        int n = p.size();
        logn = 1;
        while ((1 << logn) < n) logn++;

        parent.assign(logn, std::vector<std::optional<int>>(n, std::nullopt));
        for (int i = 0; i < n; i++) parent[0][i] = p[i];
        std::vector<std::vector<int>> tree(n);
        std::vector<int> root;
        for (int i = 0; i < n; i++) {
            if (parent[0][i].has_value())
                tree[parent[0][i].value()].push_back(i);
            else
                root.push_back(i);
        }

        depth.assign(n, -1);
        auto calc_depth = [&](int u, int d, auto &&f) -> void {
            depth[u] = d;
            for (auto v : tree[u]) {
                f(v, d + 1, f);
            }
            return;
        };

        for (int u : root) {
            calc_depth(u, 0, calc_depth);
        }

        for (int k = 1; k < logn; k++) {
            for (int u = 0; u < n; u++) {
                if (parent[k - 1][u].has_value())
                    parent[k][u] = parent[k - 1][parent[k - 1][u].value()];
            }
        }
    };

    std::optional<int> climb(int u, int d) {
        if (d > depth[u]) return std::nullopt;
        int cnt = 0;
        while (d > 0) {
            if (d & 1) u = parent[cnt][u].value();
            d >>= 1;
            cnt++;
        }
        return u;
    };

    // LowestCommonAncestor
    int lca(int u, int v) {
        if (depth[u] > depth[v]) std::swap(u, v);
        v = climb(v, depth[v] - depth[u]).value();

        if (u == v) return u;
        for (int k = logn - 1; k >= 0; k--) {
            if (parent[k][u] != parent[k][v]) {
                u = parent[k][u].value();
                v = parent[k][v].value();
            }
        }
        return parent[0][u].value();
    };

    inline int distance(int u, int v) {
        return depth[u] + depth[v] - depth[lca(u, v)] * 2;
    };
};

struct DoublingTreeBuilder {
    std::vector<std::vector<int>> g;
    DoublingTreeBuilder(int n) : g(n){};
    void add_edge(int a, int b) {
        g[a].push_back(b);
        g[b].push_back(a);
    };

    DoublingTree build(const std::vector<int> &root = {0}) {
        std::vector<std::optional<int>> parent(g.size(), std::nullopt);
        auto dfs = [&](int u, int p, auto &&f) -> void {
            for (auto v : g[u]) {
                if (v == p) continue;
                parent[v] = u;
                f(v, u, f);
            }
            return;
        };
        for (auto v : root) dfs(v, -1, dfs);
        return DoublingTree(parent.begin(), parent.end());
    };
};


#line 6 "test/aoj/2270.test.cpp"

#define _overload(_1, _2, _3, _4, name, ...) name
#define _rep1(Itr, N) _rep3(Itr, 0, N, 1)
#define _rep2(Itr, a, b) _rep3(Itr, a, b, 1)
#define _rep3(Itr, a, b, step) for (i64 Itr = a; Itr < b; Itr += step)
#define repeat(...) _overload(__VA_ARGS__, _rep3, _rep2, _rep1)(__VA_ARGS__)
#define rep(...) repeat(__VA_ARGS__)

#define ALL(X) begin(X), end(X)

using namespace std;
using i64 = long long;
using u64 = unsigned long long;

int main() {
    cin.tie(nullptr);
    ios::sync_with_stdio(false);

    int n, q;
    cin >> n >> q;
    vector<int> x(n);
    for (auto &vs : x) cin >> vs;

    DoublingTreeBuilder buider(n);
    rep(i, n - 1) {
        i64 a, b;
        cin >> a >> b;
        --a, --b;
        buider.add_edge(a, b);
    }

    auto tr = buider.build();
    vector<PersistentAVLSet<int>> st(n);
    st[0] = st[0].insert(x[0]);
    auto dfs = [&](int u, auto &&f) {
        if (st[u].size() > 0) return;
        f(tr.climb(u, 1).value(), f);
        st[u] = st[tr.climb(u, 1).value()].insert(x[u]);
        return;
    };
    rep(i, 1, n) {
        //
        dfs(i, dfs);
    }

    auto count = [&](int u, int v, int r, int t) {
        return st[u].count_lower(t) + st[v].count_lower(t) -
               2 * st[r].count_lower(t) + (x[r] < t);
    };
    auto find = [&](int u, int v, int l) {
        int valid = 0;
        int invalid = 1 << 30;
        int lca = tr.lca(u, v);
        while (abs(valid - invalid) > 1) {
            i64 mid = (valid + invalid) / 2;
            if (count(u, v, lca, mid) < l)
                valid = mid;
            else
                invalid = mid;
        }
        return valid;
    };

    rep(_, q) {
        int u, v, l;
        cin >> u >> v >> l;
        --u, --v;

        cout << find(u, v, l) << '\n';
    }

    return 0;
}
Back to top page