0%

dp选讲

7.11dp by 一休哥777 orz

P2051

在一个 $n\times m$ 的棋盘上放置若干个互不攻击的炮,求摆放的方案数。$n,m\le100$。

首先发现炮不能互相攻击等价于每行每列最多有 $2$ 个炮。

考虑设 dp 状态 $dp_{i,j,k}$ 表示当前到达第 $i$ 行,且有 $j$ 列放置两个炮、$k$ 列放置一个炮。那么相应的放置 $0$ 个炮的列数就是 $m-j-k$。

在考虑 $i$ 行的时候考虑这一行放置几个炮:

  • 放置 $0$ 个,$dp_{i,j,k}\leftarrow dp_{i-1,j,k}$;
  • 放置 $1$ 个,那么这一个炮可能放在原来的 $0$ 炮列,也可以放在 $1$ 炮列,那么
  • 放置两个,相似地我们可以选择放置的列炮的数量。
    • 放在一个有炮、一个没有炮的列,那么
    • 都放在没有炮的列,此时要在空列里选择两个:
    • 都放在有一个炮的列里,此时要在有一个炮的列里选择两个:

然后,就做完了。时间复杂度 $O(n^3)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define ll long long
#define il inline
#define N 105
#define mod 9999973
il ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') {f=-1;} c=getchar();}
while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
ll n,m,dp[N][N][N];
il ll C(ll x) {
return x*(x-1)/2%mod;
}
int main() {
n=read(),m=read();
dp[0][0][0]=1;
for(int i=1;i<=n;i++) {
for(int j=0;j<=m;j++) {
for(int k=0;k<=m-j;k++) {
dp[i][j][k]=dp[i-1][j][k];
if(k>=1) dp[i][j][k]=(dp[i][j][k]+(m-j-k+1)*dp[i-1][j][k-1]%mod)%mod;
if(j>=1) dp[i][j][k]=(dp[i][j][k]+(k+1)*dp[i-1][j-1][k+1]%mod)%mod;
if(j>=1) dp[i][j][k]=(dp[i][j][k]+(m-j-k+1)*k%mod*dp[i-1][j-1][k]%mod)%mod;
if(k>=2) dp[i][j][k]=(dp[i][j][k]+C(m-j-k+2)*dp[i-1][j][k-2]%mod)%mod;
if(j>=2) dp[i][j][k]=(dp[i][j][k]+C(k+2)*dp[i-1][j-2][k+2]%mod)%mod;
}
}
}
ll ans=0;
for(int j=0;j<=m;j++) for(int k=0;k<=m;k++) ans=(ans+dp[n][j][k])%mod;
printf("%lld\n",ans);
return 0;
}

P7961

给定 $n,m,k$ 和长度为 $m+1$ 的正整数数组 $v_0,v_1,\dots,v_m$。对于一个长度为 $n$,下标从 $1$ 开始且每个元素均不超过 $m$ 的非负整数序列 $\{a_i\}$,我们定义它的权值为 $v_{a_1} \times v_{a_2} \times \cdots \times v_{a_n}$。当这样的序列 $\{a_i\}$ 满足整数 $S = 2^{a_1} + 2^{a_2} + \cdots + 2^{a_n}$ 的二进制表示中 $1$ 的个数不超过 $k$ 时,我们认为 $\{a_i\}$ 是一个合法序列。计算所有合法序列 $\{a_i\}$ 的权值和对 $998244353$ 取模的结果。$n,k\le30,m\le100$。

首先我们想一下怎么求出 $\operatorname{popcount}(S)$。我们记 $cnt_i$ 为 $a$ 数组中 $i$ 出现的次数,我们从低到高枚举 $i$,让 $\mathrm{ans}$ 加上 $cnt_i\bmod 2$,并且让 $cnt_{i+1}$ 加上 $\left\lfloor\dfrac{cnt_i}2\right\rfloor$(其实就是模拟二进制进位)。这样就可以在 $O(m)$ 的复杂度内求出 $\operatorname{popcount}(S)$。

这个顺序启示我们从 $0\sim m$ 的顺序 dp。设 $dp_{i,j,k,l}$ 表示考虑 $S$ 的 $0\sim i$ 位且填了 $j$ 个 $a$ 数组的元素,$S$ 总共有 $k$ 个 $1$ 且要向下一位进位 $l$ 的总权值。

接下来我们枚举有 $t$ 个 $a_p=i$,那么就会有 $t+l$ 个 $1$ 会进到 $S$ 的下一位,根据上面求 $\operatorname{popcount}(S)$ 的方法我们可以得知

考虑那个 $?$ 是什么,显然就是 $i$ 这个数的贡献 $v_i^t$ 乘上选择 $t$ 个点放置的方案数 $\dbinom{n-j}{t}$。因此

这就是最后的转移方程。

注意:在 dp 到 $m$ 之后可能会有剩余的数要进位,因此要考虑能不能统计入答案。

P4516

给定一棵树,可以对恰好 $k$ 个点标记,要求每个点都需要有至少一个被标记的点与它相邻(自己不算)。问可行的方案数。$n\le10^5,k\le100$。

