0%

后缀数组学习笔记

SA 数组

后缀数组维护了字符串每个后缀的大小关系。利用这个数组可以维护很多厉害的信息。

但是首先要求出 $sa,rk$ 数组。

$sa_i$ 表示 $s[i\dots n]$ 这个后缀的大小排名,而 $rk_i$ 代表排名为 $i$ 的开头位置。显然有 $sa_{rk_i}=rk_{sa_i}=i$。

首先一个暴力算法是显然的:将所有后缀字符串排序。时间复杂度为 $O(n^2\log n)$。这个速度肯定不能被接受,需要优化,可以使用倍增算法来优化。

首先对长度为 $1$ 的子串排序,这个东西很简单,于是得到了所有长度为 $1$ 的子串大小关系 $p_i$。

接下来再对长度为 $2$ 的子串排序,方法是将相邻的两个长度 $1$ 子串合并起来成为 $c_i=(p_i,p_{i+1})$。接下来对 $c_i$ 进行双关键字排序即可得到长度为 $2$ 的子串大小关系。

为什么是对的呢?其实很简单,先比较字符串的前半段,因为已经得到了大小关系所以可以直接比较;如果前半段相同就比较后半段。

接下来如法炮制,得到 $4,8,16,\dots,2^w$ 直到长度比 $n$ 大。这样就得到了所有后缀的大小关系。这样时间复杂度是 $O(n\log^2n)$。

一个发现是上面的做法中使用了 sort 进行排序,是 $O(n\log n)$ 的,再乘上倍增的复杂度就是 $O(n\log^2n)$ 的。

一个简单的思路就是换一个排序,比如基数排序或者计数排序。这样的复杂度就是 $O(n\log n)$ 的了。

