题意
给出一个串
思路
- 后缀数组 + 单调栈;
- 把
A 和所有Bi 拼接在一起,用间隔符分隔后做后缀排序并求出height 数组; - 一个字符串的每个后缀的每个前缀都是这个字符串的一个子串。后缀排序后,相似的后缀就会排在一起,也就是说,相同的子串也会被排在一起;
- 先找出
A 所有的本质不同的子串。可以用单调栈维护这个东西,方法类似 CF123D 。每从栈中弹出一个元素,就会得到一个区间,这个区间内的最长公共前缀长度为这个元素的height 。如果这个区间内没有B 的后缀,就说明这些子串没有在B 中出现,要统计它们的答案;否则不能统计这段区间的答案。
代码
已注释一些容易写错的地方。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int MAXN=330000; // 数组要开到 3 倍
const int MAXS=256;
int t, n, m, l;
char s[MAXN], a[MAXN];
int sa[MAXN], rnk[2*MAXN], rnk1[2*MAXN];
int ct[MAXS], cnt[MAXN], tmp[MAXN];
int height[MAXN], sum[MAXN], top;
struct Node
{
int pos, val;
Node(int p=0, int v=0): pos(p), val(v) {}
} sta[MAXN];
int main()
{
// freopen("hdu4416.in", "r", stdin);
// freopen("hdu4416.out", "w", stdout);
scanf("%d", &t);
for (int c=1; c<=t; c++)
{
scanf("%d%s", &m, a+1);
n=l=strlen(a+1);
for (int i=1; i<=m; i++)
{
a[++n]='#'; scanf("%s", s);
for (int j=0; s[j]; j++) a[++n]=s[j];
// 不要每次 strlen ,会超时
}
a[n+1]=0; // 一定要把这里赋成 '\0' ,否则 height 会出错
memset(ct, 0, sizeof ct);
memset(rnk, 0, sizeof rnk);
for (int i=1; i<=n; i++) ct[a[i]]=1;
for (int i=1; i<MAXS; i++) ct[i]+=ct[i-1];
for (int i=1; i<=n; i++) rnk[i]=ct[a[i]];
for (int k=0, p=1; p<=n&&k!=n; p<<=1)
{
memset(cnt, 0, sizeof cnt);
for (int i=1; i<=n; i++) cnt[rnk[i+p]]++;
for (int i=1; i<=n; i++) cnt[i]+=cnt[i-1];
for (int i=n; i>=1; i--) tmp[cnt[rnk[i+p]]--]=i;
memset(cnt, 0, sizeof cnt);
for (int i=1; i<=n; i++) cnt[rnk[i]]++;
for (int i=1; i<=n; i++) cnt[i]+=cnt[i-1];
for (int i=n; i>=1; i--) sa[cnt[rnk[tmp[i]]]--]=tmp[i];
memcpy(rnk1, rnk, sizeof rnk1);
rnk[sa[1]]=k=1;
for (int i=2; i<=n; i++)
{
if (rnk1[sa[i]]!=rnk1[sa[i-1]]||rnk1[sa[i]+p]!=rnk1[sa[i-1]+p])
k++;
rnk[sa[i]]=k;
}
}
height[n+1]=0;
for (int i=1, k=0; i<=n; i++)
{
if (rnk[i]==1)
{
height[rnk[i]]=k=0;
continue;
}
if (--k<0) k=0;
while (a[i+k]==a[sa[rnk[i]-1]+k]) k++;
height[rnk[i]]=k;
}
long long ans=0; // 1E6*1E6 需要开 long long
sum[0]=0; top=0;
// 下面是做了一个前缀和, sum[r]-sum[l-1] 表示 [l,r] 中 A 的后缀的数量
for (int i=1; i<=n; i++) sum[i]=sum[i-1]+(sa[i]<=l);
for (int i=1; i<=n+1; i++)
{
int t=i;
while (top>0&&height[i]<sta[top-1].val)
{
t=min(t, sta[--top].pos);
if (sum[i-1]-sum[sta[top].pos-2]==i-sta[top].pos+1)
ans+=sta[top].val-max(top>0?sta[top-1].val:0, height[i]);
// 减掉被大区间统计过的答案
}
sta[top++]=Node(t, height[i]);
}
for (int i=1; i<=n; i++)
if (sa[i]<=l) ans+=max(l-sa[i]-max(height[i], height[i+1])+1, 0);
printf("Case %d: %lld\n", c, ans);
}
return 0;
}