Codeforces 1097G Vladislav and a Great Legend 题解

传送门:Hello 2019 G. Vladislav and a Great Legend

题面

给一棵有根树 $T=(V,E)$,$X$ 是 $V$ 的非空子集,记 $f(X)$ 表示 $T$ 中使得点集 $X$ 联通的最小联通子图中的边数,求

$$
\sum_{X\subseteq V,X \neq \varnothing} (f(X))^k
$$

其中,$3 \le |V| \le 10^5,1\le k \le 200$ 。

分析

显然,$k=1$ 直接扫一遍算贡献即可。

考虑 $k>1$ 的情况,由第二类斯特林数展开

$$
(f(X))^k=\sum_{i=0}^{k}S(k,i)i!{f(X) \choose i}
$$

前面一部分可以预处理,也就是要求对于 $i (0\le i \le k)​$

$$
\sum_{X\subseteq V,X \neq \varnothing} {f(X) \choose i}
$$

设 $E’ \subseteq E$ 且 $|E’|=i$,$E_X$ 为 $f(X)$ 对应联通块的边集,上式可以写成

$$
\sum_{X\subseteq V,X \neq \varnothing} \sum_{e_j \in E’,1\le j \le i} [e_1,e_2,\dots,e_i \in E_X]
$$

于是,交换求和顺序

$$
\sum_{e_j \in E’,1\le j \le i} \sum_{X\subseteq V,X \neq \varnothing} [e_1,e_2,\dots,e_i \in E_X]
$$

上式含义是枚举边集 $E$ 的大小为 $i$ 的子集 $E’$,计算 $E’$ 出现在多少个点集 $V$ 的非空子集 $X$ 对应的联通块内。

考虑 $dp(i,j)$ 表示以 $i$ 为根的子树内大小为 $j$ 的边集,在子树内对上式的贡献。

设当前计算的根为 $u$,有三种情况:

  1. $u​$ 到子树的边单独构成边集。
  2. $u$ 到子树的边与这棵子树的边集合并。
  3. $u$ 的子树的边集之间合并,子树的边集包含前两种情况。

对于以 $v$ 为根的子树,更新 $v$ 的 $dp$ 状态,情况一就是 $dp(v,1)+2^{size(v)}-1$,而情况二是用 $dp(v,i)$ 更新 $dp(u,i+1)$ 。得到每一个子树 $v$ 的 $dp$ 状态,做一个卷积即可得到根 $u$ 的 $dp$ 状态。

考虑卷积的过程,如果 $v_1$ 和 $v_2$ 合并,没有考虑到 $u$ 的其它子树对其的贡献。

为了解决这个问题,卷积过程中对于 $0$ 次项,设置其为 $2^{size(v)}$,这样没有参与合并的子树贡献就被计算进去了,注意这个部分的贡献不需要 $-1$。因为只要参与了合并,那么只有端点的子树才是必须非空,路径上的点都是可取可不取。

但是,状态转移的过程中,实际上已经对最终答案产生贡献,因为 $dp(i,j)$ 考虑的点集都是以 $i$ 为 $LCA$ 的。

对答案的贡献也有类似的三种情况。

情况一和情况二只需要考虑 $u$ 外的节点是否被取,更新答案即可。

情况三需要考虑合并后的边集对答案贡献,注意到卷积后实际上有一些是没有合并的,我们需要从 $u$ 的 $dp$ 状态中减去前两种情况,这有卷积结果就只剩下参与过合并的边集了。这部分边集同样是考虑 $u$ 外的节点是否选取。

于是,我们就得到了一个 $ans$ 数组表示大小为 $i$ 的边集 $E’$ 的贡献,使用第二类斯特林数即可得到最终答案。

时间复杂度 $O(nk^2)$。

最后参考树形依赖背包的优化原理,上面的卷积过程,实际上与子树大小有关,用类似的方法,我们最终得到时间复杂度为 $O(nk)​$ 的解法。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define ms(a,b) memset(a,b,sizeof(a))
using namespace std;
typedef long long ll;
const int mod = 1e9 + 7;
const int maxn = 100000 + 5;

ll two[maxn], S[300][300];
int f[maxn], inv[maxn], finv[maxn];
void init(){
S[0][0] = 1;
for (int i = 1; i <= 200; i++) for (int j = 1; j <= 200; j++) S[i][j] = (S[i - 1][j - 1] + S[i - 1][j] * j % mod) % mod;
inv[1] = 1;
for (int i = 2; i < maxn; i++) inv[i] = (mod - mod / i) * 1ll * inv[mod % i] % mod;
f[0] = finv[0] = 1;
for (int i = 1; i < maxn; i++) {
f[i] = f[i - 1] * 1ll * i % mod;
finv[i] = finv[i - 1] * 1ll * inv[i] % mod;
}
two[0] = 1; for (int i = 1; i < maxn; i++) two[i] = 2ll * two[i - 1] % mod;
}
inline int C(int n, int m){
if (m < 0 || m > n) return 0;
return f[n] * 1ll * finv[n - m] % mod * finv[m] % mod;
}
inline void add(ll& x, ll y) {
x += y; if (x >= mod) x -= mod;
}

vector<int> edge[maxn];
int n, k, siz[maxn];
vector<ll> dp[maxn];
ll ans[maxn];

vector<ll> mul(vector<ll>& x, vector<ll>& y, int k1, int k2) {
k1 = min(k1, k); k2 = min(k2, k);
vector<ll> ans(k + 1, 0);
for (int i = 0; i <= k1; i++) {
for (int j = 0; j <= k2; j++) {
if (i + j > k) break;
add(ans[i + j], x[i] * y[j] % mod);
}
}
return ans;
}

void dfs(int u, int f) {
siz[u] = 1;
vector<ll> x(k + 1, 0); x[0] = 2;
for (int& v: edge[u]) {
if (v == f) continue;
dfs(v, u);
vector<ll> y = vector<ll>(dp[v]);
y[0] = two[siz[v]];
y[1] = (y[1] + two[siz[v]] - 1 + mod) % mod;
for (int i = 2; i <= k; i++) y[i] = (y[i] + dp[v][i - 1]) % mod;
x = mul(x, y, siz[u], siz[v]);
siz[u] += siz[v];
}
dp[u] = x;
for (int& v: edge[u]) {
if (v == f) continue;
ll t = (two[n - siz[v]] - 1 + mod) % mod;
add(ans[1], t * (two[siz[v]] - 1 + mod) % mod);
for (int i = 1; i < k; i++) {
add(ans[i + 1], t * dp[v][i] % mod);
}
ll rs = two[siz[u] - siz[v]] % mod;
x[1] = (x[1] - (two[siz[v]] - 1) * rs % mod + mod) % mod;
for (int i = 2; i <= k; i++) {
x[i] = (x[i] - dp[v][i - 1] * rs % mod + mod) % mod;
}
for (int i = 1; i <= k; i++) {
x[i] = (x[i] - dp[v][i] * rs % mod + mod) % mod;
}
}
for (int i = 2; i <= k; i++) {
add(ans[i], x[i] * two[n - siz[u]] % mod);
}
}

int main(){
init();
scanf("%d%d", &n, &k);
for (int i = 2, u, v; i <= n; i++) {
scanf("%d%d", &u, &v);
edge[u].push_back(v);
edge[v].push_back(u);
}
dfs(1, 0);
ll r = 0;
for (int i = 0; i <= k; i++) {
r = (r + S[k][i] * f[i] % mod * ans[i]) % mod;
}
cout << r << endl;
return 0;
}