传送门:Hello 2019 G. Vladislav and a Great Legend
题面
给一棵有根树 $T=(V,E)$,$X$ 是 $V$ 的非空子集,记 $f(X)$ 表示 $T$ 中使得点集 $X$ 联通的最小联通子图中的边数,求
$$
\sum_{X\subseteq V,X \neq \varnothing} (f(X))^k
$$
其中,$3 \le |V| \le 10^5,1\le k \le 200$ 。
分析
显然,$k=1$ 直接扫一遍算贡献即可。
考虑 $k>1$ 的情况,由第二类斯特林数展开
$$
(f(X))^k=\sum_{i=0}^{k}S(k,i)i!{f(X) \choose i}
$$
前面一部分可以预处理,也就是要求对于 $i (0\le i \le k)$
$$
\sum_{X\subseteq V,X \neq \varnothing} {f(X) \choose i}
$$
设 $E’ \subseteq E$ 且 $|E’|=i$,$E_X$ 为 $f(X)$ 对应联通块的边集,上式可以写成
$$
\sum_{X\subseteq V,X \neq \varnothing} \sum_{e_j \in E’,1\le j \le i} [e_1,e_2,\dots,e_i \in E_X]
$$
于是,交换求和顺序
$$
\sum_{e_j \in E’,1\le j \le i} \sum_{X\subseteq V,X \neq \varnothing} [e_1,e_2,\dots,e_i \in E_X]
$$
上式含义是枚举边集 $E$ 的大小为 $i$ 的子集 $E’$,计算 $E’$ 出现在多少个点集 $V$ 的非空子集 $X$ 对应的联通块内。
考虑 $dp(i,j)$ 表示以 $i$ 为根的子树内大小为 $j$ 的边集,在子树内对上式的贡献。
设当前计算的根为 $u$,有三种情况:
- $u$ 到子树的边单独构成边集。
- $u$ 到子树的边与这棵子树的边集合并。
- $u$ 的子树的边集之间合并,子树的边集包含前两种情况。
对于以 $v$ 为根的子树,更新 $v$ 的 $dp$ 状态,情况一就是 $dp(v,1)+2^{size(v)}-1$,而情况二是用 $dp(v,i)$ 更新 $dp(u,i+1)$ 。得到每一个子树 $v$ 的 $dp$ 状态,做一个卷积即可得到根 $u$ 的 $dp$ 状态。
考虑卷积的过程,如果 $v_1$ 和 $v_2$ 合并,没有考虑到 $u$ 的其它子树对其的贡献。
为了解决这个问题,卷积过程中对于 $0$ 次项,设置其为 $2^{size(v)}$,这样没有参与合并的子树贡献就被计算进去了,注意这个部分的贡献不需要 $-1$。因为只要参与了合并,那么只有端点的子树才是必须非空,路径上的点都是可取可不取。
但是,状态转移的过程中,实际上已经对最终答案产生贡献,因为 $dp(i,j)$ 考虑的点集都是以 $i$ 为 $LCA$ 的。
对答案的贡献也有类似的三种情况。
情况一和情况二只需要考虑 $u$ 外的节点是否被取,更新答案即可。
情况三需要考虑合并后的边集对答案贡献,注意到卷积后实际上有一些是没有合并的,我们需要从 $u$ 的 $dp$ 状态中减去前两种情况,这有卷积结果就只剩下参与过合并的边集了。这部分边集同样是考虑 $u$ 外的节点是否选取。
于是,我们就得到了一个 $ans$ 数组表示大小为 $i$ 的边集 $E’$ 的贡献,使用第二类斯特林数即可得到最终答案。
时间复杂度 $O(nk^2)$。
最后参考树形依赖背包的优化原理,上面的卷积过程,实际上与子树大小有关,用类似的方法,我们最终得到时间复杂度为 $O(nk)$ 的解法。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
| #include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <vector> #define ms(a,b) memset(a,b,sizeof(a)) using namespace std; typedef long long ll; const int mod = 1e9 + 7; const int maxn = 100000 + 5;
ll two[maxn], S[300][300]; int f[maxn], inv[maxn], finv[maxn]; void init(){ S[0][0] = 1; for (int i = 1; i <= 200; i++) for (int j = 1; j <= 200; j++) S[i][j] = (S[i - 1][j - 1] + S[i - 1][j] * j % mod) % mod; inv[1] = 1; for (int i = 2; i < maxn; i++) inv[i] = (mod - mod / i) * 1ll * inv[mod % i] % mod; f[0] = finv[0] = 1; for (int i = 1; i < maxn; i++) { f[i] = f[i - 1] * 1ll * i % mod; finv[i] = finv[i - 1] * 1ll * inv[i] % mod; } two[0] = 1; for (int i = 1; i < maxn; i++) two[i] = 2ll * two[i - 1] % mod; } inline int C(int n, int m){ if (m < 0 || m > n) return 0; return f[n] * 1ll * finv[n - m] % mod * finv[m] % mod; } inline void add(ll& x, ll y) { x += y; if (x >= mod) x -= mod; }
vector<int> edge[maxn]; int n, k, siz[maxn]; vector<ll> dp[maxn]; ll ans[maxn];
vector<ll> mul(vector<ll>& x, vector<ll>& y, int k1, int k2) { k1 = min(k1, k); k2 = min(k2, k); vector<ll> ans(k + 1, 0); for (int i = 0; i <= k1; i++) { for (int j = 0; j <= k2; j++) { if (i + j > k) break; add(ans[i + j], x[i] * y[j] % mod); } } return ans; }
void dfs(int u, int f) { siz[u] = 1; vector<ll> x(k + 1, 0); x[0] = 2; for (int& v: edge[u]) { if (v == f) continue; dfs(v, u); vector<ll> y = vector<ll>(dp[v]); y[0] = two[siz[v]]; y[1] = (y[1] + two[siz[v]] - 1 + mod) % mod; for (int i = 2; i <= k; i++) y[i] = (y[i] + dp[v][i - 1]) % mod; x = mul(x, y, siz[u], siz[v]); siz[u] += siz[v]; } dp[u] = x; for (int& v: edge[u]) { if (v == f) continue; ll t = (two[n - siz[v]] - 1 + mod) % mod; add(ans[1], t * (two[siz[v]] - 1 + mod) % mod); for (int i = 1; i < k; i++) { add(ans[i + 1], t * dp[v][i] % mod); } ll rs = two[siz[u] - siz[v]] % mod; x[1] = (x[1] - (two[siz[v]] - 1) * rs % mod + mod) % mod; for (int i = 2; i <= k; i++) { x[i] = (x[i] - dp[v][i - 1] * rs % mod + mod) % mod; } for (int i = 1; i <= k; i++) { x[i] = (x[i] - dp[v][i] * rs % mod + mod) % mod; } } for (int i = 2; i <= k; i++) { add(ans[i], x[i] * two[n - siz[u]] % mod); } }
int main(){ init(); scanf("%d%d", &n, &k); for (int i = 2, u, v; i <= n; i++) { scanf("%d%d", &u, &v); edge[u].push_back(v); edge[v].push_back(u); } dfs(1, 0); ll r = 0; for (int i = 0; i <= k; i++) { r = (r + S[k][i] * f[i] % mod * ans[i]) % mod; } cout << r << endl; return 0; }
|