HDu4416 Good Article Good sentence

后缀自动机求母串不包含给定串子串的不同子串数量

首先,需要知道求两个串最长公共子串,即维护第二个串的每个前缀与第一个串的最长公共后缀。

对母串建一个 $sam$,所有串在 $sam$ 上面跑,并维护每个 $sam$ 上每个状态与所有的最长公共后缀长度,因此每个状态对答案的贡献为 $len(s) - dep(s)$。

可以理解原来状态的对子串数量的贡献为 $len(s) - len(link(s))$,$s$ 代表的子串长度区间中每个后缀都对答案有贡献,现在我们将这个区间拆成两块更新答案。

具体来讲,在获取所有点的 $dep$ 值后,基数排序获得 $sam$ 的状态的逆序拓扑排序。

因为需要对于每个状态更新他父亲节点的 $dep$ 值,所以遍历逆序拓扑排序。

这里需要同时从自动机和后缀树两个角度同时思考,当前节点的父亲也是当前节点的后缀,但是计算 $dep$ 时是自动机视角,所以当前节点和他的父亲计算的 $dep$ 没有影响,但实际上当前节点 $dep$ 值对他的父亲是有影响的,例如当前 $dep$ 超过了父亲节点的 $len$,父亲节点对答案是没有贡献的。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define ms(a,b) memset(a,b,sizeof(a))
using namespace std;
typedef long long ll;
const int maxn = 100000 + 5;

ll ans[maxn << 1];

namespace sam{
int len[maxn << 1], link[maxn << 1], ch[maxn << 1][26], tot, last;
void init(){ ms(ans, -1); ms(ch, 0); tot = last = 1; }
void insert(int c){
int cur = ++tot, p = last;
len[cur] = len[last] + 1;
for (; p && !ch[p][c]; p = link[p]) ch[p][c] = cur;
if (!p) link[cur] = 1;
else {
int q = ch[p][c];
if (len[p] + 1 == len[q]) link[cur] = q;
else {
int nq = ++tot; len[nq] = len[p] + 1;
link[nq] = link[q]; link[q] = link[cur] = nq;
memcpy(ch[nq], ch[q], sizeof ch[q]);
for (; ch[p][c] == q; p = link[p]) ch[p][c] = nq;
}
}
last = cur;
}

int c[maxn << 1], a[maxn << 1];
void rsort(){
for (int i = 0; i <= tot; i++) c[i] = 0;
for (int i = 1; i <= tot; i++) c[len[i]]++;
for (int i = 1; i <= tot; i++) c[i] += c[i - 1];
for (int i = 1; i <= tot; i++) a[c[len[i]]--] = i;
}
}
using namespace sam;

char s[maxn];
int n, dep[maxn << 1];

int main(){
int T, kase = 0; scanf("%d", &T);
while (T--){
scanf("%d%s", &n, s);
init(); ms(dep, 0);
for (int i = 0; s[i]; i++) insert(s[i] - 'a');
while (n--){
scanf("%s", s);
int now = 1, l = 0;
for (int i = 0; s[i]; i++){
int c = s[i] - 'a';
if (ch[now][c]) l++, now = ch[now][c], dep[now] = max(dep[now], l);
else {
while (now && !ch[now][c]) now = link[now];
if (!now) now = 1, l = 0;
else l = len[now] + 1, now = ch[now][c], dep[now] = max(dep[now], l);
}
}
}
rsort(); ll ans = 0;
for (int i = tot; i; i--){
int now = a[i];
if (dep[now]){
dep[link[now]] = max(dep[link[now]], dep[now]);
if (dep[now] < len[now]) ans += len[now] - dep[now];
}
else ans += len[now] - len[link[now]];
}
printf("Case %d: %lld\n", ++kase, ans);
}
return 0;
}