设状态 $dp_{u,i,0/1,0/1}$ 表示考虑 $u$ 子树内恰有 $i$ 个点被标记、$u$ 是/否被标记、$u$ 是/否已有标记点与它相邻。

转移类似一个树形背包的形式,即将 $v$ 合并到 $u$ 上。接下来开始分讨

  • $dp_{u,i+j,0,0}$。那么 $v$ 必须不能被标记、而且已经被覆盖。从 $dp_{u,i,0,0}\cdot dp_{v,j,0,1}$ 转移。
  • $dp_{u,i+j,0,1}$。分两种情况:
    • 若 $u$ 之前已经被覆盖,那么 $v$ 无论被不被标记都可以。
    • 若 $u$ 之前没被覆盖,那么 $v$ 必须被标记。
  • $dp_{u,i+j,1,0}$,那么 $v$ 不能被标记,从 $dp_{u,i,1,0}\cdot(dp_{v,j,0,0}+dp_{v,j,0,1})$ 转移。
  • $dp_{u,i+j,1,1}$,又从两种情况讨论:
    • 若 $u$ 没被覆盖,和上面类似;
    • 若 $u$ 已经被覆盖,那么 $v$ 没有任何限制。

想明白了代码也非常不好写。时间复杂度是树上背包的复杂度,可以证明是 $O(nk)$。

这个题如果 dp 数组开 long long 会 MLE。另外要存一个 cpy 数组保存之前数组的值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define ll long long
#define il inline
#define N 100005
#define K 105
#define mod 1000000007
il ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') {f=-1;} c=getchar();}
while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,k,siz[N];
int dp[N][K][2][2];ll cpy[K][2][2];
vector<int> g[N];
il void dfs(int u,int fa) {
siz[u]=dp[u][0][0][0]=dp[u][1][1][0]=1;
for(auto v:g[u]) {
if(v==fa) continue;
dfs(v,u);
for(int i=0;i<=min(siz[u],k);i++) {
cpy[i][0][0]=dp[u][i][0][0];
cpy[i][0][1]=dp[u][i][0][1];
cpy[i][1][0]=dp[u][i][1][0];
cpy[i][1][1]=dp[u][i][1][1];
dp[u][i][0][0]=dp[u][i][0][1]=dp[u][i][1][0]=dp[u][i][1][1]=0;
}
for(int i=0;i<=min(siz[u],k);i++) {
for(int j=0;j<=min(siz[v],k-i);j++) {
dp[u][i+j][0][0]=(1ll*dp[u][i+j][0][0]+1ll*cpy[i][0][0]*dp[v][j][0][1]%mod)%mod;

dp[u][i+j][0][1]=(1ll*dp[u][i+j][0][1]+1ll*cpy[i][0][1]*((1ll*dp[v][j][0][1]+1ll*dp[v][j][1][1])%mod)%mod)%mod;
dp[u][i+j][0][1]=(1ll*dp[u][i+j][0][1]+1ll*cpy[i][0][0]*dp[v][j][1][1]%mod)%mod;

dp[u][i+j][1][0]=(1ll*dp[u][i+j][1][0]+1ll*cpy[i][1][0]*((1ll*dp[v][j][0][0]+1ll*dp[v][j][0][1])%mod)%mod)%mod;

dp[u][i+j][1][1]=(1ll*dp[u][i+j][1][1]+1ll*cpy[i][1][0]*((1ll*dp[v][j][1][0]+1ll*dp[v][j][1][1])%mod)%mod)%mod;
dp[u][i+j][1][1]=(1ll*dp[u][i+j][1][1]+1ll*cpy[i][1][1]*((1ll*dp[v][j][0][0]+1ll*dp[v][j][0][1]+1ll*dp[v][j][1][0]+1ll*dp[v][j][1][1])%mod)%mod)%mod;
}
}
siz[u]+=siz[v];
}
}
int main() {
n=read(),k=read();
for(int i=1,u,v;i<n;i++) {
u=read(),v=read();
g[u].push_back(v),g[v].push_back(u);
}
dfs(1,0);
ll ans=(1ll*dp[1][k][0][1]+1ll*dp[1][k][1][1])%mod;
printf("%lld\n",ans);
return 0;
}

P5643

听得最懂的一道黑题

给定一棵有根树,定义从 $u$ 进行一步随机游走为等概率移动到与 $u$ 相邻的所有 $v$。

有 $q$ 次询问,每次询问给出一个点集,求从根开始随机游走,经过点集中每个点至少一次的期望步数。$n\le18,q\le5000$。

我们记 $d_u$ 为 $u$ 的度。首先注意到对于一个已知的点集 $S$,记 $E(u)$ 为从根节点走到 $u$ 的期望步数,那么答案就是 $\displaystyle\max_{u\in S}E(u)$。但是这个 $\max$ 很不好做,可以利用 min-max 容斥 转换成求最小值:

这个式子是线性的,因此对期望也同样适用,因此我们得知

这样就转化成求到达点集中第一个点的期望步数。考虑写出 dp 方程之后消元求解。

