0%

dp选讲

7.11dp by 一休哥777 orz

P2051

在一个 n×m 的棋盘上放置若干个互不攻击的炮,求摆放的方案数。n,m100

首先发现炮不能互相攻击等价于每行每列最多有 2 个炮。

考虑设 dp 状态 dpi,j,k 表示当前到达第 i 行,且有 j 列放置两个炮、k 列放置一个炮。那么相应的放置 0 个炮的列数就是 mjk

在考虑 i 行的时候考虑这一行放置几个炮:

  • 放置 0 个,dpi,j,kdpi1,j,k
  • 放置 1 个,那么这一个炮可能放在原来的 0 炮列,也可以放在 1 炮列,那么dpi,j,k(mjk+1)dpi1,j,k1+(k+1)dpi1,j1,k+1
  • 放置两个,相似地我们可以选择放置的列炮的数量。
    • 放在一个有炮、一个没有炮的列,那么dpi,j,k(mjk+1)jdpi1,j1,k
    • 都放在没有炮的列,此时要在空列里选择两个:dpi,j,k(mjk+22)dpi1,j,k2
    • 都放在有一个炮的列里,此时要在有一个炮的列里选择两个:dpi,j,k(k+22)dpi1,j2,k+2

然后,就做完了。时间复杂度 O(n3)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define ll long long
#define il inline
#define N 105
#define mod 9999973
il ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') {f=-1;} c=getchar();}
while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
ll n,m,dp[N][N][N];
il ll C(ll x) {
return x*(x-1)/2%mod;
}
int main() {
n=read(),m=read();
dp[0][0][0]=1;
for(int i=1;i<=n;i++) {
for(int j=0;j<=m;j++) {
for(int k=0;k<=m-j;k++) {
dp[i][j][k]=dp[i-1][j][k];
if(k>=1) dp[i][j][k]=(dp[i][j][k]+(m-j-k+1)*dp[i-1][j][k-1]%mod)%mod;
if(j>=1) dp[i][j][k]=(dp[i][j][k]+(k+1)*dp[i-1][j-1][k+1]%mod)%mod;
if(j>=1) dp[i][j][k]=(dp[i][j][k]+(m-j-k+1)*k%mod*dp[i-1][j-1][k]%mod)%mod;
if(k>=2) dp[i][j][k]=(dp[i][j][k]+C(m-j-k+2)*dp[i-1][j][k-2]%mod)%mod;
if(j>=2) dp[i][j][k]=(dp[i][j][k]+C(k+2)*dp[i-1][j-2][k+2]%mod)%mod;
}
}
}
ll ans=0;
for(int j=0;j<=m;j++) for(int k=0;k<=m;k++) ans=(ans+dp[n][j][k])%mod;
printf("%lld\n",ans);
return 0;
}

P7961

给定 n,m,k 和长度为 m+1 的正整数数组 v0,v1,,vm。对于一个长度为 n,下标从 1 开始且每个元素均不超过 m 的非负整数序列 {ai},我们定义它的权值为 va1×va2××van。当这样的序列 {ai} 满足整数 S=2a1+2a2++2an 的二进制表示中 1 的个数不超过 k 时,我们认为 {ai} 是一个合法序列。计算所有合法序列 {ai} 的权值和对 998244353 取模的结果。n,k30,m100

首先我们想一下怎么求出 popcount(S)。我们记 cntia 数组中 i 出现的次数,我们从低到高枚举 i,让 ans 加上 cntimod2,并且让 cnti+1 加上 cnti2(其实就是模拟二进制进位)。这样就可以在 O(m) 的复杂度内求出 popcount(S)

这个顺序启示我们从 0m 的顺序 dp。设 dpi,j,k,l 表示考虑 S0i 位且填了 ja 数组的元素,S 总共有 k1 且要向下一位进位 l 的总权值。

接下来我们枚举有 tap=i,那么就会有 t+l1 会进到 S 的下一位,根据上面求 popcount(S) 的方法我们可以得知

?dpi,j,k,ldpi+1,j+t,k+(t+l)mod2,t+l2

考虑那个 ? 是什么,显然就是 i 这个数的贡献 vit 乘上选择 t 个点放置的方案数 (njt)。因此

(njt)vitdpi,j,k,ldpi+1,j+t,k+(t+l)mod2,t+l2

这就是最后的转移方程。

注意:在 dp 到 m 之后可能会有剩余的数要进位,因此要考虑能不能统计入答案。

P4516

给定一棵树,可以对恰好 k 个点标记,要求每个点都需要有至少一个被标记的点与它相邻(自己不算)。问可行的方案数。n105,k100

设状态 dpu,i,0/1,0/1 表示考虑 u 子树内恰有 i 个点被标记、u 是/否被标记、u 是/否已有标记点与它相邻。