此时根据上面 $sa,rk$ 的关系可以求出 $rk$ 数组。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
il void solve() {
int m=127;
for(int i=1;i<=n;i++) x[i]=s[i],cnt[x[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[cnt[x[i]]--]=i;
for(int w=1;w<=n;w<<=1) {
int tot=0;
for(int i=n-w+1;i<=n;i++) y[++tot]=i;
for(int i=1;i<=n;i++) if(sa[i]>w) y[++tot]=sa[i]-w;
//第二关键字可以不进行基数排序。只需要在排序前按照排序后的顺序放进去就可以了。
//基数排序是稳定的排序算法。
for(int i=1;i<=m;i++) cnt[i]=0;
for(int i=1;i<=n;i++) cnt[x[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[cnt[x[y[i]]]--]=y[i],y[i]=0;
swap(x,y);tot=x[sa[1]]=1;
for(int i=2;i<=n;i++) {
if(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+w]==y[sa[i-1]+w]) x[sa[i]]=tot;
else x[sa[i]]=++tot;
}
if(tot==n) break;
m=tot; //优化值域、提前退出来减小常数
}
for(int i=1;i<=n;i++) rk[sa[i]]=i;
}

height 数组

为啥要叫 height 捏

$\text{height}$ 数组记录了排序后相邻后缀的最长公共前缀,lcp 的长度。也就是说 $\text{height}_i=\operatorname{LCP}(sa_i,sa_{i-1})$。

如何求 $\text{height}$ 数组?一个显然的思路是暴力匹配,但是这东西是 $O(n^2)$ 的,太慢了。

而 $O(n)$ 求 $\text{height}$ 需要一个引理:

证明:当 $\text{height}_{rk_{i-1}}\le1$ 时显然成立。

如果 $\text{height}_{rk_{i-1}}>1$,则根据定义有 $\operatorname{LCP}(sa_{rk_{i-1}},sa_{rk_{i-1}-1})>1$,即 $\operatorname{LCP}(i-1,sa_{rk_{i-1}-1})>1$。

那么不妨设这个公共前缀为 $cA$,那么可以设后缀 $i-1=cAB$,$sa_{rk_{i-1}-1}=cAD$。这样 $i=AB$,并且有一个后缀 $sa_{rk_{i-1}-1}+1=AD$。

根据后缀数组的定义可以发现 $sa_{rk_i-1}$ 只比 $i$ 靠前一个,并且 $AD<AB$,所以 $AD\le sa_{rk_i-1}<AB$。

所以

可以得到 $\text{height}_{rk_i}\ge\text{height}_{rk_{i-1}}-1$。

得到这个引理之后就可以利用它求解 $\text{height}$ 了。

1
2
3
4
5
6
for(int i=1,k=0;i<=n;i++) {
if(rk[i]==0) continue;
if(k) k--;
while(s[i+k]==s[sa[rk[i]-1]+k]) k++;
height[rk[i]]=k;
}

可以发现 $k$ 减少 $n$ 次,而 $k$ 还不会超过 $n$,因此这个算法时间复杂度为 $O(n)$。

题目

Milk Patterns G:出现至少 $k$ 次的子串的最大长度

给定一个数组,求一个最长的子数组使得它在大数组中出现了至少 $k$ 次。

事实上,如果一个子数组作为整个数组的 $k$ 后缀的公共前缀,那么它就出现了至少 $k$ 次,而可以将数组后缀排序后求出 $\text{height}$ 数组。

显然,离得更近的后缀的 lcp 长度会更大,所以最优的选择是选取 $\text{height}$ 数组的长度 $k-1$ 的区间的最小值。

于是使用单调队列取其最大值即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define ll long long
#define il inline
#define N 20005
il ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') {f=-1;} c=getchar();}
while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,k,m=1000005;
int a[N],cnt[1000005],x[N],y[N],sa[N],rk[N],height[N];
int que[N],head,tail;
il void solve() {
for(int i=1;i<=n;i++) x[i]=a[i],cnt[x[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[cnt[x[i]]--]=i;
for(int w=1;w<=n;w<<=1) {
int tot=0;
for(int i=n-w+1;i<=n;i++) y[++tot]=i;
for(int i=1;i<=n;i++) if(sa[i]>w) y[++tot]=sa[i]-w;
for(int i=1;i<=m;i++) cnt[i]=0;
for(int i=1;i<=n;i++) cnt[x[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[cnt[x[y[i]]]--]=y[i],y[i]=0;
swap(x,y),tot=x[sa[1]]=1;
for(int i=2;i<=n;i++) {
if(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+w]==y[sa[i-1]+w]) x[sa[i]]=tot;
else x[sa[i]]=++tot;
}
if(tot==n) return;
m=tot;
}
}
int main() {
n=read(),k=read();
for(int i=1;i<=n;i++) a[i]=read();
solve();
for(int i=1;i<=n;i++) rk[sa[i]]=i;
for(int i=1,k=0;i<=n;i++) {
if(rk[i]==0) continue;
if(k) k--;
while(a[i+k]==a[sa[rk[i]-1]+k]) k++;
height[rk[i]]=k;
}
head=tail=1,que[1]=1;
ll ans=0;
for(int i=2;i<=n;i++) {
while(head<=tail&&i-que[head]>=k-1) head++;
while(head<=tail&&height[i]<height[que[tail]]) tail--;
que[++tail]=i;
if(i>=k-1&&head<=tail) ans=max(ans,1ll*height[que[head]]);
}
printf("%lld\n",ans);
return 0;
}

Best Cow Line G:快速比较子串大小

给出一个含有 $n$ 个字符的双端队列,每次可以从队首或者队尾出队,求可能的出队队列中字典序最小的一个。

首先可以发现如果队头队尾的字符不同,那么一定先出小的那个。

考虑如果队头队尾的字符相同是什么情况。显然出队时还是要按照“让更小的字符更快出队”的原则。可以发现如果剩余字符串正着读字典序要比反着读的要小,那么就出头;否则出尾。

至于怎么快速比较字符串的大小,可以对字符串的正序和倒序分别建立后缀数组和 $\text{height}$ 数组。对于两个字符串 $A,B$(不妨设 $|A|\le|B|$),设 $l=\operatorname{LCP}(A,B)$,可以发现如果 $l\ge|A|$,那么 $A$ 一定是 $B$ 的前缀,显然 $A$ 更小;否则就比较对应的后缀大小(因为比较的不同点一定在两个字符串内)。这样的比较时间复杂度是 $O(1)$ 的。

总时间复杂度就是 $O(n\log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define ll long long
#define il inline
#define N 1000005
il ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') {f=-1;} c=getchar();}
while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,m=127;
int x[N],y[N],cnt[N],sa[N],rk[N],height[N];
char s[N];
il void solve() {
for(int i=1;i<=n;i++) x[i]=s[i],cnt[x[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[cnt[x[i]]--]=i;
for(int w=1;w<=n;w<<=1) {
int tot=0;
for(int i=n-w+1;i<=n;i++) y[++tot]=i;
for(int i=1;i<=n;i++) if(sa[i]>w) y[++tot]=sa[i]-w;
for(int i=1;i<=m;i++) cnt[i]=0;
for(int i=1;i<=n;i++) cnt[x[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[cnt[x[y[i]]]--]=y[i],y[i]=0;
swap(x,y);tot=x[sa[1]]=1;
for(int i=2;i<=n;i++) {
if(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+w]==y[sa[i-1]+w]) x[sa[i]]=tot;
else x[sa[i]]=++tot;
}
if(tot==n) break;
m=tot;
}
for(int i=1;i<=n;i++) rk[sa[i]]=i;
for(int i=1,k=0;i<=n;i++) {
if(rk[i]==0) continue;
if(k) k--;
while(s[i+k]==s[sa[rk[i]-1]+k]) k++;
height[rk[i]]=k;
}
}
int st[N][21],Log[N];
il void build() {
for(int i=1;i<=n;i++) st[i][0]=height[i];
for(int j=1;j<=20;j++) {
for(int i=1;i+(1<<j)<=n;i++) st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
}
}
il int RMQ(int l,int r) {
int k=Log[r-l+1];
return min(st[l][k],st[r-(1<<k)][k]);
}
int main() {
n=read();
for(int i=1;i<=n;i++) {
scanf("%s",&s[i]);
s[2*n-i+2]=s[i];
}
s[n+1]='#',n=2*n+1;
for(int i=2;i<=n;i++) Log[i]=Log[i>>1]+1;
solve();
int p=1,q=n/2,cnt=0;
build();
for(int i=1;i<=n/2;i++) {
if(s[p]!=s[q]) {
if(s[p]<s[q]) putchar(s[p++]);
else putchar(s[q--]);
cnt++;
if(cnt==80) putchar('\n'),cnt=0;
}
else {
//比较[p,q]与[q,p]大小关系。
int pp=n-p+1,qq=n-q+1;
int lcp=RMQ(min(rk[qq],rk[p]),max(rk[qq],rk[p])),pd;
if(lcp==q-p+1) pd=0;
else if(rk[qq]<rk[p]) pd=1;
else pd=-1;
if(pd>=0) putchar(s[q--]);
else putchar(s[p++]);
cnt++;
if(cnt==80) putchar('\n'),cnt=0;
}
}
return 0;
}

优秀的拆分:连续的两个相同子串

对一个字符串的所有子串求 $\text{AABB}$ 拆分的个数。

首先转化一下,设 $f_i$ 为从字符串第 $i$ 位开始向前的“$AA$”类型字符串个数,$g_i$ 为向后的个数,那么答案就是

考虑枚举 $A$ 的长度为 $l$,并且设所有 $i$ 使得 $l|i$ 为关键点,那么一个合法的 $AA$ 一定能经过两个关键点。

比如说这种情况:

这上面有一个 $AA$ 能够匹配:

那么可以发现,其 lcp 和 lcs 有这样的关系:

可以发现,这一段的 lcp 必须等于下一段的 lcp,并且这一段的 lcs 必须等于上一段的 lcs。

不妨设这一段与下一段的 lcp 为 $\text{lcp}$,并且这一段与上一段的 lcs 为 $\text{lcs}$。那么显然 $\text{lcp}+\text{lcs}=l$ 时,有一组 $AA$ 匹配。事实上,如果 $\text{lcp}+\text{lcs}<l$ 那么不存在合法的匹配,并且如果 $\text{lcp}+\text{lcs}\ge l$,那么就有 $t=\text{lcp}+\text{lcs}-l+1$ 组合法的匹配。

注意到这里 $AA$ 出现的位置是一段区间,可以使用差分来区间加。

就做完了,时间复杂度为 $O(n\log n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define ll long long
#define il inline
#define N 50005
il ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') {f=-1;} c=getchar();}
while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,m,x[N],y[N],cnt[N],sa[2][N],rk[2][N],hei[2][N];
int st[2][N][17],Log[N];
char s[N];
il void solve(int id) {
for(int i=1;i<=n;i++) x[i]=s[i],cnt[x[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[id][cnt[x[i]]--]=i;
for(int w=1;w<=n;w<<=1) {
int tot=0;
for(int i=n-w+1;i<=n;i++) y[++tot]=i;
for(int i=1;i<=n;i++) if(sa[id][i]>w) y[++tot]=sa[id][i]-w;
for(int i=1;i<=m;i++) cnt[i]=0;
for(int i=1;i<=n;i++) cnt[x[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[id][cnt[x[y[i]]]--]=y[i],y[i]=0;
swap(x,y);tot=x[sa[id][1]]=1;
for(int i=2;i<=n;i++) {
if(y[sa[id][i]]==y[sa[id][i-1]]&&y[sa[id][i]+w]==y[sa[id][i-1]+w]) x[sa[id][i]]=tot;
else x[sa[id][i]]=++tot;
}
if(tot==n) break;
m=tot;
}
for(int i=1;i<=n;i++) rk[id][sa[id][i]]=i;
for(int i=1,k=0;i<=n;i++) {
if(rk[id][i]==0) continue;
if(k) k--;
while(s[i+k]==s[sa[id][rk[id][i]-1]+k]) k++;
hei[id][rk[id][i]]=k;
}
memset(cnt,0,sizeof(cnt));
for(int i=1;i<=n;i++) x[i]=y[i]=0;
for(int i=1;i<=n;i++) st[id][i][0]=hei[id][i];
for(int j=1;j<=16;j++) {
for(int i=1;i<=n;i++) st[id][i][j]=min(st[id][i][j-1],st[id][i+(1<<(j-1))][j-1]);
}
}
il int getlcp(int id,int l,int r) {
int _l=l,_r=r;
l=min(rk[id][_l],rk[id][_r])+1,r=max(rk[id][_l],rk[id][_r]);
int k=Log[r-l+1];
return min(st[id][l][k],st[id][r-(1<<k)+1][k]);
}
ll f[N],g[N];
void mian() {
memset(st,0x3f,sizeof(st));
scanf("%s",s+1);
n=strlen(s+1),m=127;
solve(0);
for(int i=1;i*2<=n;i++) swap(s[i],s[n-i+1]);
m=127;
solve(1);
for(int i=1;i*2<=n;i++) swap(s[i],s[n-i+1]);
for(int l=1;l<=n/2;l++) {
for(int i=l;i+l<=n;i+=l) {
int j=i+l;
int lcp=min(getlcp(0,i,j),l),lcs=min(getlcp(1,n-i+2,n-j+2),l-1);
if(lcp+lcs<l) continue;
int t=lcp+lcs-l+1;
g[i-lcs]++,g[i-lcs+t]--;
f[j+lcp-t]++,f[j+lcp]--;
}
}
for(int i=1;i<=n+1;i++) f[i]+=f[i-1],g[i]+=g[i-1];
ll ans=0;
for(int i=1;i<=n;i++) ans+=f[i]*g[i+1];
printf("%lld\n",ans);
}
il void clean() {
memset(f,0,sizeof(f));memset(g,0,sizeof(g));
memset(hei,0,sizeof(hei));memset(rk,0,sizeof(rk));
memset(sa,0,sizeof(sa));memset(st,0x3f,sizeof(st));
}
int main() {
for(int i=2;i<=30000;i++) Log[i]=Log[i>>1]+1;
int t=read();
while(t--) {
mian();clean();
}
return 0;
}

品酒大会:结合并查集

本质上是让你对每个 $r\in[0,n-1]$ 求长度为 $r$ 的相同子串对的个数以及两个相同子串开头权值乘积的最大值。

首先,因为如果两杯酒是 $r$ 相似可以推出它们是 $0,1,\dots,r-1$ 相似。所以只需求出 $\operatorname{LCP}(s,t)=r$ 的个数。

再次进行转化,对原字符串建立后缀数组与 $\text{height}$ 数组,因为 $\operatorname{LCP}(sa_i,sa_j)=\min\limits_{i<k\le j}\{\text{height}_k\}$,因此又转化为 $\text{height}$ 数组区间最小值为 $r$ 的区间个数。

最后答案的求解需要使用并查集。将 $\text{height}$ 数组排序后按照倒序依次填写。假设现在要填写 $i$ 位置的 $\text{height}$,可以发现 $i$ 左边、右边已经填写部分的 $\text{height}$ 值均大于等于 $\text{height}_i$,因此区间左右端点在左右边的区间全部有贡献。此时用并查集维护已被填写的区间长度即可。

这是第一问,第二问只需维护已被填写的区间最大最小值即可。(因为 $a_i$ 可能小于 $0$)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
#include<queue>
using namespace std;
#define ll long long
#define il inline
#define N 300005
#define PII pair<int,int>
#define mkpir make_pair
#define int long long
il ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') {f=-1;} c=getchar();}
while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,x[N],y[N],cnt[N];
int sa[N],rk[N],height[N];
char s[N];
il void solve() {
int m=127;
for(int i=1;i<=n;i++) x[i]=s[i],cnt[x[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[cnt[x[i]]--]=i;
for(int w=1;w<=n;w<<=1) {
int tot=0;
for(int i=n-w+1;i<=n;i++) y[++tot]=i;
for(int i=1;i<=n;i++) if(sa[i]>w) y[++tot]=sa[i]-w;
for(int i=1;i<=m;i++) cnt[i]=0;
for(int i=1;i<=n;i++) cnt[x[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[cnt[x[y[i]]]--]=y[i],y[i]=0;
swap(x,y);tot=x[sa[1]]=1;
for(int i=2;i<=n;i++) {
if(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+w]==y[sa[i-1]+w]) x[sa[i]]=tot;
else x[sa[i]]=++tot;
}
if(tot==n) break;
m=tot;
}
for(int i=1;i<=n;i++) rk[sa[i]]=i;
for(int i=1,k=0;i<=n;i++) {
if(rk[i]==0) continue;
if(k) k--;
while(s[i+k]==s[sa[rk[i]-1]+k]) k++;
height[rk[i]]=k;
}
}
vector<int> q[N];
int a[N],now,num=-1e18;
int ans[2][N];
struct node {
int fa,siz,maxn,minn;
};
node e[N];
il int findf(int x) {
while(x!=e[x].fa) x=e[x].fa=e[e[x].fa].fa;
return x;
}
il void Union(int x,int y) {
int fx=findf(x),fy=findf(y);
now+=e[fx].siz*e[fy].siz,num=max(num,max(e[fx].maxn*e[fy].maxn,e[fx].minn*e[fy].minn));
e[fy].fa=fx,e[fx].siz+=e[fy].siz,e[fx].maxn=max(e[fx].maxn,e[fy].maxn),e[fx].minn=min(e[fx].minn,e[fy].minn);
}
signed main() {
n=read();
scanf("%s",s+1);
for(int i=1;i<=n;i++) a[i]=read();
solve();
for(int i=1;i<=n;i++) e[i].fa=i,e[i].maxn=e[i].minn=a[sa[i]],e[i].siz=1,q[height[i]].push_back(i);
for(int i=n-1;i>=0;i--) {
for(auto p:q[i]) Union(p,p-1);
if(now) ans[0][i]=now,ans[1][i]=num;
}
for(int i=0;i<n;i++) printf("%lld %lld\n",ans[0][i],ans[1][i]);
return 0;
}

找相同字符差异:结合单调栈

差异那道题更简单一些。

差异

给定一个字符串 $S$,其长度为 $n$,定义 $T_i$ 为从第 $i$ 个字符开始的后缀,你需要求出

的值。

首先这个式子分成两部分,第一部分就是求 $\text{len}$ 的和,这一部分其实就是

所以只需要求出第二部分即可。

先考虑一个暴力做法:枚举后缀位置并且计算 $\text{lcp}$ 之和。利用 $\text{height}$ 数组可以将求解过程优化到 $O(n^2)$。

但是事实上,计算 $\text{lcp}$ 就是计算区间最小值,所以两两后缀 $\text{lcp}$ 之和其实就是 $\text{height}$ 数组所有区间最小值之和。

这个过程可以用单调栈维护,其时间复杂度为 $O(n)$。

找相同字符

本质上是类似的,但是这个题要求两个串来自不同的字符串,考虑使用容斥,先将两个字符串拼起来求答案,再拆开减去两个字符串各自的答案。

答案的求解仍然是后缀 $\text{lcp}$ 之和。

代码:(差异)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
#define ll long long
#define il inline
#define N 500005
il ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {if(c=='-') {f=-1;} c=getchar();}
while(c>='0'&&c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
ll n,m,x[N],y[N],cnt[N],sa[N],rk[N],height[N],be[N],ed[N],st[N],top;
char s[N];
il void solve() {
m=127;
for(int i=1;i<=n;i++) x[i]=s[i],cnt[x[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[cnt[x[i]]--]=i;
for(int w=1;w<=n;w<<=1) {
int tot=0;
for(int i=n-w+1;i<=n;i++) y[++tot]=i;
for(int i=1;i<=n;i++) if(sa[i]>w) y[++tot]=sa[i]-w;
for(int i=1;i<=m;i++) cnt[i]=0;
for(int i=1;i<=n;i++) cnt[x[i]]++;
for(int i=2;i<=m;i++) cnt[i]+=cnt[i-1];
for(int i=n;i>=1;i--) sa[cnt[x[y[i]]]--]=y[i],y[i]=0;
swap(x,y);tot=x[sa[1]]=1;
for(int i=2;i<=n;i++) {
if(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+w]==y[sa[i-1]+w]) x[sa[i]]=tot;
else x[sa[i]]=++tot;
}
if(tot==n) break;
m=tot;
}
for(int i=1;i<=n;i++) rk[sa[i]]=i;
for(int i=1,k=0;i<=n;i++) {
if(rk[i]==0) continue;
if(k) k--;
while(s[i+k]==s[sa[rk[i]-1]+k]) k++;
height[rk[i]]=k;
}
}
il ll solve2() {
for(int i=1;i<=n;i++) be[i]=0,ed[i]=n+1;
top=0;
for(int i=1;i<=n;i++) {
while(top>0&&height[st[top]]>=height[i]) {
ed[st[top]]=i;
top--;
}
st[++top]=i;
}
top=0;
for(int i=n;i>=1;i--) {
while(top>0&&height[st[top]]>height[i]) {
be[st[top]]=i;
top--;
}
st[++top]=i;
}
ll ans=0;
for(int i=1;i<=n;i++) {
ll l=i-be[i],r=ed[i]-i;
ans+=l*r*height[i];
}
return ans;
}
int main() {
scanf("%s",s+1);n=strlen(s+1);
solve();
ll ans=solve2();
printf("%lld\n",n*(n-1)*(n+1)/2-2*ans);
return 0;
}