非常后知后觉地意识到 SA(Suffix Array) 和 SAM(Suffix Automaton) 的 A 不是同一个 A
#定义
显而易见一个长度为 n 的字符串中有 n 个长度分别为 1∼n 的后缀,如果我们对其按字典序排序,分别存储下排名 i 的后缀 sai 和每个后缀 i 的排名 rki。虽然看着挺没头没尾的,但是很有用。
#求解
#哈希 + 排序
直接把所有后缀拿来排序的话,字符串比较是 O(n) 的。如果我们用哈希 + 二分优化比较过程,就可以把整个排序优化到 O(nlog2n)。
#倍增
先对所有后缀按 第一个字符 排序,记排序后排名序列为 a。
那么怎么按 前两个字符 排序呢?对于第 i 组字符,我们用 (ai,ai+1) 双关键字排序即可。记此时排名序列为 b,那么如果需要按照前四个字符排序,用 (bi,bi+2) 进行双关键字排序即可。总共需要进行 logn 次排序。复杂度为 O(nlog2n)。
此时我们注意到排名数组的值域为 n,那么我们用桶排就能少一个 log。
#实现
哈希很好实现,这里就按下不表,主要讲解倍增法的实现。
描述起来很简单,实现起来很要命。OI wiki 上的实现算是相对好理解的:
首先了解双关键字桶排的方法,首先用单关键字桶排完成对 第二关键字 的排序;对于第一关键字,令桶 i 记录前 i 个元素的数量;遍历排序后的第二关键字数组,将元素放到桶中记录数值对应的下标中,并将桶中数值 −1。实际上桶 c 充当计算下标范围的作用,(ci−1,ci] 即为 i 分布的范围。
显然,当且仅当排名种类为 n,即没有并列排名时,排序完成。设本轮区间长度为 w,对于一轮操作:
- 计算每个区间按后半段 w2 长度字符排序的结果:(n−w,n] 开头的区间后半段均为空,直接放在序列首端;接着按照上一轮 sa 结果,把能够作为后半段的元素依次放入。
- 依照上一轮的 rk 作为前半段排名,进行双关键字桶排。
- 依照 sa 和第二关键字(处理并列),求出 rk。
std::vector<int> la(n + 2);
std::copy(s.begin(), s.end(), rk.begin());
int m = 128;
{
std::vector<int> c(m + 1);
for (int i = 1; i <= n; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n; i; --i)
sa[c[rk[i]]--] = i;
}
for (int w = 1, p; ; w <<= 1, m = p) {
std::vector<int> id(1);
for (int i = n - w + 1; i <= n; ++i)
id.push_back(i);
for (int i = 1; i <= n; ++i)
if (sa[i] > w)
id.push_back(sa[i] - w);
std::vector<int> c(m + 1);
for (int i = 1; i <= n; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n; i; --i)
sa[c[rk[id[i]]]--] = id[i];
p = 0;
std::copy(rk.begin(), rk.end(), la.begin());
for (int i = 1; i <= n; ++i)
if (la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
rk[sa[i]] = p;
else
rk[sa[i]] = ++p;
if (p == n)
break;
}
查看代码
#纯 SA 的应用
#最小表示法
模板:https://www.luogu.com.cn/problem/P1368。
对于循环位移相关要求,首先考虑将字符串重复一遍。
在 ss 中找到排名第一个 sai≤n 即为答案。
#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
std::freopen(".in", "r", stdin);
std::freopen(".out", "w", stdout);
#endif
int n;
std::cin >> n;
std::vector<int> s(2 * n + 1);
for (int i = 1; i <= n; ++i)
std::cin >> s[i], s[n + i] = s[i];
std::vector<int> sa(2 * n + 1), rk(s);
{
int m = 29;
{
std::vector<int> c(m + 1);
for (int i = 1; i <= 2 * n; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = 2 * n; i; --i)
sa[c[rk[i]]--] = i;
}
for (int w = 1, p; ; w <<= 1, m = p) {
std::vector<int> id(1);
for (int i = 2 * n - w + 1; i <= 2 * n; ++i)
id.push_back(i);
for (int i = 1; i <= 2 * n; ++i)
if (sa[i] > w)
id.push_back(sa[i] - w);
std::vector<int> c(m + 1);
for (int i = 1; i <= 2 * n; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = 2 * n; i; --i)
sa[c[rk[id[i]]]--] = id[i];
auto la(rk);
p = 0;
for (int i = 1; i <= 2 * n; ++i)
if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
rk[sa[i]] = p;
else
rk[sa[i]] = ++p;
if (p == 2 * n)
break;
}
}
for (int i = 1; i <= 2 * n; ++i)
if (sa[i] <= n) {
for (int j = sa[i]; j < n + sa[i]; ++j)
std::cout << s[j] << ' ';
std::cout << '\n';
break;
}
return 0;
}
查看代码
#字符串匹配
二分,复杂度 O(|S|log|T|)。求出现次数则二分左右边界。
太麻烦了且没有实际应用价值,代码略。
#height 数组
定义 hi=lcp(sai,sai−1),特别地,h1=0。
有引理:hrki≥hrki−1−1。
假设已经求出 hrki−1,那么可以从 hrki−1−1 出发暴力看下一个字符是否相等得到答案。那么我们会发现从前往后 h 值每次最多 −1,所以复杂度摊下来是 O(n) 的。
记住记住一定是 rki−1 而不是下意识的 rki−1!!!所以为了保证求解顺序循环枚举的一定是下标而非排名。但是注意定义却是和 rki−1 的 lcp!!!所以求 height 的写法是相对固定的,不能觉得好像是对的就随便乱改。
#height 数组的应用
相当于背板子,因为应用太多且形式大多固定。
#求任意两个后缀的 lcp
易得 lcp(sai,saj)=min{hi+1,⋯,hj}。故应将一些复杂 lcp 问题的解决方式和 RMQ 联系起来。
#子串大小关系
即比较 Sl1,r1 和 Sl2,r2 的大小关系。比较导致 lcp 不能继续延伸的元素大小即可。
#本质不同子串数量
子串等价于「后缀的前缀」。按顺序枚举每个后缀,减去和已枚举的所有后缀的 lcp 即可。鉴于 min{hj+1,⋯,hi} 单调不减,直接减去 hi 即可。
最后答案即为 n(n−1)2−n∑i=2hi。
#至少出现 k 次子串的最大长度
模板:https://www.luogu.com.cn/problem/P2852。
出现 k 次 ⟺ 在后缀数组中连续出现 k 次 ⟺ 是任意连续 k−1 个 h 的最小值,需要最大化该最小值,考虑滑动窗口。
#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
std::freopen("P2852_7.in", "r", stdin);
std::freopen(".out", "w", stdout);
#endif
int n, k;
std::cin >> n >> k, --k;
std::vector<int> s(n + 1);
for (int i = 1; i <= n; ++i)
std::cin >> s[i];
std::vector<int> sa(n + 1), rk(s), h(n + 1);
{
int m = 1000001;
{
std::vector<int> c(m + 1);
for (int i = 1; i <= n; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n; i; --i)
sa[c[rk[i]]--] = i;
}
for (int w = 1, p; ; w <<= 1, m = p) {
std::vector<int> id(1);
for (int i = n - w + 1; i <= n; ++i)
id.push_back(i);
for (int i = 1; i <= n; ++i)
if (sa[i] > w)
id.push_back(sa[i] - w);
std::vector<int> c(m + 1);
for (int i = 1; i <= n; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n; i; --i)
sa[c[rk[id[i]]]--] = id[i];
auto la(rk);
p = 0;
for (int i = 1; i <= n; ++i)
if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
rk[sa[i]] = p;
else
rk[sa[i]] = ++p;
if (p == n)
break;
}
for (int i = 1, to = 0; i <= n; ++i)
if (rk[i]) {
for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
h[rk[i]] = to;
}
}
std::vector<int> q(n + 1);
int res = 0;
for (int i = 1, l = 1, r = 0; i <= n; ++i) {
// printf("%d\n", h[i]);
for (; l <= r && i - q[l] >= k; ++l);
for (; l <= r && h[i] <= h[q[r]]; --r);
q[++r] = i;
if (i >= k)
res = std::max(res, h[q[l]]);
}
std::cout << res << '\n';
return 0;
}
查看代码
#最长不重叠多次出现子串
bb:定式太多太杂以至于让人怀疑某些定式是否存在应用场景
发现满足单调性,二分子串长度 len,那么显然 lcp≥len;将 h 划分为连续 ≥len 的段,在每段内找到下标极差与 len 比较即可。
也可以用于判定是否存在不重叠多次出现子串。
甚至可以考虑限制至少出现次数为 k,那大概多个 log,看看一段里有没有 ≥k 个相互相差 ≥len 的。排序贪心求解。
那么上面的至少出现 k 次子串也可以用这个方法来解,但是多个 log 没必要。
也可以限制多次出现但长度至少为 len,那甚至少了二分的 log,直接跑一遍 check 即可。
???到底为什么会有这么多奇怪的定式,是因为真的有题这么出吗???
#最长公共子串问题
求 S 和 T 的最长公共子串(注意不是 LCS)。设 S 长为 n,T 长为 m,那么将 S 与 T 拼接,答案就是 max{lcp(i,j)},i≤n<j。
但这里不直接枚举 i 和 j,还是照例先从 h 下手再卡条件,若 sai−1≤n<sai(或者反过来),就可以用 hi 更新答案。容易证明这样总可以找到最大值。
eg1. 找相同字符
https://www.luogu.com.cn/problem/P3181
要求方案数,那么答案为 lcp(i,j),i≤n<j。(我已经帮你们试过了容斥比直接做更麻烦),考虑用单调栈维护左 / 右侧区间 lcp 求解右 / 左侧答案。关于单调栈的描述可见 本页后部内容。
#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
std::freopen(".in", "r", stdin);
std::freopen(".out", "w", stdout);
#endif
int n, n1;
std::string s, t;
std::cin >> s >> t;
n = (int)s.length(), n1 = n + (int)t.length() + 1;
s = "#" + s + "$" + t;
std::vector<int> sa(n1 + 1), rk(n1 + 1), h(n1 + 1);
{
std::copy(s.begin() + 1, s.end(), rk.begin() + 1);
int m = 128;
{
std::vector<int> c(m + 1);
for (int i = 1; i <= n1; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n1; i; --i)
sa[c[rk[i]]--] = i;
}
for (int w = 1, p; ; w <<= 1, m = p) {
std::vector<int> id(w + 1);
std::iota(id.begin() + 1, id.end(), n1 - w + 1);
for (int i = 1; i <= n1; ++i)
if (sa[i] > w)
id.push_back(sa[i] - w);
std::vector<int> c(m + 1);
for (int i = 1; i <= n1; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n1; i; --i)
sa[c[rk[id[i]]]--] = id[i];
p = 0;
auto la(rk);
for (int i = 1; i <= n1; ++i)
if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
rk[sa[i]] = p;
else
rk[sa[i]] = ++p;
if (p == n1)
break;
}
for (int i = 1, to = 0; i <= n1; ++i) {
for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
h[rk[i]] = to;
}
}
std::vector<std::pair<int, long long> > q1, q2;
std::vector<int> tot1(n1 + 1), tot2(n1 + 1);
for (int i = 1; i <= n1; ++i) {
tot1[i] = tot1[i - 1] + (sa[i] <= n);
tot2[i] = tot2[i - 1] + (sa[i] > n + 1);
}
long long res = 0ll;
q1.emplace_back(1, 0ll), q2.emplace_back(1, 0ll);
for (int i = 1; i <= n1; ++i) {
for (; !q1.empty() && h[i] < h[q1.back().first]; q1.pop_back());
q1.emplace_back(i, (tot1[i - 1] - tot1[q1.back().first - 1]) * h[i] + q1.back().second);
if (sa[i] > n + 1)
res += q1.back().second;
for (; !q2.empty() && h[i] < h[q2.back().first]; q2.pop_back());
q2.emplace_back(i, (tot2[i - 1] - tot2[q2.back().first - 1]) * h[i] + q2.back().second);
if (sa[i] <= n)
res += q2.back().second;
}
std::cout << res << '\n';
return 0;
}
查看代码
eg2. 公共串
https://www.luogu.com.cn/problem/P5546
要求多串最长公共子串,仍然考虑将多个串拼在一起。仿照前面二分的方式处理,问题转化为找到最长的 len,使得存在一段最小值 ≥len 的区间,其覆盖了 n 段串。
#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
std::freopen(".in", "r", stdin);
std::freopen(".out", "w", stdout);
#endif
int n, l = 0, r = 0;
std::cin >> n;
std::string s;
std::vector<std::pair<int, int> > lim(n + 1);
for (int i = 1; i <= n; ++i) {
std::string t;
std::cin >> t;
lim[i] = { (int)s.length() + 1, s.length() + t.length() };
s += "#" + t;
r = std::max(r, (int)t.length());
// printf("[%d, %d]\n", lim[i].first, lim[i].second);
}
int n1 = lim.back().second;
std::vector<int> sa(n1 + 1), rk(n1 + 1), h(n1 + 1);
{
std::copy(s.begin() + 1, s.end(), rk.begin() + 1);
int m = 128;
{
std::vector<int> c(m + 1);
for (int i = 1; i <= n1; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n1; i; --i)
sa[c[rk[i]]--] = i;
}
for (int w = 1, p; ; w <<= 1, m = p) {
std::vector<int> id(w + 1);
std::iota(id.begin() + 1, id.end(), n1 - w + 1);
for (int i = 1; i <= n1; ++i)
if (sa[i] > w)
id.push_back(sa[i] - w);
std::vector<int> c(m + 1);
for (int i = 1; i <= n1; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n1; i; --i)
sa[c[rk[id[i]]]--] = id[i];
p = 0;
auto la(rk);
for (int i = 1; i <= n1; ++i)
if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
rk[sa[i]] = p;
else
rk[sa[i]] = ++p;
if (p == n1)
break;
}
for (int i = 1, to = 0; i <= n1; ++i) {
for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
h[rk[i]] = to;
}
}
// for (int i = 1; i <= n1; ++i)
// printf("h[%d] = %d\n", sa[i], h[i]);
int res = 0;
auto check = [&](int len) {
// printf("check %d: \n", len);
std::vector<int> cnt(n + 1);
for (int i = 1; i <= n1; ++i) {
if (h[i] < len) {
if (*std::min_element(cnt.begin() + 1, cnt.end()))
return 1;
cnt.assign(n + 1, 0);
}
else
for (int j = 1; j <= n; ++j) {
if (lim[j].first <= sa[i - 1] && sa[i - 1] <= lim[j].second)
cnt[j] = 1;
if (lim[j].first <= sa[i] && sa[i] <= lim[j].second)
cnt[j] = 1;
}
}
// printf("\n%d\n", *std::min_element(cnt.begin() + 1, cnt.end()));
return *std::min_element(cnt.begin() + 1, cnt.end());
};
for (int mid; l <= r; ) {
mid = (l + r) >> 1;
if (check(mid))
l = mid + 1, res = mid;
else
r = mid - 1;
}
std::cout << res << '\n';
return 0;
}
查看代码
但是看了题解发现居然还有线性做法(当然不看建 SA 的 log),对于覆盖全部 n 段串找区间最小值,发现需要最小化区间,考虑双指针。
区间最小值用单调队列求解,细想可能会觉得不太对劲,但是容易证明答案不大于队首且不小于最大队首,所以最大队首就是答案。
#include <bits/stdc++.h>
int main() {
#ifdef ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
std::freopen(".in", "r", stdin);
std::freopen(".out", "w", stdout);
#endif
int n, l = 0, r = 0;
std::cin >> n;
std::string s;
std::vector<std::pair<int, int> > lim(n + 1);
for (int i = 1; i <= n; ++i) {
std::string t;
std::cin >> t;
lim[i] = { (int)s.length() + 1, s.length() + t.length() };
s += (char)('A' + i - 1) + t;
r = std::max(r, (int)t.length());
}
int n1 = lim.back().second;
std::vector<int> sa(n1 + 1), rk(n1 + 1), h(n1 + 1);
{
std::copy(s.begin() + 1, s.end(), rk.begin() + 1);
int m = 128;
{
std::vector<int> c(m + 1);
for (int i = 1; i <= n1; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n1; i; --i)
sa[c[rk[i]]--] = i;
}
for (int w = 1, p; ; w <<= 1, m = p) {
std::vector<int> id(w + 1);
std::iota(id.begin() + 1, id.end(), n1 - w + 1);
for (int i = 1; i <= n1; ++i)
if (sa[i] > w)
id.push_back(sa[i] - w);
std::vector<int> c(m + 1);
for (int i = 1; i <= n1; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n1; i; --i)
sa[c[rk[id[i]]]--] = id[i];
p = 0;
auto la(rk);
for (int i = 1; i <= n1; ++i)
if (i != 1 && la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
rk[sa[i]] = p;
else
rk[sa[i]] = ++p;
if (p == n1)
break;
}
for (int i = 1, to = 0; i <= n1; ++i) {
for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
h[rk[i]] = to;
}
}
int res = 0;
std::vector<int> q(n1 + 1), cnt(n + 1);
// for (int i = 1; i <= n1; ++i)
// printf("%d: %d\n", sa[i], h[i]);
for (int l = 1, r = 0, ql = 1, qr = 0; l <= n1; ++l) {
for (; r < n1 && !*std::min_element(cnt.begin() + 1, cnt.end()); ) {
++r;
for (int i = 1; i <= n; ++i)
if (lim[i].first <= sa[r] && sa[r] <= lim[i].second) {
++cnt[i];
break;
}
for (; ql <= qr && h[r] <= h[q[qr]]; --qr);
q[++qr] = r;
}
if (*std::min_element(cnt.begin() + 1, cnt.end())) {
// printf("[%d, %d]: %d\n", l, r, h[q[ql]]);
res = std::max(res, h[q[ql]]);
}
for (; ql <= qr && q[ql] <= l; ++ql);
if (l != 1) {
for (int i = 1; i <= n; ++i)
if (lim[i].first <= sa[l - 1] && sa[l - 1] <= lim[i].second) {
--cnt[i];
break;
}
}
}
std::cout << res << '\n';
return 0;
}
查看代码
#AA 式子串处理
即对于连续相同子串问题的处理,有一个定的思路,由例题分析。
eg1. 优秀的拆分
https://www.luogu.com.cn/problem/P1117
还是从中间分开,按前后分别处理。这里有个 trick,我们枚举 B 的长度 len,在 S 中每隔 len 打一个标记。那么显然,任意一个长度为 2×len 的子串都会经过恰好两个标记(充分的),这样就可以筛选出所有可能的串。
我们枚举所有连续两个标记(总复杂度为调和级数),求它们对应后缀的 lcp 和对应前缀的 lcs(翻转求 SA 即可),如果二者加起来 ≥len 就说明存在这样的 AA。在 lcs+lcp 中任取 len 长度即为一对 AA。用差分给可能的起点和终点区间加即可。
小细节:lcp 和 lcs 均需要对 len 取 min,否则取到的串可能不会经过当前选中的两个标记。
#include <bits/stdc++.h>
class SA {
public:
std::vector<int> sa, rk, h;
std::vector<std::vector<int> > st;
SA(int n, std::string s): sa(n + 1), rk(n + 2), h(n + 1), st(20, std::vector<int> (n + 1)) {
std::vector<int> la(n + 2);
std::copy(s.begin(), s.end(), rk.begin());
int m = 128;
{
std::vector<int> c(m + 1);
for (int i = 1; i <= n; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n; i; --i)
sa[c[rk[i]]--] = i;
}
for (int w = 1, p; ; w <<= 1, m = p) {
std::vector<int> id(1);
for (int i = n - w + 1; i <= n; ++i)
id.push_back(i);
for (int i = 1; i <= n; ++i)
if (sa[i] > w)
id.push_back(sa[i] - w);
std::vector<int> c(m + 1);
for (int i = 1; i <= n; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n; i; --i)
sa[c[rk[id[i]]]--] = id[i];
p = 0;
std::copy(rk.begin(), rk.end(), la.begin());
for (int i = 1; i <= n; ++i)
if (la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
rk[sa[i]] = p;
else
rk[sa[i]] = ++p;
if (p == n)
break;
}
for (int i = 1, to = 0; i <= n; ++i)
if (rk[i]) {
for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
h[rk[i]] = to;
}
for (int i = 1; i <= n; ++i)
st[0][i] = h[i];
for (int j = 1; (1 << j) <= n; ++j)
for (int i = 1; i + (1 << j) - 1 <= n; ++i)
st[j][i] = std::min(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
rk.emplace_back();
return;
}
private:
int ask(int l, int r) {
// fprintf(stderr, "l = %d, r = %d\n", l, r);
int k = std::__lg(r - l + 1);
return std::min(st[k][l], st[k][r - (1 << k) + 1]);
}
public:
int lcp(int l, int r) {
return ask(std::min(rk[l], rk[r]) + 1, std::max(rk[l], rk[r]));
}
};
int main() {
#ifdef ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
std::freopen(".in", "r", stdin);
std::freopen(".out", "w", stdout);
#endif
int T;
for (std::cin >> T; T--; ) {
std::string s;
std::cin >> s;
int n = (int)s.length();
s = "#" + s;
SA p(n, s);
std::reverse(s.begin() + 1, s.end());
SA q(n, s);
std::vector<int> f(n + 2), g(n + 2);
for (int len = 1; len <= n / 2; ++len)
for (int i = len; i + len <= n; i += len) {
int l = i, r = i + len, lcp = std::min(len, p.lcp(l, r)), lcs = std::min(len - 1, q.lcp(n - l + 2, n - r + 2));
if (lcp + lcs >= len) {
int t = lcp + lcs - len + 1;
// fprintf(stderr, "(%d, %d), %d, %d\n", l, r, lcp, lcs);
++g[l - lcs], --g[l - lcs + t], ++f[r + lcp - t], --f[r + lcp];
}
}
std::partial_sum(f.begin(), f.end(), f.begin());
std::partial_sum(g.begin(), g.end(), g.begin());
long long res = 0ll;
for (int i = 1; i < n; ++i)
res += (long long)f[i] * g[i + 1];
std::cout << res << '\n';
}
return 0;
}
查看代码
eg2. tandem
https://www.codechef.com/problems/TANDEM
注意到多了一个限制,前一个好处理,找到经过 3 个标记的串即可。对于后一个限制,画图可以发现对于 interesting ones,每次只会出现最多一个;当且仅当 lcp>len 时不存在。
对于 uninteresting ones,用每次能提供的总数减去 interesting ones 的数量即可。
#include <bits/stdc++.h>
class SA {
public:
std::vector<int> sa, rk, h;
std::vector<std::vector<int> > st;
SA(int n, std::string s): sa(n + 1), rk(n + 2), h(n + 1), st(20, std::vector<int> (n + 1)) {
std::vector<int> la(n + 2);
std::copy(s.begin(), s.end(), rk.begin());
int m = 128;
{
std::vector<int> c(m + 1);
for (int i = 1; i <= n; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n; i; --i)
sa[c[rk[i]]--] = i;
}
for (int w = 1, p; ; w <<= 1, m = p) {
std::vector<int> id(1);
for (int i = n - w + 1; i <= n; ++i)
id.push_back(i);
for (int i = 1; i <= n; ++i)
if (sa[i] > w)
id.push_back(sa[i] - w);
std::vector<int> c(m + 1);
for (int i = 1; i <= n; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n; i; --i)
sa[c[rk[id[i]]]--] = id[i];
p = 0;
std::copy(rk.begin(), rk.end(), la.begin());
for (int i = 1; i <= n; ++i)
if (la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
rk[sa[i]] = p;
else
rk[sa[i]] = ++p;
if (p == n)
break;
}
for (int i = 1, to = 0; i <= n; ++i)
if (rk[i]) {
for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
h[rk[i]] = to;
}
for (int i = 1; i <= n; ++i)
st[0][i] = h[i];
for (int j = 1; (1 << j) <= n; ++j)
for (int i = 1; i + (1 << j) - 1 <= n; ++i)
st[j][i] = std::min(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
rk.emplace_back();
return;
}
private:
int ask(int l, int r) {
// fprintf(stderr, "l = %d, r = %d\n", l, r);
int k = std::__lg(r - l + 1);
return std::min(st[k][l], st[k][r - (1 << k) + 1]);
}
public:
int lcp(int l, int r) {
return ask(std::min(rk[l], rk[r]) + 1, std::max(rk[l], rk[r]));
}
};
int main() {
#ifdef ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
std::freopen(".in", "r", stdin);
std::freopen(".out", "w", stdout);
#endif
std::string s;
std::cin >> s;
int n = (int)s.length();
s = "#" + s;
SA p(n, s);
std::reverse(s.begin() + 1, s.end());
SA q(n, s);
std::reverse(s.begin() + 1, s.end());
long long res1 = 0ll, res2 = 0ll;
for (int len = 1; len <= n; ++len)
for (int i = len, j = 2 * len, k = 3 * len; k <= n; i += len, j += len, k += len) {
int lcp = std::min(p.lcp(i, j), p.lcp(j, k)), lcs = std::min({ len - 1, q.lcp(n - i + 2, n - j + 2), q.lcp(n - j + 2, n - k + 2) });
if (std::min(len, lcp) + lcs >= len) {
// printf("(%d, %d, %d), %d, %d, %d\n", i, j, k, lcs, lcp, len);
int t = (lcp <= len);
res1 += t, res2 += std::min(len, lcp) + lcs - len + 1 - t;
}
// else
// printf("# (%d, %d, %d), %d, %d, %d\n", i, j, k, lcs, lcp, len);
}
std::cout << res1 << ' ' << res2 << '\n';
return 0;
}
查看代码
eg3. repeats
https://www.spoj.com/problems/REPEATS/
重复次数最多,只需经过标记点最多。显然经过标记点的数量就是该字符串长除以 len 向下取整就可以得到重复次数减 1 的值。
选择两个连续标记点,对于 lcp 和 lcs(显然此时不需要对 len 取 min),计算 lcp+lcslen+1 取最大即可。
#include <bits/stdc++.h>
class SA {
public:
std::vector<int> sa, rk, h;
std::vector<std::vector<int> > st;
SA(int n, std::string s): sa(n + 1), rk(n + 2), h(n + 1), st(20, std::vector<int> (n + 1)) {
std::vector<int> la(n + 2);
std::copy(s.begin(), s.end(), rk.begin());
int m = 128;
{
std::vector<int> c(m + 1);
for (int i = 1; i <= n; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n; i; --i)
sa[c[rk[i]]--] = i;
}
for (int w = 1, p; ; w <<= 1, m = p) {
std::vector<int> id(1);
for (int i = n - w + 1; i <= n; ++i)
id.push_back(i);
for (int i = 1; i <= n; ++i)
if (sa[i] > w)
id.push_back(sa[i] - w);
std::vector<int> c(m + 1);
for (int i = 1; i <= n; ++i)
++c[rk[i]];
std::partial_sum(c.begin(), c.end(), c.begin());
for (int i = n; i; --i)
sa[c[rk[id[i]]]--] = id[i];
p = 0;
std::copy(rk.begin(), rk.end(), la.begin());
for (int i = 1; i <= n; ++i)
if (la[sa[i]] == la[sa[i - 1]] && la[sa[i] + w] == la[sa[i - 1] + w])
rk[sa[i]] = p;
else
rk[sa[i]] = ++p;
if (p == n)
break;
}
for (int i = 1, to = 0; i <= n; ++i)
if (rk[i]) {
for (to = std::max(to - 1, 0); s[i + to] == s[sa[rk[i] - 1] + to]; ++to);
h[rk[i]] = to;
}
for (int i = 1; i <= n; ++i)
st[0][i] = h[i];
for (int j = 1; (1 << j) <= n; ++j)
for (int i = 1; i + (1 << j) - 1 <= n; ++i)
st[j][i] = std::min(st[j - 1][i], st[j - 1][i + (1 << (j - 1))]);
rk.emplace_back();
return;
}
private:
int ask(int l, int r) {
// fprintf(stderr, "l = %d, r = %d\n", l, r);
int k = std::__lg(r - l + 1);
return std::min(st[k][l], st[k][r - (1 << k) + 1]);
}
public:
int lcp(int l, int r) {
return ask(std::min(rk[l], rk[r]) + 1, std::max(rk[l], rk[r]));
}
};
int main() {
#ifdef ONLINE_JUDGE
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr), std::cout.tie(nullptr);
#else
std::freopen(".in", "r", stdin);
std::freopen(".out", "w", stdout);
#endif
int T;
for (std::cin >> T; T--; ) {
int n;
std::cin >> n;
std::string s = "#";
for (int i = 1; i <= n; ++i) {
char t;
std::cin >> t;
s.push_back(t);
}
SA p(n, s);
std::reverse(s.begin() + 1, s.end());
SA q(n, s);
int res = 0;
for (int len = 1; len <= n; ++len)
for (int i = len, j = 2 * len; j <= n; i += len, j += len) {
int lcp = p.lcp(i, j), lcs = q.lcp(n - i + 2, n - j + 2);
if (lcp + lcs >= len)
res = std::max(res, (lcp + lcs) / len + 1);
}
std::cout << res << '\n';
}
return 0;
}