跳轉到

SegmentTree

這頁整理幾種常用的 segment tree 模板。每個段落都依照「適用情境、維護資訊、模板、參考題型」的順序整理,方便之後依題型快速套用。

Point Update + Range Query

適用於已知初始陣列、更新只影響單一位置、查詢需要合併一段區間資訊的題型。這裡以區間總和為例,合併時直接把左右子樹相加。

維護資訊

  • sum[p]:節點 p 對應區間的總和。
  • build:由初始陣列建立 segment tree。
  • update:把單一位置 k 的值改成 v
  • query:回傳 [ql, qr] 的區間總和。

模板

C++
#include <bits/stdc++.h>
using namespace std;

class SegmentTree {
public:
    int n;
    vector<int> sum;

    SegmentTree(vector<int>& arr) {
        n = arr.size();
        sum.assign(4 * n, 0);
        build(1, 0, n - 1, arr);
    }

    void build(int p, int l, int r, vector<int>& arr) {
        if (l == r) {
            sum[p] = arr[l];
            return;
        }

        int mid = (l + r) / 2;
        build(p << 1, l, mid, arr);
        build(p << 1 | 1, mid + 1, r, arr);

        sum[p] = sum[p << 1] + sum[p << 1 | 1];
    }

    void update(int p, int l, int r, int k, int v) {
        if (l == r) {
            sum[p] = v;
            return;
        }

        int mid = (l + r) / 2;
        if (k <= mid)
            update(p << 1, l, mid, k, v);
        else
            update(p << 1 | 1, mid + 1, r, k, v);

        sum[p] = sum[p << 1] + sum[p << 1 | 1];
    }

    int query(int p, int l, int r, int ql, int qr) {
        if (qr < l || r < ql) return 0;

        if (ql <= l && r <= qr) {
            return sum[p];
        }

        int mid = (l + r) / 2;
        int a = query(p << 1, l, mid, ql, qr);
        int b = query(p << 1 | 1, mid + 1, r, ql, qr);

        return a + b;
    }
};

參考題型

C++
class Solution {
public:
    vector<int> minDeletions(string s, vector<vector<int>>& queries) {
        int n = s.size();

        if (n == 1) {
            vector<int> ret;
            for (auto& q : queries) {
                if (q[0] == 2) ret.push_back(0);
            }
            return ret;
        }

        vector<int> arr(n - 1);
        for (int i = 0; i < n - 1; ++i) {
            arr[i] = (s[i] == s[i + 1]);
        }

        SegmentTree tree(arr);
        vector<int> ret;

        for (auto& q : queries) {
            if (q[0] == 1) {
                int idx = q[1];

                if (idx < n - 1) {
                    arr[idx] ^= 1;
                    tree.update(1, 0, n - 2, idx, arr[idx]);
                }

                if (idx > 0) {
                    arr[idx - 1] ^= 1;
                    tree.update(1, 0, n - 2, idx - 1, arr[idx - 1]);
                }
            } else {
                int l = q[1];
                int r = q[2];
                int v = tree.query(1, 0, n - 2, l, r - 1);
                ret.push_back(v);
            }
        }

        return ret;
    }
};

Range Assign + Range Query

適用於每次把一段區間覆蓋成同一個值,並查詢區間統計資訊的題型。這裡以區間最大值為例,更新時會把整段 assign 成同一個高度。

維護資訊

  • st[p]:節點 p 對應區間的最大值。
  • lazy:記錄尚未往下推的 assign value。
  • is_lazy:標記該節點是否有 pending assignment,因為 assign value 可能是 0

模板

C++
#include <bits/stdc++.h>
using namespace std;

class SegmentTree {
public:
    int n;
    vector<int> st, lazy;
    vector<bool> is_lazy;

    SegmentTree(int n) {
        this->n = n;
        st.assign(4 * n, 0);
        lazy.assign(4 * n, 0);
        is_lazy.assign(4 * n, false);
    }

    void apply(int p, int v) {
        lazy[p] = v;
        st[p] = v;
        is_lazy[p] = true;
    }

    void push(int p) {
        if (!is_lazy[p]) return;

        apply(p << 1, lazy[p]);
        apply(p << 1 | 1, lazy[p]);
        is_lazy[p] = false;
    }

    void pull(int p) {
        st[p] = max(st[p << 1], st[p << 1 | 1]);
    }

    void update(int p, int l, int r, int ql, int qr, int val) {
        if (ql <= l && r <= qr) {
            apply(p, val);
            return;
        }

        if (qr < l || r < ql) return;

        push(p);

        int mid = (l + r) / 2;
        update(p << 1, l, mid, ql, qr, val);
        update(p << 1 | 1, mid + 1, r, ql, qr, val);

        pull(p);
    }

