题意
给了 $n$ 个串,要求从任意两个串中拿出一对前缀拼在一起的不同串数,可以同时选一个。
其中 $1 \le n \le 10^4$,串长至多为 $30$。
分析
感觉上是 Trie 树上结点数的平方,实际上有一堆重复。
重复的本质是一个 Trie 树结点的祖先,祖先到当前结点的路径是整个 Tire 树的某一个前缀。
构建出 $fail$ 树,$dfs$ 找到子树大小。
去重时,枚举每个结点,使用 $fail$ 指针向上跳到那个最远祖先,这个祖先内的所有串不包括自身,均重复了一次。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
| #include <iostream> #include <cstdio> #include <cstring> #include <cmath> #include <algorithm> #include <vector> #include <utility> #include <queue> #define ms(a,b) memset(a,b,sizeof(a)) using namespace std; typedef long long ll; typedef pair<int,int> PII; const int mod = 998244353; const int inf = 1 << 30; const int maxn = 100000 + 5;
struct ACAM { static const int maxp = 500000 + 5; int ch[maxp][26], val[maxp], fail[maxp], fa[maxp], sz; void clear() { for (int i = 0; i <= sz; i++) { ms(ch[i], 0); fail[i] = fa[i] = 0; } sz = 1; for (int i = 0; i < 26; i++) ch[0][i] = 1; } ACAM() { clear(); } void insert(char* s, int x) { int len = strlen(s), now = 1; for (int i = 0; i < len; i++) { int v = s[i] - 'a'; if (!ch[now][v]) ch[now][v] = ++sz, fa[sz] = now; now = ch[now][v]; } } int siz[maxp], deg[maxp]; vector<int> edge[maxp]; void dfs(int u) { siz[u] = 1; for (int& v: edge[u]) { dfs(v); siz[u] += siz[v]; } } void build() { queue<int> q; q.push(1); while (!q.empty()) { int t = q.front(); q.pop(); for (int i = 0; i < 26; i++) { if (ch[t][i]) { fail[ch[t][i]] = ch[fail[t]][i]; q.push(ch[t][i]); } else { ch[t][i] = ch[fail[t]][i]; } } } for (int i = 0; i <= sz; i++) edge[i].clear(); for (int i = 2; i <= sz; i++) edge[fail[i]].push_back(i); dfs(1); } ll cal() { ll ans = 1ll * (sz - 1) * (sz - 1); for (int i = 1; i <= sz; i++) { if (fail[i] <= 1) continue; int x = i, y = fail[i]; while (y > 1) y = fa[y], x = fa[x]; ans -= siz[x] - 1; } return ans; } } f;
int n; char s[50];
int main() { while (scanf("%d", &n) == 1 && n) { f.clear(); for (int i = 1; i <= n; i++) { scanf("%s", s); f.insert(s, i); } f.build(); printf("%lld\n", f.cal()); } return 0; }
|