跳轉到

String Matching

String matching 題型常見的核心是比較子字串、尋找 pattern 出現位置,或計算每個 suffix/prefix 的匹配長度。這頁整理 Rolling Hash、KMP、Z Algorithm 三種模板。

Rolling Hash

Rolling hash 適合用來快速比較兩段序列是否相同。預處理 prefix hash 和 base power 後,每次查詢子區間 hash 都是 O(1),常搭配二分搜、集合交集或去重使用。

Codeforces 這類可能被刻意構造測資卡 hash 的平台,base 建議像模板一樣用亂數初始化。其他平台若沒有互相 hack 的情境,base 固定挑一組即可;mod 常用 10000000071000000009

時間複雜度

  • 建立 hash:O(n)
  • 查詢子區間 hash:O(1)
  • Longest common subpath:O(S log L) expected,其中 S 是所有 path 長度總和,L 是最短 path 長度

模板

C++
#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using ull = unsigned long long;

ll rnd(ll l, ll r) {
    static mt19937_64 rng(
        chrono::steady_clock::now().time_since_epoch().count()
    );
    return uniform_int_distribution<ll>(l, r)(rng);
}

class Hash {
public:
    static constexpr ll mod1 = 1000000007;
    static constexpr ll mod2 = 1000000009;

    static ll base1;
    static ll base2;

    vector<ll> h1, p1, h2, p2;

    Hash(const vector<int>& s) {
        int n = s.size();

        h1.assign(n + 1, 0);
        p1.assign(n + 1, 1);
        h2.assign(n + 1, 0);
        p2.assign(n + 1, 1);

        for (int i = 0; i < n; ++i) {
            ll x = s[i] + 1007;

            h1[i + 1] = (h1[i] * base1 + x) % mod1;
            p1[i + 1] = p1[i] * base1 % mod1;

            h2[i + 1] = (h2[i] * base2 + x) % mod2;
            p2[i + 1] = p2[i] * base2 % mod2;
        }
    }

    ull query(int l, int r) {
        ++l;
        ++r;

        ll val1 = (h1[r] - h1[l - 1] * p1[r - l + 1] % mod1 + mod1) % mod1;
        ll val2 = (h2[r] - h2[l - 1] * p2[r - l + 1] % mod2 + mod2) % mod2;

        return ((ull)val1 << 32) | (ull)val2;
    }
};

ll Hash::base1 = rnd(256, Hash::mod1 - 2);
ll Hash::base2 = rnd(256, Hash::mod2 - 2);

參考題型

  • LeetCode - Longest Common Subpath:二分答案長度 len。對每條 path 取出所有長度為 len 的 subpath hash,若某個 hash 在每條 path 都出現過,代表存在共同 subpath。
C++
class Solution {
public:
    int longestCommonSubpath(int n, vector<vector<int>>& paths) {
        int m = paths.size();
        int min_len = INT_MAX;
        vector<Hash> hashes;

        for (auto& path : paths) {
            hashes.emplace_back(path);
            min_len = min(min_len, (int)path.size());
        }

        auto ok = [&](int len) -> bool {
            if (len == 0) return true;

            unordered_map<ull, int> cnt;

            for (int i = 0; i < m; ++i) {
                unordered_set<ull> seen;

                for (int j = 0; j + len <= (int)paths[i].size(); ++j) {
                    seen.insert(hashes[i].query(j, j + len - 1));
                }

                for (ull key : seen) {
                    if (++cnt[key] == m) {
                        return true;
                    }
                }
            }

            return false;
        };

        int l = 0;
        int r = min_len + 1;

        while (l < r) {
            int mid = (l + r) / 2;

            if (ok(mid)) {
                l = mid + 1;
            } else {
                r = mid;
            }
        }

        return l - 1;
    }
};

KMP

KMP 適合用來找出 pattern 在 text 中的所有出現位置。Prefix function pi[i] 表示 pattern[0..i] 的最長 proper prefix,同時也是 suffix 的長度。匹配失敗時,可以用 pi 直接跳到下一個可能狀態,不需要把 text 指標往回退。

時間複雜度

  • 建立 prefix function:O(m)
  • 在 text 中匹配 pattern:O(n)
  • Count cells 題型:O(R * C + |pattern|)

模板

C++
#include <bits/stdc++.h>
using namespace std;

vector<int> prefix_function(const string& pattern) {
    int n = pattern.size();
    vector<int> pi(n, 0);

    for (int i = 1; i < n; ++i) {
        int j = pi[i - 1];

        while (j > 0 && pattern[i] != pattern[j]) {
            j = pi[j - 1];
        }

        if (pattern[i] == pattern[j]) {
            ++j;
        }

        pi[i] = j;
    }

    return pi;
}

vector<int> covered_by_matches(const string& text, const string& pattern) {
    int n = text.size();
    int m = pattern.size();

    vector<int> cover(n, 0);
    if (m == 0) return cover;

    vector<int> pi = prefix_function(pattern);
    vector<int> diff(n + 1, 0);
    int j = 0;

    for (int i = 0; i < n; ++i) {
        while (j > 0 && text[i] != pattern[j]) {
            j = pi[j - 1];
        }

        if (text[i] == pattern[j]) {
            ++j;
        }

        if (j == m) {
            int start = i - m + 1;
            ++diff[start];
            --diff[i + 1];
            j = pi[j - 1];
        }
    }

    int cur = 0;
    for (int i = 0; i < n; ++i) {
        cur += diff[i];
        cover[i] = cur > 0;
    }

    return cover;
}

參考題型

C++
class Solution {
public:
    int row_id(int i, int j, int n) {
        return i * n + j;
    }

    int col_id(int i, int j, int m) {
        return j * m + i;
    }

    int countCells(vector<vector<char>>& grid, string pattern) {
        int m = grid.size();
        int n = grid[0].size();

        string rows, cols;

        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                rows += grid[i][j];
            }
        }

        for (int j = 0; j < n; ++j) {
            for (int i = 0; i < m; ++i) {
                cols += grid[i][j];
            }
        }

        vector<int> row_cover = covered_by_matches(rows, pattern);
        vector<int> col_cover = covered_by_matches(cols, pattern);

        int ans = 0;
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                if (row_cover[row_id(i, j, n)] && col_cover[col_id(i, j, m)]) {
                    ++ans;
                }
            }
        }

        return ans;
    }
};

Z Algorithm

Z Algorithm 會計算 z[i],表示 s[i..] 和整個字串 s 的 longest common prefix 長度。它適合處理每個 suffix 和 prefix 的匹配長度,也可以用 pattern + '$' + text 來找 pattern 出現位置。

時間複雜度

  • 建立 Z array:O(n)
  • Sum of scores 題型:O(n)

模板

C++
#include <bits/stdc++.h>
using namespace std;

vector<int> z_function(const string& s) {
    int n = s.size();
    vector<int> z(n, 0);

    int l = 0;
    int r = 0;

    for (int i = 1; i < n; ++i) {
        if (i <= r) {
            z[i] = min(r - i + 1, z[i - l]);
        }

        while (i + z[i] < n && s[z[i]] == s[i + z[i]]) {
            ++z[i];
        }

        if (i + z[i] - 1 > r) {
            l = i;
            r = i + z[i] - 1;
        }
    }

    return z;
}

參考題型

C++
class Solution {
public:
    using ll = long long;

    long long sumScores(string s) {
        vector<int> z = z_function(s);
        ll ans = s.size();

        for (int x : z) {
            ans += x;
        }

        return ans;
    }
};