2022 CCPC 绵阳现场赛 K Pattern Matching in A Minor Low Space 题解

题面

求模板串 $S$ 在 $T$ 中的出现次数。

其中 $1 \le |S|, |T| \le 10^7$.

时间限制 $5$ 秒, 内存限制 $\mathbf{1}$ MB, 且输入只能读取一次.

题解

KMP 模板题

题目让你在一个空间受限的背景下, 进行流式字符串匹配. 本题无论模板串还是询问串都无法放进内存, 你必须流式的一个一个读取.

首先, 回忆 Rabin-Karp 算法是怎么样的. 它对模板串 $S$ 求了一个 hash 值, 然后使用滑动窗口 维护询问串所有长度 $S$ 的子串的 hash 值, 这里时间复杂度是 $O(|T|)$ 的, 但是空间复杂度是 $O(|S|)$ 的 (需要存下整个窗口, 以供删除开头的字符). 显然无法满足本题的要求.

官方题解的做法:

  1. 读入模板串 $S$, 计算出长度为 $\lfloor \sqrt{n} \rfloor$ 的前缀 hash 值和完整串的 hash 值, 时间复杂度 $O(|S|)$, 空间复杂度 $O(1)$;
  2. 读入询问串 $T$, 维护 $\lfloor \sqrt{n} \rfloor$ 大小的滑动窗口, 可以求出询问串中所有和这个根号前缀匹配的位置. 相当于, 筛出去了一些必不可能匹配上的位置, 留下来一些位置需要求出相应的全长度的 hash 值.
    • 注意到, 直到这里实际上还没有本质的改善, 一是匹配上的位置仍然很多, 二是难以搞出相应的全串 hash 值.
    • 根据 (Weak) Periodicity Lemm, 我们可以把匹配上的位置分成至多 $\lceil \sqrt{n} \rceil$ 组等差数列. 具体的, 每组中所有匹配的子串是一个顺次有重叠的列表, 相邻出现位置的距离差相等 (也就是所谓的构成等差数列, 进一步, 重叠部分其实就是根号前缀的 border). 于是, 记录匹配位置的空间被压缩到了 $O(\lfloor \sqrt{n} \rfloor)$.
    • 下一个问题是, 我们不仅需要记录匹配的位置, 还需要记录匹配处的前缀 hash 值, 等到处理到该次可能匹配的结束位置时来进行全串的比对. 同样根据上述引理的结论, 同一组等差数列内部, 相邻 2 个匹配位置之间的 hash 值是固定的, 我们只需要记录等差数列的起点, 公差, 终点, 起点的 hash 值, 公差对应字符串部分的 hash 值, 就能表示出这一个等差数列的所有位置信息和 hash 值信息.

最终, 时间复杂度 $O(n)$, 空间复杂度 $O(\lfloor \sqrt{n} \rfloor)$, 一是维护滑动窗口, 二是维护所有等差数列.

周期相关理论的参考文献:

  • 金策, 《字符串算法选讲》
  • 2019 年集训队论文, 陈孙立, 《子串周期查询问题的相关算法及其应用》

官方题解

official tutorial

分块 KMP

一个不要 period 的做法, 大概是把模式串分成 sqrt 个字符一块, 每块合并成一个大字符,然后关于大字符跑 sqrt 个并排的 kmp

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#include <cmath>
#include <cstdio>
#include <random>
#include <vector>

using namespace std;

const int base = 131;
const int SZ = 3162;
const int cap = 4096;

namespace {
// biv = base^{-1}, bsziv = base^{-SZ+1}
int mod, biv, bsizv;

random_device rd;
mt19937 rnd(rd());
uniform_int_distribution<> gen(100000000, 900000000);

bool isPrime(int n) {
if (n % 2 == 0) return false;
int sq = (int) sqrt(n) + 1;
for (int i = 3; i <= sq; i += 2) {
if (n % i == 0) return false;
}
return true;
}

int add(int a, int b) {
a += b;
if (a >= mod) a -= mod;
return a;
}

int sub(int a, int b) {
a -= b;
if (a < 0) a += mod;
return a;
}

int mul(int a, int b) {
return 1ll * a * b % mod;
}

int qpow(int x, int n) {
int r = 1;
while (n) {
if (n & 1) r = mul(r, x);
n >>= 1;
x = mul(x, x);
}
return r;
}

void init() {
while (true) {
mod = gen(rnd);
if (isPrime(mod)) {
break;
}
}
biv = qpow(base, mod - 2);
bsizv = qpow(qpow(base, SZ - 1), mod - 2);
}
}

int n, m, preh = 0, allh = 0;

struct Ring {
char buf[cap];
int head = 0, tail = 0, size = 0, len;
int hsh = 0, xp = 1;

void init(int n) {
len = n;
hsh = 0;
head = tail = size = 0;
xp = 1;
}

char pop() {
size--;
char x = buf[(head++) % cap];
return x;
}

void push(char x) {
size++;
buf[(tail++) % cap] = x;
}

void append(char c) {
push(c);
hsh = add(mul(c, xp), hsh);
if (size > len) {
hsh = sub(hsh, pop());
hsh = mul(hsh, biv);
} else {
xp = mul(xp, base);
}
}
} f;

struct Per {
int start, delta = -1, end = -1;
int xp, hsh, dhsh = -1;

Per(int p, int x, int h) : start(p + 1 - SZ) {
end = start;
xp = mul(x, bsizv);
hsh = sub(h, mul(preh, xp));
}

bool next(int p, int curx, int curh) {
p = p + 1 - SZ;
if (delta == -1) {
if (p - start >= SZ) {
// it should have overlap part
return false;
} else {
// set delta
end = p;
delta = end - start;
curh = sub(curh, mul(preh, mul(curx, bsizv)));
dhsh = sub(curh, hsh);
return true;
}
} else {
if (p - end == delta) {
end = p;
return true;
} else {
return false;
}
}
}

bool match(int pos, int curv) {
if (start + n - 1 == pos) {
int target = sub(curv, hsh);
bool ok = target == mul(allh, xp);
if (delta != -1) {
start += delta;
int dxp = qpow(base, delta);
xp = mul(xp, dxp);
hsh = add(hsh, dhsh);
dhsh = mul(dhsh, dxp);
}
return ok;
}
return false;
}
};

int main() {
init();

scanf("%d%d", &n, &m);
getchar(); // end of line

for (int i = 1, xp = 1; i <= n; i++, xp = mul(xp, base)) {
char c = getchar();
int val = mul(c, xp);
if (i <= SZ) {
preh = add(preh, val);
}
allh = add(allh, val);
}
getchar(); // end of line

f.init(n <= SZ ? n : SZ);
int ans = 0, curv = 0, matched = 0;
vector<Per> ps;
for (int i = 1, xp = 1; i <= m; i++, xp = mul(xp, base)) {
char c = getchar();
f.append(c);
if (n <= SZ) {
ans += f.size == n && f.hsh == allh;
} else {
curv = add(curv, mul(c, xp));
if (f.size == SZ && f.hsh == preh) {
// match the sqrt prefix
// extend the last group or create a new group
if (ps.empty() || !ps.back().next(i, xp, curv)) {
ps.emplace_back(i, xp, curv);
}
}
if (matched < ps.size()) {
ans += ps[matched].match(i, curv);
if (ps[matched].end + n - 1 == i) {
matched++;
}
}
}
}

printf("%d\n", ans);

return 0;
}