问题描述
给定 $n$ 个长度均为 $m$ 的字符串,现在你开始投骰子,有 $P_c$ 的概率投出字符 $c$,求第一次出现的串是第 $i$ 个的概率。
一般做法
对这个字符串集合建出 Trie 图后,问题等价于给定一些终点的无向图随机游走问题。
令结点数为 $s$,建出概率转移矩阵后,实际上只有 $s-1$ 条有效的方程,因为无法转移到根节点。但是,随机游走必定会在某个时刻终结,也就是走到所有终止结点的概率和为 $1$。添加这么一条方程后,高斯消元,即可解出在每个终止结点结束的概率。
时间复杂度:$O(nm|\Sigma|+n^3m^3)$。
加强做法
令概率生成函数 $[x^k]G(x)$ 表示在 $k$ 时刻尚未停止的概率,$[x^k]F_i(x)$ 表示在 $k$ 时刻末尾是第 $i$ 个串的概率。
我们要求的就是 $F_i(1),F_i(2),\dots, F_i(n)$。
考虑在每个尚未停止的时刻,往后走一步的情况,即 $xG(x)$,再加上初始时的概率 $1$,这时要么在此处迎来某个串的结束,要么尚未停止,有下式。
$$
1+xG(x)=G(x)+\sum_{i=1}^n F_i(x)
$$
记 $P(S)$ 表示一个串的出现概率,即 $P(S)=\prod_{i=1}^{|S|} P_{S_i}$
对于每个串 $i$,在一个未终止时刻,往后加第 $i$ 个串,即 $G(x)P(s_i)x^m$。这时情况很多,因为加这个串的时候可能中途出现已经终止的情况,设第 $j$ 个串在新加串的 $k$ 位置终止,因此 $s_i[1 \dots k]=s_j[ m - k + 1 \dots m]$,即第 $i$ 个串的前缀等于长度为 $k$ 的第 $j$ 个串的后缀,此时的概率生成函数就是 $F_j(x)P(s_i[k+1 \dots m])x^{m-k}$。
我们枚举每个串 $j$ 和它的前缀长度 $k$,有下式:
$$
G(x)P(s_i)x^m = \sum_{j=1}^n \sum_{k=1}^m [ s_i[1 \dots k]=s_j[ m - k + 1 \dots m] ] F_j(x)P(s_i[k+1 \dots m])x^{m-k}
$$
结合我们要求的,带入 $x=1$,可以得到 $n$ 个方程,但是 $G(1)$ 也是未知量,因此有 $n+1$ 个未知量,联立 $\sum_{i=1} F_i(1)=1$,即可得到 $n+1$ 个方程,高斯消元。注意这里我们的变量数是 $n+1$,不是一般做法的最大 $nm$。
时间复杂度:$O(n^2m+n^3)$。
代码
BZOJ 1444
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
| #include <iostream> #include <cstdio> #include <cassert> #include <cstring> #include <cmath> #include <functional> #include <algorithm> #include <utility> #include <vector> #include <string> #include <map> #include <set> #include <queue> #ifdef XLor #define dbg(args...) cout << "\033[32;1m" << #args << " -> ", err(args) void err() { std::cout << "\033[39;0m" << std::endl; } template<typename T, typename...Args> void err(T a, Args...args) { std::cout << a << ' '; err(args...); } #else #define dbg(...) #endif #define ms(a,b) memset(a,b,sizeof(a)) using namespace std; using ll = long long; using PII = pair<int,int>; const int mod = 998244353; const int inf = 1 << 30; const int maxn = 1000 + 5;
const double eps = 1e-5;
int n, l, S, fin[maxn]; double p[20]; char s[maxn];
namespace acam { static const int maxp = 100000 + 5; static const int S = 26; static const int Base = 'A'; int sz, ch[maxp][10], fail[maxp], val[maxp];; int node() { ms(ch[++sz], 0); fail[sz] = val[sz] = 0; return sz; } void clear() { sz = 0; node(); for (int i = 0; i < S; i++) ch[0][i] = 1; } int insert(char* s, int i) { int u = 1; for (int i = 0; s[i]; i++) { int v = s[i] - Base; // ! if (!ch[u][v]) ch[u][v] = node(); u = ch[u][v]; } val[u]++; return u; } void build() { queue<int> q; q.push(1); while (!q.empty()) { int t = q.front(); q.pop(); for (int i = 0; i < S; i++) { if (ch[t][i]) { fail[ch[t][i]] = ch[fail[t]][i]; q.push(ch[t][i]); } else { ch[t][i] = ch[fail[t]][i]; } } } } } using namespace acam;
double a[maxn][maxn], ans[maxn]; bool solve(int n) { for (int i = 1; i <= n; i++) { int r = i; for (int j = i + 1; j <= n; j++) if (abs(a[j][i]) > abs(a[r][i])) r = j; if (abs(a[r][i]) < eps) return false; swap(a[r], a[i]); double inv = a[i][i]; for (int j = i; j <= n + 1; j++) a[i][j] /= inv; for (int j = i + 1; j <= n; j++) { double inv = a[j][i]; for (int k = i; k <= n + 1; k++) a[j][k] -= inv * a[i][k]; } } for (int i = n; i >= 1; i--) { ans[i] = a[i][n + 1]; for (int j = i + 1; j <= n; j++) ans[i] -= a[i][j] * ans[j]; } return true; }
int main() { scanf("%d%d%d", &n, &l, &S); acam::clear(); for (int i = 0, a, b; i < S; i++) { scanf("%d%d", &a, &b); p[i] = double(a) / b; } for (int i = 1; i <= n; i++) { scanf("%s", s); fin[i] = acam::insert(s, i); } acam::build(); for (int i = 1; i <= sz; i++) { a[i][i] = -1; } for (int i = 1; i <= sz; i++) { if (val[i]) continue; for (int j = 0; j < S; j++) { a[ch[i][j]][i] += p[j]; } } a[1][sz + 1] = 1; for (int i = 1; i <= sz; i++) { if (val[i]) { a[1][i] = 1; } else { a[1][i] = 0; } } bool x = solve(sz);
for (int i = 1; i <= n; i++) { double r = ans[fin[i]]; if (r > 0) { printf("%.2lf\n", r); } else { puts("0.00"); } } return 0; }
|
「SDOI2017」硬币游戏
注意:eps 的精度。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
| #include <iostream> #include <cstdio> #include <cassert> #include <cstring> #include <cmath> #include <functional> #include <algorithm> #include <utility> #include <vector> #include <queue> #define ms(a,b) memset(a,b,sizeof(a)) using namespace std; using ll = long long; using PII = pair<int,int>; const int mod = 998244353; const int inf = 1 << 30; const int maxn = 500 + 5; const long double eps = 1e-100;
int n, m, fin[maxn]; char s[maxn]; long double two[maxn], a[maxn][maxn], ans[maxn];
namespace acam { static const int maxp = 100000 + 5; static const int S = 2; int sz, ch[maxp][S], fail[maxp], len[maxp]; vector<int> nds[maxp]; int node() { ms(ch[++sz], 0); fail[sz] = len[sz] = 0; return sz; } void clear() { sz = 0; node(); for (int i = 0; i < S; i++) ch[0][i] = 1; } int insert(char* s, int p) { int u = 1; for (int i = 0; s[i]; i++) { int v = s[i] == 'H'; if (!ch[u][v]) ch[u][v] = node(); u = ch[u][v]; len[u] = i + 1; nds[u].push_back(p); } return u; } void build() { queue<int> q; q.push(1); while (!q.empty()) { int t = q.front(); q.pop(); for (int i = 0; i < S; i++) { if (ch[t][i]) { fail[ch[t][i]] = ch[fail[t]][i]; q.push(ch[t][i]); } else { ch[t][i] = ch[fail[t]][i]; } } } for (int i = 1; i <= n; i++) { a[i][n + 1] = -1.0; int x = fin[i]; while (x > 1) { for (int u: nds[x]) { a[u][i] += two[len[x]]; } x = fail[x]; } } a[n + 1][n + 2] = 1; for (int i = 1; i <= n; i++) { a[n + 1][i] = 1; } } }
bool solve(int n) { for (int i = 1; i <= n; i++) { int r = i; for (int j = i + 1; j <= n; j++) if (abs(a[j][i]) > abs(a[r][i])) r = j; if (abs(a[r][i]) < eps) return false; swap(a[r], a[i]); double inv = a[i][i]; for (int j = i; j <= n + 1; j++) a[i][j] /= inv; for (int j = i + 1; j <= n; j++) { double inv = a[j][i]; for (int k = i; k <= n + 1; k++) a[j][k] -= inv * a[i][k]; } } for (int i = n; i >= 1; i--) { ans[i] = a[i][n + 1]; for (int j = i + 1; j <= n; j++) ans[i] -= a[i][j] * ans[j]; } return true; }
int main() { acam::clear(); scanf("%d%d", &n, &m); two[0] = 1.0; for (int i = 1; i <= m; i++) { two[i] = two[i - 1] * 2.0; } for (int i = 1; i <= n; i++) { scanf("%s", s); fin[i] = acam::insert(s, i); } acam::build(); bool f = solve(n + 1); for (int i = 1; i <= n; i++) { printf("%.6Lf\n", ans[i]); } return 0; }
|