    int query(int p, int l, int r, int ql, int qr) {
        if (ql <= l && r <= qr) {
            return st[p];
        }

        if (qr < l || r < ql) return 0;

        push(p);

        int mid = (l + r) / 2;
        int a = query(p << 1, l, mid, ql, qr);
        int b = query(p << 1 | 1, mid + 1, r, ql, qr);

        return max(a, b);
    }
};

參考題型

  • LeetCode - Falling Squares:先做座標離散化,把每個方塊覆蓋的半開區間 [l, l + size) 轉成 segment tree 上的 [L, R - 1]。每次先查詢該區間目前最大高度,再把整段 assign 成新的高度。
C++
class Solution {
public:
    vector<int> fallingSquares(vector<vector<int>>& positions) {
        vector<int> c;
        for (auto& p : positions) {
            int l = p[0];
            int s = p[1];
            c.push_back(l);
            c.push_back(l + s);
        }

        sort(c.begin(), c.end());
        c.erase(unique(c.begin(), c.end()), c.end());

        int m = c.size();
        SegmentTree tree(m);
        vector<int> ans;
        int gm = 0;

        for (auto& p : positions) {
            int l = p[0];
            int r = l + p[1];
            int L = lower_bound(c.begin(), c.end(), l) - c.begin();
            int R = lower_bound(c.begin(), c.end(), r) - c.begin();

            int base = tree.query(1, 0, m - 1, L, R - 1);
            int nb = base + p[1];

            tree.update(1, 0, m - 1, L, R - 1, nb);
            gm = max(gm, nb);
            ans.push_back(gm);
        }

        return ans;
    }
};

Range Add + Range Query

適用於每次對一段區間加上同一個值,並查詢區間統計資訊的題型。這裡同時維護區間總和 sum 與平方和 sum2,可用來計算區間內所有 pair 的乘積總和。

維護資訊

  • sum[p]:節點 p 對應區間的總和。
  • sum2[p]:節點 p 對應區間的平方和。
  • lazy[p]:尚未往下推的 range add value。

若對長度為 len 的區間加上 v,更新公式如下:

  • sum += len * v
  • sum2 += 2 * v * sum_old + len * v^2

注意 sum2 的更新需要用到更新前的 sum,所以在 apply 裡要先更新 sum2,再更新 sum

查詢 pairwise product

若題目要查詢區間內所有 pair 的乘積總和:

\[ \sum_{l \le i < j \le r} A_i A_j \]

可以透過以下公式計算:

\[ \frac{(\sum A_i)^2 - \sum A_i^2}{2} \]

在模數為 998244353 時,2 的反元素是 (mod + 1) / 2

模板

C++
#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using pll = pair<ll, ll>;

const ll mod = 998244353;

class SegmentTree {
public:
    vector<ll> sum, sum2, lazy;

    SegmentTree(int n) {
        sum.assign(4 * n + 5, 0);
        sum2.assign(4 * n + 5, 0);
        lazy.assign(4 * n + 5, 0);
    }

    ll norm(ll x) {
        x %= mod;
        if (x < 0) x += mod;
        return x;
    }

    void apply(int p, int l, int r, ll v) {
        v = norm(v);
        ll len = r - l + 1;

        sum2[p] = (
            sum2[p]
            + 2LL * v % mod * sum[p] % mod
            + len % mod * v % mod * v % mod
        ) % mod;

        sum[p] = (sum[p] + len % mod * v) % mod;
        lazy[p] = (lazy[p] + v) % mod;
    }

    void push(int p, int l, int r) {
        if (lazy[p] == 0) return;

        int mid = (l + r) / 2;
        apply(p << 1, l, mid, lazy[p]);
        apply(p << 1 | 1, mid + 1, r, lazy[p]);

        lazy[p] = 0;
    }

    void pull(int p) {
        sum[p] = (sum[p << 1] + sum[p << 1 | 1]) % mod;
        sum2[p] = (sum2[p << 1] + sum2[p << 1 | 1]) % mod;
    }

    void range_add(int p, int l, int r, int ql, int qr, ll v) {
        if (qr < l || r < ql) return;

        if (ql <= l && r <= qr) {
            apply(p, l, r, v);
            return;
        }

        push(p, l, r);

        int mid = (l + r) / 2;
        range_add(p << 1, l, mid, ql, qr, v);
        range_add(p << 1 | 1, mid + 1, r, ql, qr, v);

        pull(p);
    }

    pll query(int p, int l, int r, int ql, int qr) {
        if (qr < l || r < ql) return {0, 0};

        if (ql <= l && r <= qr) {
            return {sum[p], sum2[p]};
        }

        push(p, l, r);

        int mid = (l + r) / 2;
        auto L = query(p << 1, l, mid, ql, qr);
        auto R = query(p << 1 | 1, mid + 1, r, ql, qr);

        return {
            (L.first + R.first) % mod,
            (L.second + R.second) % mod
        };
    }
};

參考題型

  • AtCoder ABC455 F - Merge Slimes 2:每次對區間加值後,查詢同一段區間內所有 unordered pair 的乘積總和。核心是維護 sumsum2,再用 ((sum^2 - sum2) / 2) 得到答案。
