跳轉到

Tree

樹上問題常需要快速判斷祖先關係、查詢兩點最近共同祖先,或在兩點路徑上維護資訊。這頁整理 Binary Lifting / LCA 和 Heavy-Light Decomposition 兩種模板。

Binary lifting 時間複雜度:

  • DFS 預處理:O(n log n)
  • 判斷祖先關係:O(1)
  • LCA 查詢:O(log n)
  • k 個祖先:O(log n)

Binary Lifting + LCA

先選定根節點,把樹轉成 rooted tree,並記錄每個節點的進入時間 tin、離開時間 tout。若 uv 的祖先,則 tin[u] <= tin[v]tout[v] <= tout[u]。查詢 LCA 時,先處理其中一方本來就是祖先的情況;否則從最大的 jump 開始,把 u 往上跳到不會越過 LCA 的最高位置。

維護資訊

  • up[i][u]:節點 u 往上跳 2^i 步後的祖先。
  • tin[u] / tout[u]:DFS order,用來判斷祖先關係。
  • depth[u]:節點深度,可用來計算距離或處理路徑題型。

模板

C++
#include <bits/stdc++.h>
using namespace std;

class LCA {
public:
    int n, logn, timer;
    vector<vector<int>> adj, up;
    vector<int> tin, tout, depth;

    LCA(const vector<vector<int>>& graph, int root = 0) {
        n = graph.size();
        adj = graph;
        timer = 0;

        logn = 1;
        while ((1 << logn) <= n) {
            ++logn;
        }

        tin.assign(n, 0);
        tout.assign(n, 0);
        depth.assign(n, 0);
        up.assign(logn, vector<int>(n, root));

        dfs(root, root);
    }

    void dfs(int cur, int parent) {
        tin[cur] = timer++;
        up[0][cur] = parent;

        for (int i = 1; i < logn; ++i) {
            up[i][cur] = up[i - 1][up[i - 1][cur]];
        }

        for (int next : adj[cur]) {
            if (next == parent) continue;

            depth[next] = depth[cur] + 1;
            dfs(next, cur);
        }

        tout[cur] = timer++;
    }

    bool is_ancestor(int u, int v) {
        return tin[u] <= tin[v] && tout[v] <= tout[u];
    }

    int lca(int u, int v) {
        if (is_ancestor(u, v)) return u;
        if (is_ancestor(v, u)) return v;

        for (int i = logn - 1; i >= 0; --i) {
            if (!is_ancestor(up[i][u], v)) {
                u = up[i][u];
            }
        }

        return up[0][u];
    }

    int kth_ancestor(int u, int k) {
        for (int i = 0; i < logn; ++i) {
            if (k & (1 << i)) {
                u = up[i][u];
            }
        }

        return u;
    }

    int parent(int u) {
        return up[0][u];
    }

    int distance(int u, int v) {
        int w = lca(u, v);
        return depth[u] + depth[v] - 2 * depth[w];
    }
};

參考題型

本題整體複雜度為 O(n log n + q log n),其中 q 是 trips 數量;後序累加與樹 DP 各為 O(n)

C++
using ll = long long;
using pll = pair<ll, ll>;

class Solution {
public:
    int minimumTotalPrice(
        int n,
        vector<vector<int>>& edges,
        vector<int>& price,
        vector<vector<int>>& trips
    ) {
        vector<vector<int>> adj(n);

        for (auto& edge : edges) {
            int u = edge[0];
            int v = edge[1];

            adj[u].push_back(v);
            adj[v].push_back(u);
        }

        LCA lca(adj);
        vector<int> cnt(n, 0);

        for (auto& trip : trips) {
            int u = trip[0];
            int v = trip[1];
            int w = lca.lca(u, v);

            ++cnt[u];
            ++cnt[v];
            --cnt[w];

            if (lca.parent(w) != w) {
                --cnt[lca.parent(w)];
            }
        }

        auto collect = [&](auto&& collect, int cur, int parent) -> void {
            for (int next : adj[cur]) {
                if (next == parent) continue;

                collect(collect, next, cur);
                cnt[cur] += cnt[next];
            }
        };

        collect(collect, 0, -1);

        auto dfs = [&](auto&& dfs, int cur, int parent) -> pll {
            ll not_half = 1LL * cnt[cur] * price[cur];
            ll half = 1LL * cnt[cur] * price[cur] / 2;

            for (int next : adj[cur]) {
                if (next == parent) continue;

                auto [child_not_half, child_half] = dfs(dfs, next, cur);
                not_half += min(child_not_half, child_half);
                half += child_not_half;
            }

            return {not_half, half};
        };

        auto [not_half, half] = dfs(dfs, 0, -1);
        return min(not_half, half);
    }
};