转移类似一个树形背包的形式,即将 v 合并到 u 上。接下来开始分讨

  • dpu,i+j,0,0。那么 v 必须不能被标记、而且已经被覆盖。从 dpu,i,0,0dpv,j,0,1 转移。
  • dpu,i+j,0,1。分两种情况:
    • u 之前已经被覆盖,那么 v 无论被不被标记都可以。
    • u 之前没被覆盖,那么 v 必须被标记。
  • dpu,i+j,1,0,那么 v 不能被标记,从 dpu,i,1,0(dpv,j,0,0+dpv,j,0,1) 转移。
  • dpu,i+j,1,1,又从两种情况讨论:
    • u 没被覆盖,和上面类似;
    • u 已经被覆盖,那么 v 没有任何限制。

想明白了代码也非常不好写。时间复杂度是树上背包的复杂度,可以证明是 O(nk)

这个题如果 dp 数组开 long long 会 MLE。另外要存一个 cpy 数组保存之前数组的值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define ll long long
#define il inline
#define N 100005
#define K 105
#define mod 1000000007
il ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') {f=-1;} c=getchar();}
while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,k,siz[N];
int dp[N][K][2][2];ll cpy[K][2][2];
vector<int> g[N];
il void dfs(int u,int fa) {
siz[u]=dp[u][0][0][0]=dp[u][1][1][0]=1;
for(auto v:g[u]) {
if(v==fa) continue;
dfs(v,u);
for(int i=0;i<=min(siz[u],k);i++) {
cpy[i][0][0]=dp[u][i][0][0];
cpy[i][0][1]=dp[u][i][0][1];
cpy[i][1][0]=dp[u][i][1][0];
cpy[i][1][1]=dp[u][i][1][1];
dp[u][i][0][0]=dp[u][i][0][1]=dp[u][i][1][0]=dp[u][i][1][1]=0;
}
for(int i=0;i<=min(siz[u],k);i++) {
for(int j=0;j<=min(siz[v],k-i);j++) {
dp[u][i+j][0][0]=(1ll*dp[u][i+j][0][0]+1ll*cpy[i][0][0]*dp[v][j][0][1]%mod)%mod;

dp[u][i+j][0][1]=(1ll*dp[u][i+j][0][1]+1ll*cpy[i][0][1]*((1ll*dp[v][j][0][1]+1ll*dp[v][j][1][1])%mod)%mod)%mod;
dp[u][i+j][0][1]=(1ll*dp[u][i+j][0][1]+1ll*cpy[i][0][0]*dp[v][j][1][1]%mod)%mod;

dp[u][i+j][1][0]=(1ll*dp[u][i+j][1][0]+1ll*cpy[i][1][0]*((1ll*dp[v][j][0][0]+1ll*dp[v][j][0][1])%mod)%mod)%mod;

dp[u][i+j][1][1]=(1ll*dp[u][i+j][1][1]+1ll*cpy[i][1][0]*((1ll*dp[v][j][1][0]+1ll*dp[v][j][1][1])%mod)%mod)%mod;
dp[u][i+j][1][1]=(1ll*dp[u][i+j][1][1]+1ll*cpy[i][1][1]*((1ll*dp[v][j][0][0]+1ll*dp[v][j][0][1]+1ll*dp[v][j][1][0]+1ll*dp[v][j][1][1])%mod)%mod)%mod;
}
}
siz[u]+=siz[v];
}
}
int main() {
n=read(),k=read();
for(int i=1,u,v;i<n;i++) {
u=read(),v=read();
g[u].push_back(v),g[v].push_back(u);
}
dfs(1,0);
ll ans=(1ll*dp[1][k][0][1]+1ll*dp[1][k][1][1])%mod;
printf("%lld\n",ans);
return 0;
}

P5643

听得最懂的一道黑题

给定一棵有根树,定义从 u 进行一步随机游走为等概率移动到与 u 相邻的所有 v

q 次询问,每次询问给出一个点集,求从根开始随机游走,经过点集中每个点至少一次的期望步数。n18,q5000

我们记 duu 的度。首先注意到对于一个已知的点集 S,记 E(u) 为从根节点走到 u 的期望步数,那么答案就是 maxuSE(u)。但是这个 max 很不好做,可以利用 min-max 容斥 转换成求最小值:

maxxSf(x)=TS(1)|T|+1minxTf(x)

这个式子是线性的,因此对期望也同样适用,因此我们得知

maxuSE(u)=TS(1)|T|+1minuTE(u)

这样就转化成求到达点集中第一个点的期望步数。考虑写出 dp 方程之后消元求解。

发现到达点集之后就无法再移动,因此传统的设状态方法不再适用。我们设 dpu,S 表示从 u 出发到达 S 中最早一个点的期望时间。那么有转移方程

