题意
给了 $n$ 个串,要求从任意两个串中拿出一对前缀拼在一起的不同串数,可以同时选一个。
其中 $1 \le n \le 10^4$,串长至多为 $30$。
分析
感觉上是 Trie 树上结点数的平方,实际上有一堆重复。
重复的本质是一个 Trie 树结点的祖先,祖先到当前结点的路径是整个 Tire 树的某一个前缀。
构建出 $fail$ 树,$dfs$ 找到子树大小。
去重时,枚举每个结点,使用 $fail$ 指针向上跳到那个最远祖先,这个祖先内的所有串不包括自身,均重复了一次。
代码
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 
 | #include <iostream>#include <cstdio>
 #include <cstring>
 #include <cmath>
 #include <algorithm>
 #include <vector>
 #include <utility>
 #include <queue>
 #define ms(a,b) memset(a,b,sizeof(a))
 using namespace std;
 typedef long long ll;
 typedef pair<int,int> PII;
 const int mod = 998244353;
 const int inf = 1 << 30;
 const int maxn = 100000 + 5;
 
 struct ACAM {
 static const int maxp = 500000 + 5;
 int ch[maxp][26], val[maxp], fail[maxp], fa[maxp], sz;
 void clear() {
 for (int i = 0; i <= sz; i++) {
 ms(ch[i], 0); fail[i] = fa[i] = 0;
 }
 sz = 1;
 for (int i = 0; i < 26; i++) ch[0][i] = 1;
 }
 ACAM() {
 clear();
 }
 void insert(char* s, int x) {
 int len = strlen(s), now = 1;
 for (int i = 0; i < len; i++) {
 int v = s[i] - 'a';
 if (!ch[now][v]) ch[now][v] = ++sz, fa[sz] =  now;
 now = ch[now][v];
 }
 }
 int siz[maxp], deg[maxp];
 vector<int> edge[maxp];
 void dfs(int u) {
 siz[u] = 1;
 for (int& v: edge[u]) {
 dfs(v); siz[u] += siz[v];
 }
 }
 void build() {
 queue<int> q; q.push(1);
 while (!q.empty()) {
 int t = q.front(); q.pop();
 for (int i = 0; i < 26; i++) {
 if (ch[t][i]) {
 fail[ch[t][i]] = ch[fail[t]][i];
 q.push(ch[t][i]);
 } else {
 ch[t][i] = ch[fail[t]][i];
 }
 }
 }
 for (int i = 0; i <= sz; i++) edge[i].clear();
 for (int i = 2; i <= sz; i++) edge[fail[i]].push_back(i);
 dfs(1);
 }
 ll cal() {
 ll ans = 1ll * (sz - 1) * (sz - 1);
 for (int i = 1; i <= sz; i++) {
 if (fail[i] <= 1) continue;
 int x = i, y = fail[i];
 while (y > 1) y = fa[y], x = fa[x];
 ans -= siz[x] - 1;
 }
 return ans;
 }
 } f;
 
 int n;
 char s[50];
 
 int main() {
 while (scanf("%d", &n) == 1 && n) {
 f.clear();
 for (int i = 1; i <= n; i++) {
 scanf("%s", s); f.insert(s, i);
 }
 f.build();
 printf("%lld\n", f.cal());
 }
 return 0;
 }
 
 |