yww 与树上的回文串 题解

题意

给定一棵带边权 ${ 0, 1}$ 的无根树,求回文路径的个数。

其中 $3 \le n \le 5\times 10^4$。

分析

点分治,问题转化为求有多少条过当前重心的路径,使得其构成回文串。

将从重心开始的所有结点对应的串,插入 Trie 树中,最坏情况下 Trie 树与原树同构。

分析跨越重心 $root$ 路径的回文串结构,记两个端点分别为 $u,v$,其中 $u$ 的深度大于等于 $v$ 的深度。

深度相同,则 $root$ 到 $u$ 路径对应串 (简记为 $u$ 串) 必须等于 $root$ 到 $v$ 路径对应串 (简记为 $v$ 串)。令枚举的每个结点的出现次数为 $cnt$,这部分答案贡献为 $cnt \choose 2$。

深度不同,则 $v$ 串等于 $u$ 串的后缀,且 $u$ 串的前缀是一个回文串。

因此,我们考虑对 Trie 树构建 AC 自动机,则 $u$ 串的所有在 Trie 树中有对应结点的后缀就是 fail 树上的所有祖先。注意到,fail 树上的祖先可能很多,无法枚举并检查当前前缀是否为回文串。但是,根据 $Border$ 理论,一个串的回文后缀(前缀)可以表示为 $\log n$ 段等差数列。于是,我们处理所有 Trie 树结点的等差数列 $(l,r,d)$,对应表示首项、末项和公差。这可以在构建完 Trie 树后,dfs Trie 树,并维护正串和反串的哈希值得到。

最后,枚举所有 Trie 树结点,令其深度为 $len$,再枚举其等差数列,只需要询问 fail 树的祖先上长度为等差数列 $(len-r,len-l,d)$ 的出现次数和,进一步转化可以变成求所有长度 $l \equiv ((len-l) \bmod d) ( \bmod d)$ 的点值和,考虑根号分块容易维护,但是可能会取到一些等差数列值域外的点,可以将询问离线的最大长度上,并拆成两个前缀和,在 dfs fail 树的栈上二分即可获得对应结点。

总结

  1. 点分治。

  2. 插入从重心开始的 Trie 树。

  3. dfs Trie 树,求出每个点回文前缀的等差数列。

  4. dfs fail 树,维护栈上的结点,维护栈上支持快速查询模 $b$ 余 $r$ 处点值的分块数据结构。枚举等差数列,对于公差小的,栈上二分找到询问离线点,对于公差大的,暴力枚举。

  5. 枚举 Trie 树结点,计算回文前缀为空和不为空的匹配方案数。

  6. 点分治容斥掉来自同一子树的配对答案数。

时间复杂度:$T(n)=2T({n \over 2})+O(n \sqrt{n}+n\log^2 n)$,$T(n)=O(n \sqrt{n})$。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
#include <string>
#include <utility>
#include <queue>
#ifdef XLor
#define dbg(args...) do { cout << #args << " -> "; err(args); } while (0)
void err() { std::cout << std::endl; }
template<typename T, typename...Args>
void err(T a, Args...args) { std::cout << a << ' '; err(args...); }
#else
#define dbg(...)
#endif
#define ms(a,b) memset(a,b,sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PII;
const int mod = 998244353;
const int inf = 1 << 30;
const int maxn = 50000 + 5;
const int seed = 53;
const int p1 = 1e9 + 7, p2 = 1e9 + 9;

int n, vis[maxn], siz[maxn], mn, rt, sum;
ll ans;
vector<PII> edge[maxn];

void getrt(int u, int f) {
siz[u] = 1; int m = 0;
for (auto& x: edge[u]) {
int v = x.first;
if (v == f || vis[v]) continue;
getrt(v, u);
siz[u] += siz[v];
m = max(m, siz[v]);
}
m = max(m, sum - siz[u]);
if (m < mn) {
mn = m; rt = u;
}
}
int getrt(int u) {
sum = siz[u]; mn = 1e9;
getrt(u, 0); return rt;
}

struct Border {
int l, r, d;
};

struct Query {
int b, r, id, sgn;
};

struct DS {
int n, B, a[250][250], b[maxn];
void set(int nn) {
n = nn; B = sqrt(n);
}
void update(int l, int x) {
for (int i = 1; i <= B; i++) {
a[i][l % i] += x;
}
b[l] += x;
}
} g;

