HDu5320 Fan Li 题解

题面

给定一个序列,求选出一堆不相交的子区间,每个子区间的 $\gcd$ 相同,要求最大化选出的个数,并求出对应方案数。

其中 $1 \le n \le 10^5, 1 \le a_i \le 2333333$。

分析

首先我们知道一个结论,对于固定的左端点,它后面所有区间的 $\gcd$ 最多变化 $\log$ 次。

考虑四元组 $(x,i,l,r)$ 表示左端点为 $i$,右端点在区间 $[l,r]$,所有区间的 $\gcd$ 为 $x$,枚举端点二分即可。

将相同 $x$ 并在一起按顺序处理,令 $dp(i)$ 表示以 $i$ 结尾的最大值和方案数。

加一个四元组也就是从 $[1, i - 1]$ 中选一个最优决策(并且维护最优决策的方案数),覆盖(越往后越优)到 $[l,r]$ 中的每个位置。

线段树需要支持区间覆盖,维护区间和(最大值),并且打时间戳标记删除之前枚举的 $\gcd$。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include <iostream>
#include <cstdio>
#include <cassert>
#include <cstring>
#include <algorithm>
#include <utility>
#include <vector>
#ifdef XLor
#define dbg(args...) cout << "\033[32;1m" << #args << " -> ", err(args)
void err() { std::cout << "\033[39;0m" << std::endl; }
template<typename T, typename...Args>
void err(T a, Args...args) { std::cout << a << ' '; err(args...); }
#else
#define dbg(...)
#endif
#define ms(a,b) memset(a,b,sizeof(a))
#define lson l, m, rt << 1
#define rson m + 1, r, rt << 1 | 1
using namespace std;
using ll = long long;
using PII = pair<int,int>;
const int mod = 998244353;
const int inf = 1 << 30;
const int maxn = 100000 + 5;

int add(int x, int y) {
x += y; return x >= mod ? x - mod : x;
}
int mul(int x, int y) {
return 1ll * x * y % mod;
}

int n, g[18][maxn];
int qgcd(int l, int r) {
int k = __lg(r - l + 1);
return __gcd(g[k][l], g[k][r - (1 << k) + 1]);
}

struct Node {
int x, i, l, r;
bool operator< (const Node& b) const {
if (x == b.x) return i < b.i;
return x < b.x;
}
} vec[maxn * 30];

struct Info {
int val, sum;
void clear() { val = sum = 0; }
Info(int a = 0, int b = 0): val(a), sum(b) {}
Info operator+(const Info& b) {
if (val < b.val) return Info(b.val, b.sum);
else if (val > b.val) return Info(val, sum);
else return Info(val, add(sum, b.sum));
}
Info operator+=(const Info& b) {
return *this = *this + b;
}
} t[maxn * 4], tag[maxn * 4];
int time, pos[maxn * 4];

void build(int l = 1, int r = n, int rt = 1) {
pos[rt] = -1; t[rt].clear(); tag[rt].clear();
if (l == r) return ;
int m = (l + r) / 2;
build(lson); build(rson);
}
void clear(int rt) {
if (pos[rt] < time) {
pos[rt] = time;
t[rt].clear(); tag[rt].clear();
}
}
void upd(int rt, Info x, int len) {
clear(rt);
t[rt] += Info(x.val, mul(x.sum, len));
tag[rt] += x;
}
void pushdown(int rt, int ln, int rn) {
if (!tag[rt].val) return ;
upd(rt << 1, tag[rt], ln);
upd(rt << 1 | 1, tag[rt], rn);
tag[rt].clear();
}
void pushup(int rt) {
clear(rt << 1); clear(rt << 1 | 1);
t[rt] = t[rt << 1] + t[rt << 1 | 1];
}
Info query(int L, int R, int l = 1, int r = n, int rt = 1) {
clear(rt);
if (L <= l && r <= R) return t[rt];
int m = (l + r) / 2; pushdown(rt, m - l + 1, r - m);
Info ans;
if (R <= m) ans = query(L, R, lson);
else if (L > m) ans = query(L, R, rson);
else ans = query(L, R, lson) + query(L, R, rson);
return ans;
}
void update(int L, int R, const Info& x, int l = 1, int r = n, int rt = 1) {
clear(rt);
if (L <= l && r <= R) {
upd(rt, x, r - l + 1); return ;
}
int m = (l + r) / 2; pushdown(rt, m - l + 1, r - m);
if (L <= m) update(L, R, x, lson);
if (R > m) update(L, R, x, rson);
pushup(rt);
}

int main() {
while (scanf("%d", &n) == 1) {
for (int i = 1; i <= n; i++) {
scanf("%d", &g[0][i]);
}
for (int j = 1; j < 18; j++) {
for (int i = 1; i + (1 << j) <= n + 1; i++) {
g[j][i] = __gcd(g[j - 1][i], g[j - 1][i + (1 << (j - 1))]);
}
}
build();
int sz = 0;
for (int i = 1; i <= n; i++) {
for (int j = i; j <= n; j++) {
int l = j, r = n, ans = j, g = qgcd(i, j);
while (l <= r) {
int m = (l + r) / 2;
if (qgcd(i, m) == g) {
ans = m; l = m + 1;
} else {
r = m - 1;
}
}
vec[sz++] = (Node){ g, i, j, ans };
j = ans;
}
}
sort(vec, vec + sz);
Info ans;
for (int i = 0, j; i < sz; i = j) {
j = i; time = vec[i].x;
while (j < sz && vec[j].x == vec[i].x) {
Info r;
if (vec[j].i > 1) {
r = query(1, vec[j].i - 1);
}
if (r.val == 0) {
r.val = r.sum = 1;
} else {
r.val++;
}
update(vec[j].l, vec[j].r, r);
r.sum = mul(r.sum, vec[j].r - vec[j].l + 1);
ans += r; j++;
}
}
printf("%d %d\n", ans.val, ans.sum);
}
return 0;
}