Codeforces Global Round 2 F Niyaz and Small Degrees

题意

给定一棵带边权的无根树,要求删除一些边使得每个点的度数不超过 $k$,最小化费用,对所有 $[0,n-1]$ 中每个数回答询问。

其中 $2 \le n \le 250000$。

分析

首先,假如只有一个询问,记当前的最大度数为 $k$,设 $dp(x,0/1)$ 表示删除一些边使得 $x$ 点剩余度数为 $k$ 和 $k+1$ 的最小费用。转移时,不妨令所有边都没有被删除,将删除一条连向 $v$ 边权为 $w$ 的边产生的影响记录下来,即 $w+dp(v,1)-dp(v,0)$,如果是非负数直接选择,否则选择一些最小的。

回到原题,如何批量回答所有询问?

注意到一个事实:$\sum_{k=0}^{n-1} \sum_{u=1}^{n} [deg(u)>k]=2n-2$

考虑从小到大枚举最大度数 $k$,记度数大于 $k$ 的为关键点,否则为非关键点。

每个点维护一个堆,表示它删除所有连向非关键点的边的影响,因为非关键点度数小于等于 $k$,它的边随便删不删都行,因此这个影响等价于边权。

为了保证复杂度,每次只能 dfs 所有关键点构成的森林。因为非关键点的影响已经在堆内维护好了,只需要把关键点的影响丢进去即可,用类似与上述的方案求出相应的 $dp$ 值,注意出栈的时候,需要撤回一下关键点的操作。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <vector>
#include <string>
#include <utility>
#include <set>
#define ms(a,b) memset(a,b,sizeof(a))
using namespace std;
typedef long long ll;
typedef pair<int,int> PII;
const int mod = 998244353;
const int inf = 1 << 30;
const int maxn = 250000 + 5;

struct Heap {
multiset<ll> st; ll sum;
void insert(ll v) {
st.insert(v); sum += v;
}
void del(ll v) {
st.erase(st.find(v)); sum -= v;
}
void resize(int sz) {
while (!st.empty() && (int)st.size() > sz) {
sum -= *st.rbegin();
st.erase(--st.end());
}
}
void resize(int sz, vector<ll>& v) {
while (!st.empty() && (int)st.size() > sz) {
sum -= *st.rbegin();
v.push_back(*st.rbegin());
st.erase(--st.end());
}
}
} h[maxn];

int n, lim, deg[maxn], ptr[maxn], vis[maxn];
vector<PII> edge[maxn];
vector<int> bag[maxn], key[maxn];

ll dp[maxn][2]; // lim/lim+1
void dfs(int u, int ff) {
vis[u] = lim;
int sz = deg[u] - lim;
h[u].resize(sz);
vector<ll> add, sub;
ll tot = 0; int& p = ptr[u];
while (p < deg[u] && deg[edge[u][p].first] <= lim) p++;
for (int i = p; i < deg[u]; i++) {
int v = edge[u][i].first, w = edge[u][i].second;
if (v == ff) continue;
dfs(v, u);
if (dp[v][1] + w < dp[v][0]) {
tot += dp[v][1] + w; sz--;
} else {
tot += dp[v][0];
add.push_back(dp[v][1] + w - dp[v][0]);
h[u].insert(dp[v][1] + w - dp[v][0]);
}
}
h[u].resize(sz, sub); dp[u][0] = tot + h[u].sum;
h[u].resize(sz - 1, sub); dp[u][1] = tot + h[u].sum;
for (auto& x: sub) h[u].insert(x);
for (auto& x: add) h[u].del(x);
}

int main() {
scanf("%d", &n);
ll total = 0;
for (int i = 2, u, v, w; i <= n; i++) {
scanf("%d%d%d", &u, &v, &w);
edge[u].push_back({v, w});
edge[v].push_back({u, w});
deg[u]++; deg[v]++;
total += w;
}
for (int i = 1; i <= n; i++) {
bag[deg[i]].push_back(i);
sort(edge[i].begin(), edge[i].end(), [&](PII a, PII b) {
return deg[a.first] < deg[b.first];
});
for (int j = 1; j < deg[i]; j++) key[j].push_back(i);
}
printf("%I64d", total);
for (lim = 1; lim < n; lim++) {
for (auto u: bag[lim]) {
for (auto x: edge[u]) {
int v = x.first, w = x.second;
h[v].insert(w);
}
}
ll ans = 0;
for (auto& u: key[lim]) {
if (vis[u] == lim) continue;
dfs(u, 0); ans += dp[u][0];
}
printf(" %I64d", ans);
}
return 0;
}