给定长度不超过 15 ,字符集大小为 4 的字符串 $S$ ,对所有正整数 $0\le i\le|S|$ 询问有多少长度为 $n$ 的同样字符集的字符串与 $S$ 的最长公共子序列长度为 $i$ 。
5 组数据, $n\le 10^3$
dp套dp是2015年陈立杰在WC上计数问题选讲中提出的。思想是将内层DP的结果作为外层DP的状态。
2021年IOI集训队杭州学军中学徐哲安的论文《浅谈有限状态自动机及其应用》中再次提到了dp套dp问题,但我没有看懂。
设 $\mathrm{LCS}[i][j]$ 表示我们构造的串 $T$ 的前 $i$ 位与给定串 $S$ 的前 $j$ 位的最长公共子序列。
容易看出, $\mathrm{LCS}[i][j+1]-\mathrm{LCS}[i][j]$ 只可能是 0 或 1。由于 $j$ 最大为 15 ,可以用一个最大 15 位的二进制数 $s[i]$ 来表示一整行 $\mathrm{LCS}[i]$ 。
同时发现 $s[i]$ 只与 $s[i-1]$ 、 $S$ 、$T[i]$ 有关。所以在给定的 $S$ 下,只要给定 $T[i]$ ,就能由 $s[i-1]$ 推出 $s[i]$ 。用 $\mathrm{trans}[\mathrm{state}][c]$ 表示 $s[i-1]=\mathrm{state}, T[i]=c$ 时的 $s[i]$ 。
计算 $\mathrm{trans}$ 需要用到传统的 $\mathrm{LCS}$ ,也就是内层 DP。
设 $f[i][s]$ 表示长度为 $i$ 的字符串与 $S$ 的 $\mathrm{LCS}$ 是 $s$ 的方案数。有转移方程:
计算 $f$ 就是外层 DP。
可以使用刷表法实现。
#include<cstdio>
#include<cstring>
#include<algorithm>
const int MOD = 1e9 + 7;
int cnt_ones[1 << 15];
int trans[1 << 15][4];
int state[2][1 << 15];
void solve() {
int n;
char str[20];
scanf("%s%d", str + 1, &n);
const int slen = strlen(str + 1);
const int state_number = 1 << slen;
// 内层DP
for(int s = 0; s < state_number; ++s) {
int state[2][20];
state[0][0] = 0;
for(int i = 1; i <= slen; ++i) {
state[0][i] = state[0][i - 1] + (s >> (i - 1) & 1);
}
char sigma[] = "AGCT";
for(int k = 0; k < 4; ++k) {
char c = sigma[k];
state[1][0] = 0;
int ns = 0;
for(int j = 1; j <= slen; ++j) {
state[1][j] = c == str[j] ?
state[0][j - 1] + 1 :
std::max(state[0][j], state[1][j - 1]);
ns |= (state[1][j] - state[1][j - 1]) << (j - 1);
}
trans[s][k] = ns;
}
}
// 外层DP
state[0][0] = 1;
for(int s = 1; s < state_number; ++s) {
state[0][s] = 0;
}
for(int i = 0; i < n; ++i) {
for(int s = 0; s < state_number; ++s) {
state[i & 1 ^ 1][s] = 0;
}
for(int s = 0; s < state_number; ++s) {
for(int k = 0; k < 4; ++k) {
state[i & 1 ^ 1][trans[s][k]] += state[i & 1][s];
state[i & 1 ^ 1][trans[s][k]] %= MOD;
}
}
}
// 统计答案
int ans[20];
for(int i = 0; i <= slen; ++i) ans[i] = 0;
for(int s = 0; s < state_number; ++s) {
ans[cnt_ones[s]] += state[n & 1][s];
ans[cnt_ones[s]] %= MOD;
}
for(int i = 0; i <= slen; ++i) {
printf("%d\n", ans[i]);
}
}
int main() {
// 预处理
for(int s = 0; s < (1 << 15); ++s) {
cnt_ones[s] = cnt_ones[s >> 1] + (s & 1);
}
int cas;
scanf("%d", &cas);
while(cas--) {
solve();
}
return 0;
}