学习笔记:后缀数组
2025-02-02

非常后知后觉地意识到 SA(Suffix Array) 和 SAM(Suffix Automaton) 的 A 不是同一个 A


#定义

显而易见一个长度为 n 的字符串中有 n 个长度分别为 1n 的后缀,如果我们对其按字典序排序,分别存储下排名 i 的后缀 sai 和每个后缀 i 的排名 rki。虽然看着挺没头没尾的,但是很有用。

#求解

#哈希 + 排序

直接把所有后缀拿来排序的话,字符串比较是 O(n) 的。如果我们用哈希 + 二分优化比较过程,就可以把整个排序优化到 O(nlog2n)

#倍增

先对所有后缀按 第一个字符 排序,记排序后排名序列为 a

那么怎么按 前两个字符 排序呢?对于第 i 组字符,我们用 (ai,ai+1) 双关键字排序即可。记此时排名序列为 b,那么如果需要按照前四个字符排序,用 (bi,bi+2) 进行双关键字排序即可。总共需要进行 logn 次排序。复杂度为 O(nlog2n)

此时我们注意到排名数组的值域为 n,那么我们用桶排就能少一个 log

#实现

哈希很好实现,这里就按下不表,主要讲解倍增法的实现。

描述起来很简单,实现起来很要命。OI wiki 上的实现算是相对好理解的:

首先了解双关键字桶排的方法,首先用单关键字桶排完成对 第二关键字 的排序;对于第一关键字,令桶 i 记录前 i 个元素的数量;遍历排序后的第二关键字数组,将元素放到桶中记录数值对应的下标中,并将桶中数值 1。实际上桶 c 充当计算下标范围的作用,(ci1,ci] 即为 i 分布的范围。

显然,当且仅当排名种类为 n,即没有并列排名时,排序完成。设本轮区间长度为 w,对于一轮操作:

  1. 计算每个区间按后半段 w2 长度字符排序的结果:(nw,n] 开头的区间后半段均为空,直接放在序列首端;接着按照上一轮 sa 结果,把能够作为后半段的元素依次放入。
  2. 依照上一轮的 rk 作为前半段排名,进行双关键字桶排。
  3. 依照 sa 和第二关键字(处理并列),求出 rk
std::vector<int> la(n + 2);
std::copy(s.begin(), s.end(), rk.begin());
int m = 128;
{
    std::vector<int> c(m + 1);
    for (int i = 1; i <= n; ++i)
        ++c[rk[i]];
    std::partial_sum(c.begin(), c.end(), c.begin());
    for (int i = n; i; --i)
        sa[c[rk[i]]--] = i;
} 
for (int w = 1, p; ; w <<= 1, m = p) {
    std::vector<int> id(1);
    for (int i = n - w + 1; i <= n; ++i)
        id.push_back(i);
    for (int i = 1; i <= n; ++i)
        if (sa[i] > w)
            id.push_back(sa[i] - w);
    std::vector<int> c(m + 1);
    for (int i = 1; i <= n; ++i)
        ++c[rk[i]];
    std::partial_sum(c.begin(), c.end(), c.begin());
    for (int i = n; i; --i)
        sa[c[rk[id[i]]]--] = id[i];
    p = 0;
    std::copy(rk.begin(), rk.end(), la.begin());
    for (int i = 1; i <= n; ++i)
        if (la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
            rk[sa[i]] = p;
        else
            rk[sa[i]] = ++p;
    if (p == n)
        break;
}
查看代码

#纯 SA 的应用

#最小表示法

模板:https://www.luogu.com.cn/problem/P1368

对于循环位移相关要求,首先考虑将字符串重复一遍。

ss 中找到排名第一个 sain 即为答案。

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n;
    std::cin >> n;
    std::vector<int> s(2 * n + 1);
    for (int i = 1; i <= n; ++i)
        std::cin >> s[i], s[n + i] = s[i];
    std::vector<int> sa(2 * n + 1), rk(s);
    {
        int m = 29;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= 2 * n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = 2 * n; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(1);
            for (int i = 2 * n - w + 1; i <= 2 * n; ++i)
                id.push_back(i);
            for (int i = 1; i <= 2 * n; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= 2 * n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = 2 * n; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            auto la(rk);
            p = 0;
            for (int i = 1; i <= 2 * n; ++i)
                if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == 2 * n)
                break;
        }
    }
    for (int i = 1; i <= 2 * n; ++i)
        if (sa[i] <= n) {
            for (int j = sa[i]; j < n + sa[i]; ++j)
                std::cout << s[j] << ' ';
            std::cout << '\n';
            break;
        }
    return 0;
}
查看代码