dpu,S={0,uS1+1duvdpv,S,uS

朴素的高斯消元是 O(n32n) 的,过不了。

一个结论是在树上进行高斯消元时,可以将转移方程写成与父亲相关的函数,且系数是和子树相关的。在这里函数是一个简单的一次函数 dpu,S=ku,Sdpf,S+bu,S

我们来求解一下这个东西:

dpu,S=1du(dpf,S+v is sondpv,S)+1dpu,S=1du(dpf,S+v is sonkv,Sdpu,S+bv,s)+1

我们记

Ku,S=v is sonkv,s,Bu,S=v is sonbv,s

那么

dudpu,S=dpf,S+Ku,Sdpu,S+Bu,S+dudpu,S=1duKu,Sdpf,S+du+Bu,SduKu,S

这样我们就得出了对于一个给定的 Su 和父亲 f 之间的转移关系。事实上我们可以由 dprt,S=brt,S 来反推出所有的 dpu,S,但是没有必要,因为我们只需要 dprt,S

现在我们可以通过 O(nlogn2n) 的预处理来求出对于所有的 SminuSE(u) 的值。(这里的 log 是求逆元的复杂度)现在我们再预处理对于所有的 SmaxuSE(u) 的值。

我们首先求出 pS 表示在 min-max 容斥中 S 集合的系数。由于这个过程是个求和,因此我们用高维前缀和来进行优化,这样复杂度仍然是 O(n2n) 的。

最后我们进行询问,直接查表即可。总复杂度是 O(n2n)(假设忽略掉求逆元的复杂度)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define ll long long
#define il inline
#define mod 998244353
#define N 20
il ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') {f=-1;} c=getchar();}
while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
ll n,q,rt,deg[N],k[N],b[N],sumk[N],sumb[N];
vector<ll> g[N];
ll fac[1<<N],f[1<<N];
il ll qpow(ll a,ll b) {
ll ans=1;
while(b) {
if(b&1) ans=ans*a%mod;
a=a*a%mod,b>>=1;
}
return ans;
}
il void dfs(int u,int fa,ll S) {
bool flg=(S>>u)&1;
sumk[u]=sumb[u]=k[u]=b[u]=0;
for(auto v:g[u]) {
if(v==fa) continue;
dfs(v,u,S);
if(!flg) sumk[u]=(sumk[u]+k[v])%mod,sumb[u]=(sumb[u]+b[v])%mod;
}
if(flg) return;
k[u]=qpow(((deg[u]-sumk[u])%mod+mod)%mod,mod-2);
b[u]=(deg[u]+sumb[u])%mod*k[u]%mod;
}
int main() {
n=read(),q=read(),rt=read()-1;
for(int i=1,u,v;i<n;i++) {
u=read()-1,v=read()-1;
g[u].push_back(v),g[v].push_back(u);
deg[u]++,deg[v]++;
}
fac[0]=mod-1;
for(ll S=1;S<(1<<n);S++) {
fac[S]=(fac[S>>1]*((S&1)?-1:1)%mod+mod)%mod;
dfs(rt,-1,S);
f[S]=b[rt]*fac[S]%mod;
}
for(int i=0;i<n;i++) {
for(ll S=1;S<(1<<n);S++) {
if((S>>i)&1) f[S]=(f[S]+f[S^(1<<i)])%mod;
}
}
while(q--) {
ll k=read(),S=0;
for(int i=1,u;i<=k;i++) {
u=read()-1;
S|=(1<<u);
}
printf("%lld\n",(f[S]%mod+mod)%mod);
}
return 0;
}

Trick: Stirling 反演

有些题里会有一些恶心的 k 次幂,这时候就可以用 Stirling 公式来求值:

nm=i=1m{mi}ni

例子:CF1278F

给定 n,m,k,令 p=1m,q=1m,求

i=1n(ni)piqniikmod998244353

后面那个 ik 很恶心,考虑变换:

i=0n(ni)piqniik=i=0n(ni)piqnij=1k{kj}ij=j=1k{kj}i=0n(ni)piqniij

考虑求后面那个式子

i=0n(ni)piqniij=i=jn(ni)piqni(ij)j!=j!i=jnn!i!(ni)!i!j!(ij)!piqni=j!(nj)i=jn(njij)piqni=nji=0nj(nji)pi+jqnij=nj(p+q)njpj=njpj

因此原式就是

j=1k{kj}njpj

暴力递推第二类 Stirling 数并且 O(n) 求下降幂,是 O(k2) 的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define ll long long
#define il inline
#define mod 998244353
#define N 5005
il ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') {f=-1;} c=getchar();}
while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
ll S[N][N],n,m,k;
il ll qpow(ll a,ll b) {
ll ans=1;
while(b) {
if(b&1) ans=ans*a%mod;
a=a*a%mod,b>>=1;
}
return ans;
}
int main() {
n=read(),m=read(),k=read();
S[0][0]=1;
for(ll i=1;i<=k;i++) {
for(ll j=0;j<=i;j++) {
S[i][j]=(S[i-1][j-1]+j*S[i-1][j]%mod)%mod;
}
}
ll p=qpow(m,mod-2),ans=0;
for(ll i=1,dpow=1,ppow=1;i<=k;i++) {
ppow=ppow*p%mod;
dpow=dpow*(n-i+1)%mod;
ans=(ans+S[k][i]*dpow%mod*ppow%mod)%mod;
}
printf("%lld\n",ans);
return 0;
}