Heavy-Light Decomposition

Heavy-Light Decomposition,通常縮寫成 HLD,適合處理樹上路徑查詢與更新。它會把樹拆成多條 heavy chain,並把每個節點映射到一個連續序列上的位置。之後一條 u -> v 路徑可以被拆成數段連續區間,再交給 segment tree 維護。

時間複雜度

  • HLD 預處理:O(n)
  • 單點更新:O(log n)
  • 路徑查詢:O(log^2 n)

維護資訊

  • parent[u]:rooted tree 中 u 的 parent。
  • depth[u]:節點 u 的深度,用來判斷哪條 chain 比較深。
  • sz[u]:以 u 為根的子樹大小。
  • heavy[u]:節點 u 最大子樹的 child。u -> heavy[u] 這條邊稱為 heavy edge,會盡量延續同一條 chain。
  • head[u]:節點 u 所在 heavy chain 的頂端。
  • pos[u]:節點 u 在線性序列中的位置。節點值會放到 pos[u],再交給 segment tree 維護。

DFS 拆解

dfs1 負責把樹 root 起來,並找出每個節點的 heavy child:

  • 設定 parent[u]depth[u]
  • 計算 sz[u],也就是 u 的子樹大小。
  • 在所有 child 中挑出子樹最大的那個,記成 heavy[u]

dfs2 負責把每條 heavy chain 攤平成連續區間:

  • 設定 head[u],表示 u 目前所在 chain 的頂端。
  • 設定 pos[u] = cur++,把節點映射到 segment tree 的線性位置。
  • 先走 heavy[u],讓同一條 heavy chain 的節點在 pos 上連續。
  • 其他非 heavy child 會各自開新的 chain,因此呼叫 dfs2(v, v)

路徑查詢

queryPath(u, v) 會把原本樹上的路徑拆成幾段 segment tree 區間:

  1. head[u] != head[v],代表 uv 還在不同 chain。
  2. 每次選 chain head 比較深的那一端往上處理。若 depth[head[u]] < depth[head[v]],就交換 uv
  3. 對目前較深的 chain 查詢 [pos[head[u]], pos[u]],這段是連續區間,可以直接丟給 segment tree。
  4. 查完後把 u 移到這條 chain 頂端的 parent,也就是 u = parent[head[u]]
  5. 重複直到兩個點在同一條 chain。最後只需要查 [pos[u], pos[v]],其中較淺的點要放在左邊。

這份模板的 segment tree 維護 XOR,所以每段查詢結果都用 mask ^= ... 合併。若題目改成路徑和、最大值或最小值,只要把 segment tree 的合併方式換掉即可。

參考題型

  • LeetCode - Palindromic Path Queries in a Tree:每個節點維護字母出現次數的 parity bitmask。路徑查詢時把所有節點的 mask XOR 起來,若最後最多只有一個 bit 是 1,代表這條路徑上的字母可以重排成 palindrome。
C++
#include <bits/stdc++.h>
using namespace std;

class SegmentTree {
public:
    int n;
    vector<int> st;

    SegmentTree(int n = 0) {
        init(n);
    }

    void init(int _n) {
        n = _n;
        st.assign(n * 4, 0);
    }

