题意
给出两个字符串
思路
- 后缀数组 + 单调栈;
- 把两个字符串连接起来,中间用分隔符隔开,做后缀排序;
- 把后缀分为两类,
sa[i]\leq n (即A 串的后缀)分为一类,其余(即B 串的后缀)分为一类; - 维护单调递增的单调栈。单调栈中的每一个元素都表示了一段长度为自身
height 的公共前缀; - 设这个区间是
[l, r] ,栈顶元素的位置为mid ,它的height 为h ,则区间[l, mid) 中的任意后缀与区间[mid, r] 中的任意后缀的 LCP 为h 。每个不同类的后缀都会对答案做出h-k+1 的贡献,所以这段区间的总贡献为:
- 其中一段区间内
A 串与B 串的后缀个数可以用前缀和预处理; - 答案和统计贡献时的乘法都要开 long long 。
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int MAXN=220000;
const int MAXS=256;
int n, m, l; char a[MAXN];
int sa[MAXN], rnk[2*MAXN], rnk1[2*MAXN];
int st[MAXN], cnt[MAXN], tmp[MAXN];
int height[MAXN], sum[MAXN], top;
struct Node
{
int id, pos;
Node(int i=0, int p=0): id(i), pos(p) {}
} sta[MAXN];
int main()
{
// freopen("poj3415.in", "r", stdin);
// freopen("poj3415.out", "w", stdout);
while (~scanf("%d", &m)&&m)
{
scanf("%s", a+1); l=n=strlen(a+1); a[l+1]='#';
scanf("%s", a+l+2); l=strlen(a+1);
memset(st, 0, sizeof st);
memset(rnk, 0, sizeof rnk);
for (int i=1; i<=l; i++) st[a[i]]=1;
for (int i=1; i<=MAXS; i++) st[i]+=st[i-1];
for (int i=1; i<=l; i++) rnk[i]=st[a[i]];
for (int k=0, p=1; k!=l; p<<=1)
{
memset(cnt, 0, sizeof cnt);
for (int i=1; i<=l; i++) cnt[rnk[i+p]]++;
for (int i=1; i<=l; i++) cnt[i]+=cnt[i-1];
for (int i=l; i>=1; i--) tmp[cnt[rnk[i+p]]--]=i;
memset(cnt, 0, sizeof cnt);
for (int i=1; i<=l; i++) cnt[rnk[i]]++;
for (int i=1; i<=l; i++) cnt[i]+=cnt[i-1];
for (int i=l; i>=1; i--) sa[cnt[rnk[tmp[i]]]--]=tmp[i];
memcpy(rnk1, rnk, sizeof rnk1);
rnk[sa[1]]=k=1;
for (int i=2; i<=l; i++)
{
if (rnk1[sa[i]]!=rnk1[sa[i-1]]||rnk1[sa[i]+p]!=rnk1[sa[i-1]+p])
k++;
rnk[sa[i]]=k;
}
}
height[l+1]=0;
for (int i=1, k=0; i<=l; i++)
{
if (rnk[i]==1)
{
height[rnk[i]]=k=0;
continue;
}
if (--k<0) k=0;
while (a[i+k]==a[sa[rnk[i]-1]+k]) k++;
height[rnk[i]]=k;
}
sum[0]=0;
for (int i=1; i<=l; i++) sum[i]=sum[i-1]+(sa[i]<=n);
long long ans=0;
top=0;
for (int i=1; i<=l+1; i++)
{
int t=i;
while (top>0&&height[i]<height[sta[top-1].id])
{
t=min(t, sta[--top].pos);
if (height[sta[top].id]>=m)
ans+=(1ll*(sum[sta[top].id-1]-sum[sta[top].pos-2])
*((i-sta[top].id)-(sum[i-1]-sum[sta[top].id-1]))
+(1ll*(sta[top].id-sta[top].pos+1)-(sum[sta[top].id-1]-sum[sta[top].pos-2]))
*(sum[i-1]-sum[sta[top].id-1]))*(height[sta[top].id]-m+1);
}
sta[top++]=Node(i, t);
}
printf("%lld\n", ans);
}
return 0;
}