SegmentTree
這頁整理幾種常用的 segment tree 模板。每個段落都依照「適用情境、維護資訊、模板、參考題型」的順序整理,方便之後依題型快速套用。
Point Update + Range Query
適用於已知初始陣列、更新只影響單一位置、查詢需要合併一段區間資訊的題型。這裡以區間總和為例,合併時直接把左右子樹相加。
維護資訊
sum[p]:節點p對應區間的總和。build:由初始陣列建立 segment tree。update:把單一位置k的值改成v。query:回傳[ql, qr]的區間總和。
模板
C++
#include <bits/stdc++.h>
using namespace std;
class SegmentTree {
public:
int n;
vector<int> sum;
SegmentTree(vector<int>& arr) {
n = arr.size();
sum.assign(4 * n, 0);
build(1, 0, n - 1, arr);
}
void build(int p, int l, int r, vector<int>& arr) {
if (l == r) {
sum[p] = arr[l];
return;
}
int mid = (l + r) / 2;
build(p << 1, l, mid, arr);
build(p << 1 | 1, mid + 1, r, arr);
sum[p] = sum[p << 1] + sum[p << 1 | 1];
}
void update(int p, int l, int r, int k, int v) {
if (l == r) {
sum[p] = v;
return;
}
int mid = (l + r) / 2;
if (k <= mid)
update(p << 1, l, mid, k, v);
else
update(p << 1 | 1, mid + 1, r, k, v);
sum[p] = sum[p << 1] + sum[p << 1 | 1];
}
int query(int p, int l, int r, int ql, int qr) {
if (qr < l || r < ql) return 0;
if (ql <= l && r <= qr) {
return sum[p];
}
int mid = (l + r) / 2;
int a = query(p << 1, l, mid, ql, qr);
int b = query(p << 1 | 1, mid + 1, r, ql, qr);
return a + b;
}
};
參考題型
- LeetCode - Minimum Deletions to Make Alternating Substring:把
arr[i]定義成s[i] == s[i + 1],查詢[l, r]時只要統計arr[l..r-1]中有幾個相鄰字元相同;更新s[idx]時,只會影響idx - 1與idx兩個相鄰關係,因此適合用單點更新。
C++
class Solution {
public:
vector<int> minDeletions(string s, vector<vector<int>>& queries) {
int n = s.size();
if (n == 1) {
vector<int> ret;
for (auto& q : queries) {
if (q[0] == 2) ret.push_back(0);
}
return ret;
}
vector<int> arr(n - 1);
for (int i = 0; i < n - 1; ++i) {
arr[i] = (s[i] == s[i + 1]);
}
SegmentTree tree(arr);
vector<int> ret;
for (auto& q : queries) {
if (q[0] == 1) {
int idx = q[1];
if (idx < n - 1) {
arr[idx] ^= 1;
tree.update(1, 0, n - 2, idx, arr[idx]);
}
if (idx > 0) {
arr[idx - 1] ^= 1;
tree.update(1, 0, n - 2, idx - 1, arr[idx - 1]);
}
} else {
int l = q[1];
int r = q[2];
int v = tree.query(1, 0, n - 2, l, r - 1);
ret.push_back(v);
}
}
return ret;
}
};
Range Assign + Range Query
適用於每次把一段區間覆蓋成同一個值,並查詢區間統計資訊的題型。這裡以區間最大值為例,更新時會把整段 assign 成同一個高度。
維護資訊
st[p]:節點p對應區間的最大值。lazy:記錄尚未往下推的 assign value。is_lazy:標記該節點是否有 pending assignment,因為 assign value 可能是0。
模板
C++
#include <bits/stdc++.h>
using namespace std;
class SegmentTree {
public:
int n;
vector<int> st, lazy;
vector<bool> is_lazy;
SegmentTree(int n) {
this->n = n;
st.assign(4 * n, 0);
lazy.assign(4 * n, 0);
is_lazy.assign(4 * n, false);
}
void apply(int p, int v) {
lazy[p] = v;
st[p] = v;
is_lazy[p] = true;
}
void push(int p) {
if (!is_lazy[p]) return;
apply(p << 1, lazy[p]);
apply(p << 1 | 1, lazy[p]);
is_lazy[p] = false;
}
void pull(int p) {
st[p] = max(st[p << 1], st[p << 1 | 1]);
}
void update(int p, int l, int r, int ql, int qr, int val) {
if (ql <= l && r <= qr) {
apply(p, val);
return;
}
if (qr < l || r < ql) return;
push(p);
int mid = (l + r) / 2;
update(p << 1, l, mid, ql, qr, val);
update(p << 1 | 1, mid + 1, r, ql, qr, val);
pull(p);
}
int query(int p, int l, int r, int ql, int qr) {
if (ql <= l && r <= qr) {
return st[p];
}
if (qr < l || r < ql) return 0;
push(p);
int mid = (l + r) / 2;
int a = query(p << 1, l, mid, ql, qr);
int b = query(p << 1 | 1, mid + 1, r, ql, qr);
return max(a, b);
}
};
參考題型
- LeetCode - Falling Squares:先做座標離散化,把每個方塊覆蓋的半開區間
[l, l + size)轉成 segment tree 上的[L, R - 1]。每次先查詢該區間目前最大高度,再把整段 assign 成新的高度。
C++
class Solution {
public:
vector<int> fallingSquares(vector<vector<int>>& positions) {
vector<int> c;
for (auto& p : positions) {
int l = p[0];
int s = p[1];
c.push_back(l);
c.push_back(l + s);
}
sort(c.begin(), c.end());
c.erase(unique(c.begin(), c.end()), c.end());
int m = c.size();
SegmentTree tree(m);
vector<int> ans;
int gm = 0;
for (auto& p : positions) {
int l = p[0];
int r = l + p[1];
int L = lower_bound(c.begin(), c.end(), l) - c.begin();
int R = lower_bound(c.begin(), c.end(), r) - c.begin();
int base = tree.query(1, 0, m - 1, L, R - 1);
int nb = base + p[1];
tree.update(1, 0, m - 1, L, R - 1, nb);
gm = max(gm, nb);
ans.push_back(gm);
}
return ans;
}
};
Range Add + Range Query
適用於每次對一段區間加上同一個值,並查詢區間統計資訊的題型。這裡同時維護區間總和 sum 與平方和 sum2,可用來計算區間內所有 pair 的乘積總和。
維護資訊
sum[p]:節點p對應區間的總和。sum2[p]:節點p對應區間的平方和。lazy[p]:尚未往下推的 range add value。
若對長度為 len 的區間加上 v,更新公式如下:
sum += len * vsum2 += 2 * v * sum_old + len * v^2
注意 sum2 的更新需要用到更新前的 sum,所以在 apply 裡要先更新 sum2,再更新 sum。
查詢 pairwise product
若題目要查詢區間內所有 pair 的乘積總和:
\[
\sum_{l \le i < j \le r} A_i A_j
\]
可以透過以下公式計算:
\[
\frac{(\sum A_i)^2 - \sum A_i^2}{2}
\]
在模數為 998244353 時,2 的反元素是 (mod + 1) / 2。
模板
C++
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using pll = pair<ll, ll>;
const ll mod = 998244353;
class SegmentTree {
public:
vector<ll> sum, sum2, lazy;
SegmentTree(int n) {
sum.assign(4 * n + 5, 0);
sum2.assign(4 * n + 5, 0);
lazy.assign(4 * n + 5, 0);
}
ll norm(ll x) {
x %= mod;
if (x < 0) x += mod;
return x;
}
void apply(int p, int l, int r, ll v) {
v = norm(v);
ll len = r - l + 1;
sum2[p] = (
sum2[p]
+ 2LL * v % mod * sum[p] % mod
+ len % mod * v % mod * v % mod
) % mod;
sum[p] = (sum[p] + len % mod * v) % mod;
lazy[p] = (lazy[p] + v) % mod;
}
void push(int p, int l, int r) {
if (lazy[p] == 0) return;
int mid = (l + r) / 2;
apply(p << 1, l, mid, lazy[p]);
apply(p << 1 | 1, mid + 1, r, lazy[p]);
lazy[p] = 0;
}
void pull(int p) {
sum[p] = (sum[p << 1] + sum[p << 1 | 1]) % mod;
sum2[p] = (sum2[p << 1] + sum2[p << 1 | 1]) % mod;
}
void range_add(int p, int l, int r, int ql, int qr, ll v) {
if (qr < l || r < ql) return;
if (ql <= l && r <= qr) {
apply(p, l, r, v);
return;
}
push(p, l, r);
int mid = (l + r) / 2;
range_add(p << 1, l, mid, ql, qr, v);
range_add(p << 1 | 1, mid + 1, r, ql, qr, v);
pull(p);
}
pll query(int p, int l, int r, int ql, int qr) {
if (qr < l || r < ql) return {0, 0};
if (ql <= l && r <= qr) {
return {sum[p], sum2[p]};
}
push(p, l, r);
int mid = (l + r) / 2;
auto L = query(p << 1, l, mid, ql, qr);
auto R = query(p << 1 | 1, mid + 1, r, ql, qr);
return {
(L.first + R.first) % mod,
(L.second + R.second) % mod
};
}
};
參考題型
- AtCoder ABC455 F - Merge Slimes 2:每次對區間加值後,查詢同一段區間內所有 unordered pair 的乘積總和。核心是維護
sum與sum2,再用((sum^2 - sum2) / 2)得到答案。
C++
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
const ll inv2 = (mod + 1) / 2;
int n, q;
cin >> n >> q;
SegmentTree tree(n);
while (q--) {
int l, r;
ll a;
cin >> l >> r >> a;
tree.range_add(1, 1, n, l, r, a);
auto [s, s2] = tree.query(1, 1, n, l, r);
ll ans = (s * s % mod - s2 + mod) % mod;
ans = ans * inv2 % mod;
cout << ans << '\n';
}
return 0;
}
Segment Tree Binary Search
適用於要找「位置」而不是只查詢區間統計值的題型。做法是利用節點維護的資訊,從 root 往下走;每一層先判斷左子樹是否已經足夠決定答案,不足時再往右子樹走。
First Position With Max
當每個節點維護區間最大值時,可以找出第一個 arr[pos] >= k 的位置。呼叫 firstpos 前,先確認整棵樹的最大值是否足夠。
C++
#include <bits/stdc++.h>
using namespace std;
class SegmentTree {
public:
int n;
vector<int> st;
SegmentTree(int n) {
this->n = n;
st.assign(4 * n, 0);
}
void build(int p, int l, int r, vector<int>& arr) {
if (l == r) {
st[p] = arr[l];
return;
}
int mid = (l + r) / 2;
build(p << 1, l, mid, arr);
build(p << 1 | 1, mid + 1, r, arr);
st[p] = max(st[p << 1], st[p << 1 | 1]);
}
void update(int p, int l, int r, int k, int v) {
if (l == r) {
st[p] = v;
return;
}
int mid = (l + r) / 2;
if (k <= mid)
update(p << 1, l, mid, k, v);
else
update(p << 1 | 1, mid + 1, r, k, v);
st[p] = max(st[p << 1], st[p << 1 | 1]);
}
int query(int p, int l, int r, int ql, int qr) {
if (qr < l || r < ql) return INT_MIN;
if (ql <= l && r <= qr) {
return st[p];
}
int mid = (l + r) / 2;
return max(
query(p << 1, l, mid, ql, qr),
query(p << 1 | 1, mid + 1, r, ql, qr)
);
}
int firstpos(int p, int l, int r, int k) {
if (l == r) {
return l;
}
int mid = (l + r) / 2;
if (st[p << 1] >= k) {
return firstpos(p << 1, l, mid, k);
}
return firstpos(p << 1 | 1, mid + 1, r, k);
}
};
K-th Active Position
當每個節點維護區間內 active position 的數量時,可以找第 k 個 active 位置。若要限制在 [ql, qr],先查詢該區間數量是否足夠,再用 prefix count 把區間內第 k 個轉成全域第 k + pre 個。
C++
#include <bits/stdc++.h>
using namespace std;
class SegmentTree {
public:
int n;
vector<int> st;
SegmentTree(int n) {
this->n = n;
st.assign(4 * n, 0);
}
void update(int p, int l, int r, int k, int v) {
if (l == r) {
st[p] += v;
return;
}
int mid = (l + r) / 2;
if (k <= mid)
update(p << 1, l, mid, k, v);
else
update(p << 1 | 1, mid + 1, r, k, v);
st[p] = st[p << 1] + st[p << 1 | 1];
}
int query(int p, int l, int r, int ql, int qr) {
if (qr < l || r < ql) return 0;
if (ql <= l && r <= qr) {
return st[p];
}
int mid = (l + r) / 2;
return query(p << 1, l, mid, ql, qr)
+ query(p << 1 | 1, mid + 1, r, ql, qr);
}
int kth(int p, int l, int r, int k) {
if (l == r) {
return l;
}
int mid = (l + r) / 2;
if (st[p << 1] >= k) {
return kth(p << 1, l, mid, k);
}
return kth(p << 1 | 1, mid + 1, r, k - st[p << 1]);
}
int kth_inrange(int ql, int qr, int k) {
int cnt = query(1, 0, n - 1, ql, qr);
if (cnt < k) return -1;
int pre = query(1, 0, n - 1, 0, ql - 1);
int pos = kth(1, 0, n - 1, k + pre);
return pos > qr ? -1 : pos;
}
};
參考題型
- LeetCode - Fruits Into Baskets III:每個 basket 只能使用一次,segment tree 維護剩餘容量最大值。若全域最大值小於 fruit size,代表放不下;否則用
firstpos找最左邊容量足夠的 basket,再把該位置更新成-1。 - LeetCode - Next Greater Element IV:離線處理 index,segment tree 維護目前已啟用的位置。查詢
i右側第 2 個已啟用位置時,可用kth_inrange(i + 1, n - 1, 2)。