学军信友队趣味网络邀请赛 C 病毒研究 题解

题面

给你一个病毒,他的强度为 $x$,你不知道 $x$ 的值,但是你知道它在 $n$ 个区间中的某一个 $[1, a_1],[a_1+1, a_2], \dots, [ a_{n-1} + 1, a_n ]$。

你现在有 $m$ 种操作,第 $i$ 个操作是强度减少 $w_i$,花费 $v_i$。你需要通过操作将病毒强度控制到区间 $[1, a_1]$ 中。

你现在要求 $x$ 是 $[1, a_n]$ 中等概率取一个的最小的花费期望(答案乘 $a_n$ 变成整数)。

其中 $1 \le a_n, m \le 2000$。

分析

显然病毒通过一系列操作病毒的强度一直是数轴上的某一段区间,我们令状态是 $f(l,r)$ 表示病毒强度在 $[l,r]$ 范围内的最小的花费期望。

如果 $[l,r]$ 在第一个区间内,则花费为 $0$。

如果 $[l,r]$ 没有跨越任何区间,则枚举下一次操作的花费进行转移。

如果 $[l,r]$ 跨越了区间,则用原来的分段点,将区间分段,每段单独算再加起来。

显然一共有 $n^2$ 种状态,转移是 $O(m)$ 的,时间复杂度是 $O(n^2m)$。

但是,我们考虑 $[l,r]$ 跨越区间的情况,它是由一堆连续的完整区间,开头的一个完整区间的后缀和结尾的一个完整区间的前缀组成。我们转移时实际上是将当前区间向左平移了一段距离,如果该区间平移后没有发生跨越,那么这个区间的转移还要继续重复枚举所有操作,产生了冗余。当一个区间产生了跨越后,这个区间的计算就会被拆解成上面的段,不再是简单地枚举操作。

因此,我们对 $[l,r]$ 区间进行转移时,强制钦定他转移到跨越的情况(或者某个区间的前缀或后缀,或者到状态一)。两头的区间我们使用记忆化搜索进行转移,而中间的完整区间可以用前缀和直接算出来。所以总体状态数就是 $O(a_n)$。

时间复杂度:$O(n^2)$。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <iostream>
#include <cstdio>
#include <cassert>
#include <cstring>
#include <cmath>
#include <functional>
#include <algorithm>
#include <utility>
#include <vector>
#include <string>
#include <map>
#include <set>
#ifdef XLor
#define dbg(args...) cout << "\033[32;1m" << #args << " -> ", err(args)
void err() { std::cout << "\033[39;0m" << std::endl; }
template<typename T, typename...Args>
void err(T a, Args...args) { std::cout << a << ' '; err(args...); }
#else
#define dbg(...)
#endif
#define ms(a,b) memset(a,b,sizeof(a))
using namespace std;
using ll = long long;
using PII = pair<int,int>;
const int mod = 998244353;
const ll inf = 1ll << 60;
const int maxn = 2000 + 5;

int n, m, a[maxn], id[maxn], v[maxn], w[maxn];
ll cost[maxn], pre[maxn], dp[maxn][maxn];

ll solve(int l, int r) {
if (dp[l][r] != -1) return dp[l][r];
if (r <= a[1]) return 0;
ll mn = inf;
for (int c = 1; c < l; c++) {
if (cost[c] == inf) continue;
int lid = id[l - c], rid = id[r - c];
ll val = cost[c] * (r - l + 1);
if (lid != rid) {
val += solve(l - c, a[lid]) + solve(a[rid - 1] + 1, r - c);
val += pre[rid - 1] - pre[lid];
} else if (l - c == a[lid - 1] + 1) {
val += solve(l - c, a[lid]);
} else if (r - c == a[rid]) {
val += solve(a[rid - 1] + 1, r - c);
} else if (r - c > a[1]) {
continue;
}
mn = min(mn, val);
}
return dp[l][r] = mn;
}

int main() {
int T; scanf("%d", &T);
while (T--) {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", a + i);
for (int j = a[i - 1] + 1; j <= a[i]; j++) {
id[j] = i; cost[j] = inf;
}
}
for (int i = 1; i <= a[n]; i++) {
for (int j = i; j <= a[n]; j++) {
dp[i][j] = -1;
}
}
for (int i = 1; i <= m; i++) {
scanf("%d%d", v + i, w + i);
}
for (int i = 1; i <= m; i++) {
for (int j = w[i]; j <= a[n]; j++) {
cost[j] = min(cost[j], cost[j - w[i]] + v[i]);
}
}
for (int i = 2; i <= n; i++) {
pre[i] = pre[i - 1] + solve(a[i - 1] + 1, a[i]);
}
if (pre[n] == inf) puts("-1");
else printf("%lld\n", pre[n]);
}
return 0;
}