namespace acam {
int sz, ch[maxn][26], fail[maxn];
int cnt[maxn], len[maxn], ans[maxn];
vector<Border> bds[maxn];
vector<Query> que[maxn];
vector<int> fG[maxn];
int node(int l) {
ms(ch[++sz], 0);
len[sz] = l;
fail[sz] = cnt[sz] = ans[sz] = 0;
fG[sz].clear(); bds[sz].clear(); que[sz].clear();
return sz;
}
void clear() {
sz = 0; node(0);
for (int i = 0; i < 26; i++) ch[0][i] = 1;
}
void build() {
queue<int> q; q.push(1);
while (!q.empty()) {
int t = q.front(); q.pop();
for (int i = 0; i < 26; i++) {
if (ch[t][i]) {
fail[ch[t][i]] = ch[fail[t]][i];
q.push(ch[t][i]);
} else {
ch[t][i] = ch[fail[t]][i];
}
}
}
for (int i = 2; i <= sz; i++) fG[fail[i]].push_back(i);
}

vector<int> stk;
int get(int x) {
int l = 0, r = (int)stk.size() - 1, ans = -1;
while (l <= r) {
int m = (l + r) / 2;
if (len[stk[m]] <= x) ans = stk[m], l = m + 1;
else r = m - 1;
}
return ans;
}
void dfs(int u) {
stk.push_back(u);
g.update(len[u], cnt[u]);
for (auto& x: bds[u]) {
if (x.d <= g.B) {
int a = get(len[u] - x.r - 1);
int b = get(len[u] - x.l);
if (a > 0) que[a].push_back({ x.d, (len[u] - x.l) % x.d, u, -1 });
if (b > 0) que[b].push_back({ x.d, (len[u] - x.l) % x.d, u, +1 });
} else {
for (int i = x.l; i <= x.r; i += x.d) {
ans[u] += g.b[len[u] - i];
}
}
}
for (int v: fG[u]) dfs(v);
for (auto& q: que[u]) {
ans[q.id] += q.sgn * g.a[q.b][q.r];
}
que[u].clear();
g.update(len[u], -cnt[u]);
stk.pop_back();
}
}
using acam::ch;
using acam::len;
using acam::node;

void predfs(int u, int f, int nd) {
acam::cnt[nd]++;
for (auto& x: edge[u]) {
int v = x.first, w = x.second;
if (v == f || vis[v]) continue;
if (!ch[nd][w]) {
ch[nd][w] = node(len[nd] + 1);
}
predfs(v, u, ch[nd][w]);
}
}
ull xp1[maxn], xp2[maxn];
void dfs(int u, ull h1, ull h2, ull h3, ull h4, vector<Border> vec) {
for (int i = 0; i < 2; i++) {
if (!ch[u][i]) continue;
int v = ch[u][i];
ull nh1 = (h1 + xp1[len[u]] * i) % p1;
ull nh2 = (h2 * seed + i) % p1;
ull nh3 = (h3 + xp2[len[u]] * i) % p2;
ull nh4 = (h4 * seed + i) % p2;
acam::bds[v] = vec;
if (nh1 == nh2 && nh3 == nh4) {
auto& bd = acam::bds[v];
if (bd.empty()) {
bd.push_back({ len[v], len[v], len[v] });
} else {
Border& back = bd.back();
if (len[v] - back.r == back.d) back.r = len[v];
else {
bd.push_back({ len[v], len[v], len[v] - back.r });
}
}
}
dfs(v, nh1, nh2, nh3, nh4, acam::bds[v]);
}
}
ll cal(int u, int f) {
acam::clear();
ll ans = 0;
if (f == -1) {
predfs(u, 0, 1);
} else {
ch[1][f] = node(1);
predfs(u, 0, ch[1][f]);
}
dfs(1, 0, 0, 0, 0, vector<Border>(0));
acam::build();
g.set(siz[u]);
acam::dfs(1);
auto C = [](int n) -> ll {
return 1ll * n * (n - 1) / 2;
};
for (int i = 2; i <= acam::sz; i++) {
ans += 1ll * acam::cnt[i] * acam::ans[i] + C(acam::cnt[i]);
}
return ans;
}
void solve(int u) {
vis[u] = 1;
ans += cal(u, -1);
for (auto& x: edge[u]) {
int v = x.first;
if (vis[v]) continue;
ans -= cal(v, x.second);
solve(getrt(v));
}
}

int main() {
scanf("%d", &n);
xp1[0] = xp2[0] = 1;
for (int i = 1; i < maxn; i++) {
xp1[i] = xp1[i - 1] * seed % p1;
xp2[i] = xp2[i - 1] * seed % p2;
}
for (int i = 1, u, v, w; i < n; i++) {
scanf("%d%d%d", &u, &v, &w);
edge[u].push_back({v, w});
edge[v].push_back({u, w});
}
siz[1] = n;
solve(getrt(1));
cout << ans;
return 0;
}