发现到达点集之后就无法再移动,因此传统的设状态方法不再适用。我们设 $dp_{u,S}$ 表示从 $u$ 出发到达 $S$ 中最早一个点的期望时间。那么有转移方程

朴素的高斯消元是 $O(n^32^n)$ 的,过不了。

一个结论是在树上进行高斯消元时,可以将转移方程写成与父亲相关的函数,且系数是和子树相关的。在这里函数是一个简单的一次函数 $dp_{u,S}=k_{u,S}\cdot dp_{f,S}+b_{u,S}$。

我们来求解一下这个东西:

我们记

那么

这样我们就得出了对于一个给定的 $S$,$u$ 和父亲 $f$ 之间的转移关系。事实上我们可以由 $dp_{rt,S}=b_{rt,S}$ 来反推出所有的 $dp_{u,S}$,但是没有必要,因为我们只需要 $dp_{rt,S}$。

现在我们可以通过 $O(n\log n\cdot 2^n)$ 的预处理来求出对于所有的 $S$,$\min\limits_{u\in S}E(u)$ 的值。(这里的 $\log$ 是求逆元的复杂度)现在我们再预处理对于所有的 $S$,$\max\limits_{u\in S}E(u)$ 的值。

我们首先求出 $p_{S}$ 表示在 min-max 容斥中 $S$ 集合的系数。由于这个过程是个求和,因此我们用高维前缀和来进行优化,这样复杂度仍然是 $O(n2^n)$ 的。

最后我们进行询问,直接查表即可。总复杂度是 $O(n2^n)$(假设忽略掉求逆元的复杂度)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define ll long long
#define il inline
#define mod 998244353
#define N 20
il ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') {f=-1;} c=getchar();}
while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
ll n,q,rt,deg[N],k[N],b[N],sumk[N],sumb[N];
vector<ll> g[N];
ll fac[1<<N],f[1<<N];
il ll qpow(ll a,ll b) {
ll ans=1;
while(b) {
if(b&1) ans=ans*a%mod;
a=a*a%mod,b>>=1;
}
return ans;
}
il void dfs(int u,int fa,ll S) {
bool flg=(S>>u)&1;
sumk[u]=sumb[u]=k[u]=b[u]=0;
for(auto v:g[u]) {
if(v==fa) continue;
dfs(v,u,S);
if(!flg) sumk[u]=(sumk[u]+k[v])%mod,sumb[u]=(sumb[u]+b[v])%mod;
}
if(flg) return;
k[u]=qpow(((deg[u]-sumk[u])%mod+mod)%mod,mod-2);
b[u]=(deg[u]+sumb[u])%mod*k[u]%mod;
}
int main() {
n=read(),q=read(),rt=read()-1;
for(int i=1,u,v;i<n;i++) {
u=read()-1,v=read()-1;
g[u].push_back(v),g[v].push_back(u);
deg[u]++,deg[v]++;
}
fac[0]=mod-1;
for(ll S=1;S<(1<<n);S++) {
fac[S]=(fac[S>>1]*((S&1)?-1:1)%mod+mod)%mod;
dfs(rt,-1,S);
f[S]=b[rt]*fac[S]%mod;
}
for(int i=0;i<n;i++) {
for(ll S=1;S<(1<<n);S++) {
if((S>>i)&1) f[S]=(f[S]+f[S^(1<<i)])%mod;
}
}
while(q--) {
ll k=read(),S=0;
for(int i=1,u;i<=k;i++) {
u=read()-1;
S|=(1<<u);
}
printf("%lld\n",(f[S]%mod+mod)%mod);
}
return 0;
}

Trick: Stirling 反演

有些题里会有一些恶心的 $k$ 次幂,这时候就可以用 Stirling 公式来求值:

例子:CF1278F

给定 $n,m,k$,令 $p=\dfrac1m,q=1-m$,求

后面那个 $i^k$ 很恶心,考虑变换:

考虑求后面那个式子

因此原式就是

暴力递推第二类 Stirling 数并且 $O(n)$ 求下降幂,是 $O(k^2)$ 的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define ll long long
#define il inline
#define mod 998244353
#define N 5005
il ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') {f=-1;} c=getchar();}
while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
ll S[N][N],n,m,k;
il ll qpow(ll a,ll b) {
ll ans=1;
while(b) {
if(b&1) ans=ans*a%mod;
a=a*a%mod,b>>=1;
}
return ans;
}
int main() {
n=read(),m=read(),k=read();
S[0][0]=1;
for(ll i=1;i<=k;i++) {
for(ll j=0;j<=i;j++) {
S[i][j]=(S[i-1][j-1]+j*S[i-1][j]%mod)%mod;
}
}
ll p=qpow(m,mod-2),ans=0;
for(ll i=1,dpow=1,ppow=1;i<=k;i++) {
ppow=ppow*p%mod;
dpow=dpow*(n-i+1)%mod;
ans=(ans+S[k][i]*dpow%mod*ppow%mod)%mod;
}
printf("%lld\n",ans);
return 0;
}