#字符串匹配

二分,复杂度 O(|S|log|T|)。求出现次数则二分左右边界。

太麻烦了且没有实际应用价值,代码略。


#height 数组

定义 hi=lcp(sai,sai1),特别地,h1=0

有引理:hrkihrki11

假设已经求出 hrki1,那么可以从 hrki11 出发暴力看下一个字符是否相等得到答案。那么我们会发现从前往后 h 值每次最多 1,所以复杂度摊下来是 O(n) 的。

记住记住一定是 rki1 而不是下意识的 rki1!!!所以为了保证求解顺序循环枚举的一定是下标而非排名。但是注意定义却是和 rki1 的 lcp!!!所以求 height 的写法是相对固定的,不能觉得好像是对的就随便乱改。


#height 数组的应用

相当于背板子,因为应用太多且形式大多固定。

#求任意两个后缀的 lcp

易得 lcp(sai,saj)=min{hi+1,,hj}故应将一些复杂 lcp 问题的解决方式和 RMQ 联系起来


#子串大小关系

即比较 Sl1,r1Sl2,r2 的大小关系。比较导致 lcp 不能继续延伸的元素大小即可。


#本质不同子串数量

子串等价于「后缀的前缀」。按顺序枚举每个后缀,减去和已枚举的所有后缀的 lcp 即可。鉴于 min{hj+1,,hi} 单调不减,直接减去 hi 即可。

最后答案即为 n(n1)2ni=2hi


#至少出现 k 次子串的最大长度

模板:https://www.luogu.com.cn/problem/P2852

出现 k 在后缀数组中连续出现 k 是任意连续 k1h 的最小值,需要最大化该最小值,考虑滑动窗口。

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen("P2852_7.in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, k;
    std::cin >> n >> k, --k;
    std::vector<int> s(n + 1);
    for (int i = 1; i <= n; ++i)
        std::cin >> s[i];
    std::vector<int> sa(n + 1), rk(s), h(n + 1);
    {
        int m = 1000001;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(1);
            for (int i = n - w + 1; i <= n; ++i)
                id.push_back(i);
            for (int i = 1; i <= n; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            auto la(rk);
            p = 0;
            for (int i = 1; i <= n; ++i)
                if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n)
                break;
        }
        for (int i = 1, to = 0; i <= n; ++i)
            if (rk[i]) {
                for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
                h[rk[i]] = to;
            }
    }
    std::vector<int> q(n + 1);
    int res = 0;
    for (int i = 1, l = 1, r = 0; i <= n; ++i) {
        // printf("%d\n", h[i]);
        for (; l <= r && i - q[l] >= k; ++l);
        for (; l <= r && h[i] <= h[q[r]]; --r);
        q[++r] = i;
        if (i >= k)
            res = std::max(res, h[q[l]]);
    }
    std::cout << res << '\n';
    return 0;
}
查看代码

#最长不重叠多次出现子串

bb:定式太多太杂以至于让人怀疑某些定式是否存在应用场景

发现满足单调性,二分子串长度 len,那么显然 lcplen;将 h 划分为连续 len 的段,在每段内找到下标极差与 len 比较即可。

也可以用于判定是否存在不重叠多次出现子串。

甚至可以考虑限制至少出现次数为 k,那大概多个 log,看看一段里有没有 k 个相互相差 len 的。排序贪心求解。

那么上面的至少出现 k 次子串也可以用这个方法来解,但是多个 log 没必要。

也可以限制多次出现但长度至少为 len,那甚至少了二分的 log,直接跑一遍 check 即可。

???到底为什么会有这么多奇怪的定式,是因为真的有题这么出吗???


#最长公共子串问题

ST 的最长公共子串(注意不是 LCS)。设 S 长为 nT 长为 m,那么将 ST 拼接,答案就是 max{lcp(i,j)},in<j

但这里不直接枚举 ij,还是照例先从 h 下手再卡条件,若 sai1n<sai(或者反过来),就可以用 hi 更新答案。容易证明这样总可以找到最大值。

eg1. 找相同字符

https://www.luogu.com.cn/problem/P3181