C++
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    const ll inv2 = (mod + 1) / 2;

    int n, q;
    cin >> n >> q;

    SegmentTree tree(n);

    while (q--) {
        int l, r;
        ll a;
        cin >> l >> r >> a;

        tree.range_add(1, 1, n, l, r, a);

        auto [s, s2] = tree.query(1, 1, n, l, r);

        ll ans = (s * s % mod - s2 + mod) % mod;
        ans = ans * inv2 % mod;

        cout << ans << '\n';
    }

    return 0;
}

適用於要找「位置」而不是只查詢區間統計值的題型。做法是利用節點維護的資訊,從 root 往下走;每一層先判斷左子樹是否已經足夠決定答案,不足時再往右子樹走。

First Position With Max

當每個節點維護區間最大值時,可以找出第一個 arr[pos] >= k 的位置。呼叫 firstpos 前,先確認整棵樹的最大值是否足夠。

C++
#include <bits/stdc++.h>
using namespace std;

class SegmentTree {
public:
    int n;
    vector<int> st;

    SegmentTree(int n) {
        this->n = n;
        st.assign(4 * n, 0);
    }

    void build(int p, int l, int r, vector<int>& arr) {
        if (l == r) {
            st[p] = arr[l];
            return;
        }

        int mid = (l + r) / 2;
        build(p << 1, l, mid, arr);
        build(p << 1 | 1, mid + 1, r, arr);

        st[p] = max(st[p << 1], st[p << 1 | 1]);
    }

    void update(int p, int l, int r, int k, int v) {
        if (l == r) {
            st[p] = v;
            return;
        }

        int mid = (l + r) / 2;
        if (k <= mid)
            update(p << 1, l, mid, k, v);
        else
            update(p << 1 | 1, mid + 1, r, k, v);

        st[p] = max(st[p << 1], st[p << 1 | 1]);
    }

    int query(int p, int l, int r, int ql, int qr) {
        if (qr < l || r < ql) return INT_MIN;

        if (ql <= l && r <= qr) {
            return st[p];
        }

        int mid = (l + r) / 2;
        return max(
            query(p << 1, l, mid, ql, qr),
            query(p << 1 | 1, mid + 1, r, ql, qr)
        );
    }

    int firstpos(int p, int l, int r, int k) {
        if (l == r) {
            return l;
        }

        int mid = (l + r) / 2;
        if (st[p << 1] >= k) {
            return firstpos(p << 1, l, mid, k);
        }
        return firstpos(p << 1 | 1, mid + 1, r, k);
    }
};

K-th Active Position

當每個節點維護區間內 active position 的數量時,可以找第 k 個 active 位置。若要限制在 [ql, qr],先查詢該區間數量是否足夠,再用 prefix count 把區間內第 k 個轉成全域第 k + pre 個。

C++
#include <bits/stdc++.h>
using namespace std;

class SegmentTree {
public:
    int n;
    vector<int> st;

    SegmentTree(int n) {
        this->n = n;
        st.assign(4 * n, 0);
    }

    void update(int p, int l, int r, int k, int v) {
        if (l == r) {
            st[p] += v;
            return;
        }

        int mid = (l + r) / 2;
        if (k <= mid)
            update(p << 1, l, mid, k, v);
        else
            update(p << 1 | 1, mid + 1, r, k, v);

        st[p] = st[p << 1] + st[p << 1 | 1];
    }

    int query(int p, int l, int r, int ql, int qr) {
        if (qr < l || r < ql) return 0;

        if (ql <= l && r <= qr) {
            return st[p];
        }

        int mid = (l + r) / 2;
        return query(p << 1, l, mid, ql, qr)
             + query(p << 1 | 1, mid + 1, r, ql, qr);
    }

    int kth(int p, int l, int r, int k) {
        if (l == r) {
            return l;
        }

        int mid = (l + r) / 2;
        if (st[p << 1] >= k) {
            return kth(p << 1, l, mid, k);
        }
        return kth(p << 1 | 1, mid + 1, r, k - st[p << 1]);
    }

    int kth_inrange(int ql, int qr, int k) {
        int cnt = query(1, 0, n - 1, ql, qr);
        if (cnt < k) return -1;

        int pre = query(1, 0, n - 1, 0, ql - 1);
        int pos = kth(1, 0, n - 1, k + pre);

        return pos > qr ? -1 : pos;
    }
};

參考題型

  • LeetCode - Fruits Into Baskets III:每個 basket 只能使用一次,segment tree 維護剩餘容量最大值。若全域最大值小於 fruit size,代表放不下;否則用 firstpos 找最左邊容量足夠的 basket,再把該位置更新成 -1
  • LeetCode - Next Greater Element IV:離線處理 index,segment tree 維護目前已啟用的位置。查詢 i 右側第 2 個已啟用位置時,可用 kth_inrange(i + 1, n - 1, 2)