    void update(int k, int v) {
        update(1, 0, n - 1, k, v);
    }

    int query(int l, int r) {
        if (l > r) return 0;
        return query(1, 0, n - 1, l, r);
    }

private:
    void update(int p, int l, int r, int k, int v) {
        if (l == r) {
            st[p] = v;
            return;
        }

        int mid = (l + r) / 2;

        if (k <= mid) {
            update(p << 1, l, mid, k, v);
        } else {
            update(p << 1 | 1, mid + 1, r, k, v);
        }

        st[p] = st[p << 1] ^ st[p << 1 | 1];
    }

    int query(int p, int l, int r, int ql, int qr) {
        if (ql <= l && r <= qr) {
            return st[p];
        }

        if (r < ql || qr < l) {
            return 0;
        }

        int mid = (l + r) / 2;

        return query(p << 1, l, mid, ql, qr) ^
               query(p << 1 | 1, mid + 1, r, ql, qr);
    }
};

class HLD {
public:
    int n;
    int cur;

    vector<vector<int>> g;

    vector<int> parent;
    vector<int> depth;
    vector<int> sz;
    vector<int> heavy;
    vector<int> head;
    vector<int> pos;

    SegmentTree seg;

    HLD(int n) : n(n), seg(n) {
        cur = 0;

        g.assign(n, {});
        parent.assign(n, -1);
        depth.assign(n, 0);
        sz.assign(n, 0);
        heavy.assign(n, -1);
        head.assign(n, 0);
        pos.assign(n, 0);
    }

    void addEdge(int u, int v) {
        g[u].push_back(v);
        g[v].push_back(u);
    }

    void build(const string& s, int root = 0) {
        dfs1(root, -1);
        dfs2(root, root);

        for (int u = 0; u < n; ++u) {
            updateNode(u, s[u]);
        }
    }

    void updateNode(int u, char c) {
        int mask = 1 << (c - 'a');
        seg.update(pos[u], mask);
    }

    int queryPath(int u, int v) {
        int mask = 0;

        while (head[u] != head[v]) {
            if (depth[head[u]] < depth[head[v]]) {
                swap(u, v);
            }

            int top = head[u];

            mask ^= seg.query(pos[top], pos[u]);

            u = parent[top];
        }

        if (depth[u] > depth[v]) {
            swap(u, v);
        }

        mask ^= seg.query(pos[u], pos[v]);

        return mask;
    }

private:
    int dfs1(int u, int p) {
        parent[u] = p;
        sz[u] = 1;

        int maxChildSize = 0;

        for (int v : g[u]) {
            if (v == p) continue;

            depth[v] = depth[u] + 1;

            int childSize = dfs1(v, u);
            sz[u] += childSize;

            if (childSize > maxChildSize) {
                maxChildSize = childSize;
                heavy[u] = v;
            }
        }

        return sz[u];
    }

    void dfs2(int u, int h) {
        head[u] = h;
        pos[u] = cur++;

        if (heavy[u] != -1) {
            dfs2(heavy[u], h);
        }

        for (int v : g[u]) {
            if (v == parent[u]) continue;
            if (v == heavy[u]) continue;

            dfs2(v, v);
        }
    }
};

class Solution {
public:
    vector<bool> palindromePath(
        int n,
        vector<vector<int>>& edges,
        string s,
        vector<string>& queries
    ) {
        HLD hld(n);

        for (auto& e : edges) {
            hld.addEdge(e[0], e[1]);
        }

        hld.build(s);

        vector<bool> ret;

        for (auto& q : queries) {
            stringstream ss(q);
            string op;
            ss >> op;

            if (op == "query") {
                int a, b;
                ss >> a >> b;

                int mask = hld.queryPath(a, b);

                ret.push_back(__builtin_popcount(mask) <= 1);
            } else {
                int k;
                char c;
                ss >> k >> c;

                hld.updateNode(k, c);
            }
        }

        return ret;
    }
};