要求方案数,那么答案为 lcp(i,j),in<j。(我已经帮你们试过了容斥比直接做更麻烦),考虑用单调栈维护左 / 右侧区间 lcp 求解右 / 左侧答案。关于单调栈的描述可见 本页后部内容

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, n1;
    std::string s, t;
    std::cin >> s >> t;
    n = (int)s.length(), n1 = n + (int)t.length() + 1;
    s = "#" + s + "$" + t;
    std::vector<int> sa(n1 + 1), rk(n1 + 1), h(n1 + 1);
    {
        std::copy(s.begin() + 1, s.end(), rk.begin() + 1);
        int m = 128;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n1; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n1; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(w + 1);
            std::iota(id.begin() + 1, id.end(), n1 - w + 1);
            for (int i = 1; i <= n1; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n1; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n1; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            p = 0;
            auto la(rk);
            for (int i = 1; i <= n1; ++i)
                if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n1)
                break;
        }
        for (int i = 1, to = 0; i <= n1; ++i) {
            for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
            h[rk[i]] = to;
        }
    }
    std::vector<std::pair<int, long long> > q1, q2;
    std::vector<int> tot1(n1 + 1), tot2(n1 + 1);
    for (int i = 1; i <= n1; ++i) {
        tot1[i] = tot1[i - 1] + (sa[i] <= n);
        tot2[i] = tot2[i - 1] + (sa[i] > n + 1);
    }
    long long res = 0ll;
    q1.emplace_back(1, 0ll), q2.emplace_back(1, 0ll);
    for (int i = 1; i <= n1; ++i) {
        for (; !q1.empty() && h[i] < h[q1.back().first]; q1.pop_back());
        q1.emplace_back(i, (tot1[i - 1] - tot1[q1.back().first - 1]) * h[i] + q1.back().second);
        if (sa[i] > n + 1)
            res += q1.back().second;
        for (; !q2.empty() && h[i] < h[q2.back().first]; q2.pop_back());
        q2.emplace_back(i, (tot2[i - 1] - tot2[q2.back().first - 1]) * h[i] + q2.back().second);
        if (sa[i] <= n)
            res += q2.back().second;
    }
    std::cout << res << '\n';
    return 0;
}
查看代码

eg2. 公共串

https://www.luogu.com.cn/problem/P5546

要求多串最长公共子串,仍然考虑将多个串拼在一起。仿照前面二分的方式处理,问题转化为找到最长的 len,使得存在一段最小值 len 的区间,其覆盖了 n 段串。

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, l = 0, r = 0;
    std::cin >> n;
    std::string s;
    std::vector<std::pair<int, int> > lim(n + 1);
    for (int i = 1; i <= n; ++i) {
        std::string t;
        std::cin >> t;
        lim[i] = { (int)s.length() + 1, s.length() + t.length() };
        s += "#" + t;
        r = std::max(r, (int)t.length());
        // printf("[%d, %d]\n", lim[i].first, lim[i].second);
    }
    int n1 = lim.back().second;
    std::vector<int> sa(n1 + 1), rk(n1 + 1), h(n1 + 1);
    {
        std::copy(s.begin() + 1, s.end(), rk.begin() + 1);
        int m = 128;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n1; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n1; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(w + 1);
            std::iota(id.begin() + 1, id.end(), n1 - w + 1);
            for (int i = 1; i <= n1; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n1; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n1; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            p = 0;
            auto la(rk);
            for (int i = 1; i <= n1; ++i)
                if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n1)
                break;
        }
        for (int i = 1, to = 0; i <= n1; ++i) {
            for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
            h[rk[i]] = to;
        }
    }
    // for (int i = 1; i <= n1; ++i)
    //     printf("h[%d] = %d\n", sa[i], h[i]);
    int res = 0;
    auto check = [&](int len) {
        // printf("check %d: \n", len);
        std::vector<int> cnt(n + 1);
        for (int i = 1; i <= n1; ++i) {
            if (h[i] < len) {
                if (*std::min_element(cnt.begin() + 1, cnt.end()))
                    return 1;
                cnt.assign(n + 1, 0);
            }
            else
                for (int j = 1; j <= n; ++j) {
                    if (lim[j].first <= sa[i - 1] && sa[i - 1] <= lim[j].second)
                        cnt[j] = 1;
                    if (lim[j].first <= sa[i] && sa[i] <= lim[j].second)
                        cnt[j] = 1;
                }
        }
        // printf("\n%d\n", *std::min_element(cnt.begin() + 1, cnt.end()));
        return *std::min_element(cnt.begin() + 1, cnt.end());
    };
    for (int mid; l <= r; ) {
        mid = (l + r) >> 1;
        if (check(mid))
            l = mid + 1, res = mid;
        else
            r = mid - 1;
    }
    std::cout << res << '\n';
    return 0;
}
查看代码

