HDu5129 Yong Zheng's Death 题解

题意

给了 $n$ 个串,要求从任意两个串中拿出一对前缀拼在一起的不同串数,可以同时选一个。

其中 $1 \le n \le 10^4$,串长至多为 $30$。

分析

感觉上是 Trie 树上结点数的平方,实际上有一堆重复。

重复的本质是一个 Trie 树结点的祖先,祖先到当前结点的路径是整个 Tire 树的某一个前缀。

构建出 $fail$ 树,$dfs$ 找到子树大小。

去重时,枚举每个结点,使用 $fail$ 指针向上跳到那个最远祖先,这个祖先内的所有串不包括自身,均重复了一次。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
#include <utility>
#include <queue>
#define ms(a,b) memset(a,b,sizeof(a))
using namespace std;
typedef long long ll;
typedef pair<int,int> PII;
const int mod = 998244353;
const int inf = 1 << 30;
const int maxn = 100000 + 5;

struct ACAM {
static const int maxp = 500000 + 5;
int ch[maxp][26], val[maxp], fail[maxp], fa[maxp], sz;
void clear() {
for (int i = 0; i <= sz; i++) {
ms(ch[i], 0); fail[i] = fa[i] = 0;
}
sz = 1;
for (int i = 0; i < 26; i++) ch[0][i] = 1;
}
ACAM() {
clear();
}
void insert(char* s, int x) {
int len = strlen(s), now = 1;
for (int i = 0; i < len; i++) {
int v = s[i] - 'a';
if (!ch[now][v]) ch[now][v] = ++sz, fa[sz] = now;
now = ch[now][v];
}
}
int siz[maxp], deg[maxp];
vector<int> edge[maxp];
void dfs(int u) {
siz[u] = 1;
for (int& v: edge[u]) {
dfs(v); siz[u] += siz[v];
}
}
void build() {
queue<int> q; q.push(1);
while (!q.empty()) {
int t = q.front(); q.pop();
for (int i = 0; i < 26; i++) {
if (ch[t][i]) {
fail[ch[t][i]] = ch[fail[t]][i];
q.push(ch[t][i]);
} else {
ch[t][i] = ch[fail[t]][i];
}
}
}
for (int i = 0; i <= sz; i++) edge[i].clear();
for (int i = 2; i <= sz; i++) edge[fail[i]].push_back(i);
dfs(1);
}
ll cal() {
ll ans = 1ll * (sz - 1) * (sz - 1);
for (int i = 1; i <= sz; i++) {
if (fail[i] <= 1) continue;
int x = i, y = fail[i];
while (y > 1) y = fa[y], x = fa[x];
ans -= siz[x] - 1;
}
return ans;
}
} f;

int n;
char s[50];

int main() {
while (scanf("%d", &n) == 1 && n) {
f.clear();
for (int i = 1; i <= n; i++) {
scanf("%s", s); f.insert(s, i);
}
f.build();
printf("%lld\n", f.cal());
}
return 0;
}