但是看了题解发现居然还有线性做法(当然不看建 SA 的 log),对于覆盖全部 n 段串找区间最小值,发现需要最小化区间,考虑双指针。

区间最小值用单调队列求解,细想可能会觉得不太对劲,但是容易证明答案不大于队首且不小于最大队首,所以最大队首就是答案。

#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int n, l = 0, r = 0;
    std::cin >> n;
    std::string s;
    std::vector<std::pair<int, int> > lim(n + 1);
    for (int i = 1; i <= n; ++i) {
        std::string t;
        std::cin >> t;
        lim[i] = { (int)s.length() + 1, s.length() + t.length() };
        s += (char)('A' + i - 1) + t;
        r = std::max(r, (int)t.length());
    }
    int n1 = lim.back().second;
    std::vector<int> sa(n1 + 1), rk(n1 + 1), h(n1 + 1);
    {
        std::copy(s.begin() + 1, s.end(), rk.begin() + 1);
        int m = 128;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n1; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n1; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(w + 1);
            std::iota(id.begin() + 1, id.end(), n1 - w + 1);
            for (int i = 1; i <= n1; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n1; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n1; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            p = 0;
            auto la(rk);
            for (int i = 1; i <= n1; ++i)
                if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n1)
                break;
        }
        for (int i = 1, to = 0; i <= n1; ++i) {
            for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
            h[rk[i]] = to;
        }
    }
    int res = 0;
    std::vector<int> q(n1 + 1), cnt(n + 1);
    // for (int i = 1; i <= n1; ++i)
    //     printf("%d: %d\n", sa[i], h[i]);
    for (int l = 1, r = 0, ql = 1, qr = 0; l <= n1; ++l) {
        for (; r < n1 && !*std::min_element(cnt.begin() + 1, cnt.end()); ) {
            ++r;
            for (int i = 1; i <= n; ++i)
                if (lim[i].first <= sa[r] && sa[r] <= lim[i].second) {
                    ++cnt[i];
                    break;
                }
            for (; ql <= qr && h[r] <= h[q[qr]]; --qr);
            q[++qr] = r;
        }
        if (*std::min_element(cnt.begin() + 1, cnt.end())) {
            // printf("[%d, %d]: %d\n", l, r, h[q[ql]]);
            res = std::max(res, h[q[ql]]);
        }
        for (; ql <= qr && q[ql] <= l; ++ql);
        if (l != 1) {
            for (int i = 1; i <= n; ++i)
                if (lim[i].first <= sa[l - 1] && sa[l - 1] <= lim[i].second) {
                    --cnt[i];
                    break;
                }
        }
    }
    std::cout << res << '\n';
    return 0;
}
查看代码

#AA 式子串处理

即对于连续相同子串问题的处理,有一个定的思路,由例题分析。

eg1. 优秀的拆分

https://www.luogu.com.cn/problem/P1117

还是从中间分开,按前后分别处理。这里有个 trick,我们枚举 B 的长度 len,在 S 中每隔 len 打一个标记。那么显然,任意一个长度为 2×len 的子串都会经过恰好两个标记(充分的),这样就可以筛选出所有可能的串。

我们枚举所有连续两个标记(总复杂度为调和级数),求它们对应后缀的 lcp 和对应前缀的 lcs(翻转求 SA 即可),如果二者加起来 len 就说明存在这样的 AA。在 lcs+lcp 中任取 len 长度即为一对 AA。用差分给可能的起点和终点区间加即可。

小细节:lcp 和 lcs 均需要对 lenmin,否则取到的串可能不会经过当前选中的两个标记。

#include <bits/stdc++.h>
class SA {
public:
    std::vector<int> sa, rk, h;
    std::vector<std::vector<int>  > st;
    SA(int n, std::string s): sa(n + 1), rk(n + 2), h(n + 1), st(20, std::vector<int> (n + 1)) {
        std::vector<int> la(n + 2);
        std::copy(s.begin(), s.end(), rk.begin());
        int m = 128;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(1);
            for (int i = n - w + 1; i <= n; ++i)
                id.push_back(i);
            for (int i = 1; i <= n; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            p = 0;
            std::copy(rk.begin(), rk.end(), la.begin());
            for (int i = 1; i <= n; ++i)
                if (la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n)
                break;
        }
        for (int i = 1, to = 0; i <= n; ++i)
            if (rk[i]) {
                for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
                h[rk[i]] = to;
            }
        for (int i = 1; i <= n; ++i)
            st[0][i] = h[i];
        for (int j = 1; (1 << j) <= n; ++j)
            for (int i = 1; i + (1 << j) - 1 <= n; ++i)
                st[j][i] = std::min(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
        rk.emplace_back();
        return;
    }
private:
    int ask(int l, int r) {
        // fprintf(stderr, "l = %d, r = %d\n", l, r);
        int k = std::__lg(r - l + 1);
        return std::min(st[k][l], st[k][r - (1 << k) + 1]);
    }
public:
    int lcp(int l, int r) {
        return ask(std::min(rk[l], rk[r]) + 1, std::max(rk[l], rk[r]));
    }
};
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int T;
    for (std::cin >> T; T--; ) {
        std::string s;
        std::cin >> s;
        int n = (int)s.length();
        s = "#" + s;
        SA p(n, s);
        std::reverse(s.begin() + 1, s.end());
        SA q(n, s);
        std::vector<int> f(n + 2), g(n + 2);
        for (int len = 1; len <= n / 2; ++len)
            for (int i = len; i + len <= n; i += len) {
                int l = i, r = i + len, lcp = std::min(len, p.lcp(l, r)), lcs = std::min(len - 1, q.lcp(n - l + 2, n - r + 2));
                if (lcp + lcs >= len) {
                    int t = lcp + lcs - len + 1;
                    // fprintf(stderr, "(%d, %d), %d, %d\n", l, r, lcp, lcs);
                    ++g[l - lcs], --g[l - lcs + t], ++f[r + lcp - t], --f[r + lcp];
                }
            }
        std::partial_sum(f.begin(), f.end(), f.begin());
        std::partial_sum(g.begin(), g.end(), g.begin());
        long long res = 0ll;
        for (int i = 1; i < n; ++i)
            res += (long long)f[i] * g[i + 1];
        std::cout << res << '\n';
    }
    return 0;
}
查看代码

eg2. tandem

https://www.codechef.com/problems/TANDEM

注意到多了一个限制,前一个好处理,找到经过 3 个标记的串即可。对于后一个限制,画图可以发现对于 interesting ones,每次只会出现最多一个;当且仅当 lcp>len 时不存在。

对于 uninteresting ones,用每次能提供的总数减去 interesting ones 的数量即可。

#include <bits/stdc++.h>
class SA {
public:
    std::vector<int> sa, rk, h;
    std::vector<std::vector<int>  > st;
    SA(int n, std::string s): sa(n + 1), rk(n + 2), h(n + 1), st(20, std::vector<int> (n + 1)) {
        std::vector<int> la(n + 2);
        std::copy(s.begin(), s.end(), rk.begin());
        int m = 128;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(1);
            for (int i = n - w + 1; i <= n; ++i)
                id.push_back(i);
            for (int i = 1; i <= n; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            p = 0;
            std::copy(rk.begin(), rk.end(), la.begin());
            for (int i = 1; i <= n; ++i)
                if (la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n)
                break;
        }
        for (int i = 1, to = 0; i <= n; ++i)
            if (rk[i]) {
                for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
                h[rk[i]] = to;
            }
        for (int i = 1; i <= n; ++i)
            st[0][i] = h[i];
        for (int j = 1; (1 << j) <= n; ++j)
            for (int i = 1; i + (1 << j) - 1 <= n; ++i)
                st[j][i] = std::min(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
        rk.emplace_back();
        return;
    }
private:
    int ask(int l, int r) {
        // fprintf(stderr, "l = %d, r = %d\n", l, r);
        int k = std::__lg(r - l + 1);
        return std::min(st[k][l], st[k][r - (1 << k) + 1]);
    }
public:
    int lcp(int l, int r) {
        return ask(std::min(rk[l], rk[r]) + 1, std::max(rk[l], rk[r]));
    }
};
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    std::string s;
    std::cin >> s;
    int n = (int)s.length();
    s = "#" + s;
    SA p(n, s);
    std::reverse(s.begin() + 1, s.end());
    SA q(n, s);
    std::reverse(s.begin() + 1, s.end());
    long long res1 = 0ll, res2 = 0ll;
    for (int len = 1; len <= n; ++len)
        for (int i = len, j = 2 * len, k = 3 * len; k <= n; i += len, j += len, k += len) {
            int lcp = std::min(p.lcp(i, j), p.lcp(j, k)), lcs = std::min({ len - 1, q.lcp(n - i + 2, n - j + 2), q.lcp(n - j + 2, n - k + 2) });
            if (std::min(len, lcp) + lcs >= len) {
                // printf("(%d, %d, %d), %d, %d, %d\n", i, j, k, lcs, lcp, len);
                int t = (lcp <= len);
                res1 += t, res2 += std::min(len, lcp) + lcs - len + 1 - t;
            }
            // else
            //     printf("# (%d, %d, %d), %d, %d, %d\n", i, j, k, lcs, lcp, len);
        }
    std::cout << res1 << ' ' << res2 << '\n';
    return 0;
}
查看代码

eg3. repeats

https://www.spoj.com/problems/REPEATS/

重复次数最多,只需经过标记点最多。显然经过标记点的数量就是该字符串长除以 len 向下取整就可以得到重复次数减 1 的值。

选择两个连续标记点,对于 lcp 和 lcs(显然此时不需要对 lenmin),计算 lcp+lcslen+1 取最大即可。

#include <bits/stdc++.h>
class SA {
public:
    std::vector<int> sa, rk, h;
    std::vector<std::vector<int>  > st;
    SA(int n, std::string s): sa(n + 1), rk(n + 2), h(n + 1), st(20, std::vector<int> (n + 1)) {
        std::vector<int> la(n + 2);
        std::copy(s.begin(), s.end(), rk.begin());
        int m = 128;
        {
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[i]]--] = i;
        }
        for (int w = 1, p; ; w <<= 1, m = p) {
            std::vector<int> id(1);
            for (int i = n - w + 1; i <= n; ++i)
                id.push_back(i);
            for (int i = 1; i <= n; ++i)
                if (sa[i] > w)
                    id.push_back(sa[i] - w);
            std::vector<int> c(m + 1);
            for (int i = 1; i <= n; ++i)
                ++c[rk[i]];
            std::partial_sum(c.begin(), c.end(), c.begin());
            for (int i = n; i; --i)
                sa[c[rk[id[i]]]--] = id[i];
            p = 0;
            std::copy(rk.begin(), rk.end(), la.begin());
            for (int i = 1; i <= n; ++i)
                if (la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
                    rk[sa[i]] = p;
                else
                    rk[sa[i]] = ++p;
            if (p == n)
                break;
        }
        for (int i = 1, to = 0; i <= n; ++i)
            if (rk[i]) {
                for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
                h[rk[i]] = to;
            }
        for (int i = 1; i <= n; ++i)
            st[0][i] = h[i];
        for (int j = 1; (1 << j) <= n; ++j)
            for (int i = 1; i + (1 << j) - 1 <= n; ++i)
                st[j][i] = std::min(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
        rk.emplace_back();
        return;
    }
private:
    int ask(int l, int r) {
        // fprintf(stderr, "l = %d, r = %d\n", l, r);
        int k = std::__lg(r - l + 1);
        return std::min(st[k][l], st[k][r - (1 << k) + 1]);
    }
public:
    int lcp(int l, int r) {
        return ask(std::min(rk[l], rk[r]) + 1, std::max(rk[l], rk[r]));
    }
};
int main() {
#ifdef ONLINE_JUDGE
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
    std::freopen(".in", "r", stdin);
    std::freopen(".out", "w", stdout);
#endif
    int T;
    for (std::cin >> T; T--; ) {
        int n;
        std::cin >> n;
        std::string s = "#";
        for (int i = 1; i <= n; ++i) {
            char t;
            std::cin >> t;
            s.push_back(t);
        }
        SA p(n, s);
        std::reverse(s.begin() + 1, s.end());
        SA q(n, s);
        int res = 0;
        for (int len = 1; len <= n; ++len)
            for (int i = len, j = 2 * len; j <= n; i += len, j += len) {
                int lcp = p.lcp(i, j), lcs = q.lcp(n - i + 2, n - j + 2);
                if (lcp + lcs >= len)
                    res = std::max(res, (lcp + lcs) / len + 1);
            }
        std::cout << res << '\n';
    }
    return 0;
}
查看代码

#结合并查集


#结合单调栈


言论

春有百花秋有月,夏有凉风冬有雪。
来自「颂古五十五首其一」
来发评论吧~
Powered By Valine
v1.5.2