题意
给出一个长度为
思路
- 求后缀之间的 LCP 之和,自然地想到使用后缀数组;
- 对字符串
s 后缀排序并求出height 数组,以下提供三种统计答案的方式:
一. 分块
- 分块思路简单,代码好写,用来混部分分十分容易;
- 考虑对于集合
B 中的每一个元素向集合A 统计答案; - 首先把集合
A 按rank 排序后分为约\sqrt n 个大小为\sqrt n 的块: - 对于集合
B 中的每一个元素到A 中的每一个块,计算它对答案的贡献: - 接下来要用到一个定理:
lcp(i,j)=min_{k={rank[i]+1}}^{rank[j]} height[k]\;(i<j) 。这个公式的意思是,两个后缀i,j\;(i<j) 的 LCP 等于从rank[i]+1 到rank[j] 的height 的最小值; - 有了这个定理,我们可以用一个ST表来维护区间
height 最小值,定义函数lcp(i,j) 表示后缀i,j 之间的 LCP 长度,则lcp(i,j) 可以 O(1) 求出; - 然后对于每个块,分别求出块中每一个元素到块的左右端点的 LCP :
- 对上面求的东西做一个前缀和,用途见下文;
- 统计集合
B 中的某一元素到A 的一个块的答案分以下三种情况:- 当前块左端点的
rank \leq 当前元素的rank \leq 当前块右端点的rank : 遇到这种情况,就暴力统计当前元素到块中每一个元素的 LCP ,并将其加到答案中。这是最简单的一种情况; - 当前元素的
rank < 当前块左端点的rank : 如上图,根据上文提到的定理,可以得出块的左端点到块中元素的 LCP 是递减的。先求出当前元素到块的左端点的 LCP 为d ,那么由定理得,该元素这个块的总答案就是\sum_{i\in block} min(d, lcp(block.left,i)) 。可以在块中二分查找出最后一个与左端点的 LCP 不小于d 的元素p ,则在这个块中,以p 为分界,左边部分(即紫色部分)的每一个元素的答案都是d ,答案总和就是d 乘上这部分的长度;右边部分(即橙色部分)的答案就是它与左端点的 LCP ,使用上文中我们预处理的前缀和即可快速统计; - 当前元素的
rank > 当前块右端点的rank : 同情况 2 ,预处理出块中每个元素到这个块的右端点的 LCP 的前缀和即可用类似方法统计答案。
- 当前块左端点的
复杂度分析
- 倍增法求后缀数组,时间复杂度
O(nlogn) ; - 集合
B 中有O(n) 个元素,集合A 分成O(\sqrt n) 个块,并对块做了一次二分查找,时间复杂度为O(n \sqrt n log \sqrt n) ; - 综上,分块方法的时间复杂度为
O(n \sqrt n log \sqrt n) ,不能通过本题 。
代码
// 后缀数组+分块 O(n sqrt(n) log sqrt(n))
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <functional>
using namespace std;
const int MAXN=220000;
const int MAXQ=500;
const int MAXS=256;
const int MAXB=20;
int n, q; char s[MAXN];
int a[MAXN], b[MAXN];
int sa[MAXN], rnk[2*MAXN], rnk1[2*MAXN];
int ct[MAXS], cnt[MAXN], tmp[MAXN];
int height[MAXN], lg[MAXN], h[MAXN][MAXB];
int lef[MAXQ], righ[MAXQ];
int pos[MAXN], cl[MAXN], cr[MAXN];
long long sl[MAXN], sr[MAXN];
inline bool cmp(int a, int b)
{
return rnk[a]<rnk[b];
}
inline int query(int l, int r)
{
int k=lg[r-l+1];
return min(h[l][k], h[r-(1<<k)+1][k]);
}
inline int lcp(int i, int j)
{
if (i==j) return n-i+1;
if (rnk[i]>rnk[j]) swap(i, j);
return query(rnk[i]+1, rnk[j]);
}
int main()
{
// freopen("CF1073G.in", "r", stdin);
// freopen("CF1073G.out", "w", stdout);
scanf("%d%d%s", &n, &q, s+1);
memset(ct, 0, sizeof ct);
memset(rnk, 0, sizeof rnk);
for (int i=1; i<=n; i++) ct[s[i]]=1;
for (int i=1; i<MAXS; i++) ct[i]+=ct[i-1];
for (int i=1; i<=n; i++) rnk[i]=ct[s[i]];
for (int k=0, p=1; k!=n; p<<=1)
{
memset(cnt, 0, sizeof cnt);
for (int i=1; i<=n; i++) cnt[rnk[i+p]]++;
for (int i=1; i<=n; i++) cnt[i]+=cnt[i-1];
for (int i=n; i>=1; i--) tmp[cnt[rnk[i+p]]--]=i;
memset(cnt, 0, sizeof cnt);
for (int i=1; i<=n; i++) cnt[rnk[i]]++;
for (int i=1; i<=n; i++) cnt[i]+=cnt[i-1];
for (int i=n; i>=1; i--) sa[cnt[rnk[tmp[i]]]--]=tmp[i];
memcpy(rnk1, rnk, sizeof rnk1);
rnk[sa[1]]=k=1;
for (int i=2; i<=n; i++)
{
if (rnk1[sa[i]]!=rnk1[sa[i-1]]||rnk1[sa[i]+p]!=rnk1[sa[i-1]+p])
k++;
rnk[sa[i]]=k;
}
}
for (int i=1, k=0; i<=n; i++)
{
if (rnk[i]==1)
{
height[rnk[i]]=k=0;
continue;
}
if (--k<0) k=0;
while (s[i+k]==s[sa[rnk[i]-1]+k]) k++;
height[rnk[i]]=k;
}
lg[0]=0;
for (int i=1; i<=n; i++)
lg[i]=lg[i-1]+(1<<lg[i-1]+1==i);
for (int i=1; i<=n; i++) h[i][0]=height[i];
for (int j=1; 1<<j<=n; j++)
for (int i=1; i+(1<<j)-1<=n; i++)
h[i][j]=min(h[i][j-1], h[i+(1<<j-1)][j-1]);
while (q--)
{
int k, l;
scanf("%d%d", &k, &l);
for (int i=1; i<=k; i++) scanf("%d", &a[i]);
for (int i=1; i<=l; i++) scanf("%d", &b[i]);
sort(a+1, a+k+1, cmp);
int sz=sqrt(k), m=(k-1)/sz+1;
for (int i=1; i<=m; i++)
lef[i]=sz*(i-1)+1, righ[i]=min(sz*i, k);
for (int i=1; i<=k; i++) pos[i]=(i-1)/sz+1;
for (int i=1; i<=k; i++)
{
cl[i]=lcp(a[i], a[lef[pos[i]]]);
cr[i]=lcp(a[i], a[righ[pos[i]]]);
}
sl[0]=sr[0]=0;
for (int i=1; i<=k; i++) sl[i]=sl[i-1]+cl[i];
for (int i=1; i<=k; i++) sr[i]=sr[i-1]+cr[i];
long long ans=0;
for (int i=1; i<=l; i++)
for (int j=1; j<=m; j++)
{
if (rnk[b[i]]<rnk[a[lef[j]]])
{
int d=lcp(b[i], a[lef[j]]);
int p=upper_bound(cl+lef[j], cl+righ[j]+1, d, greater<int>())-cl;
ans+=1ll*d*(p-lef[j])+(sl[righ[j]]-sl[p-1]);
}
else if (rnk[b[i]]>rnk[a[righ[j]]])
{
int d=lcp(b[i], a[righ[j]]);
int p=lower_bound(cr+lef[j], cr+righ[j]+1, d)-cr-1;
ans+=1ll*d*(righ[j]-p)+(sr[p]-sr[lef[j]-1]);
}
else
for (int t=lef[j]; t<=righ[j]; t++)
ans+=lcp(b[i], a[t]);
}
printf("%I64d\n", ans);
}
return 0;
}
提交记录
二. 线段树
- 我们尝试对以上分块的做法进行一些改进。上面的方法之所以要分块,是因为要统计
B 中的元素到A 中rank 小于/大于它的rank 的元素的 LCP 的前缀和。那么我们可不可以做到动态维护这个东西呢? - 不妨把
B 中的元素也按rank 排序。对于B 排序后的第i 个元素,设它与A 中rank 小于等于它的元素k 的 LCP 为d[i][j] (j 表示k 按rank 排序在A 中排第j 名)。那么,根据上文提到的定理,从d[i] 转移到d[i+1] 时,只需要把所d[i] 中所有大于lcp(B[i],B[i+1]) 的都改成lcp(B[i],B[i+1]) (这一部分和分块中二分查找的那段实现的是同一个目的),再将d[i+1][j] 中加入A 中rank[B[i]]<rank[A[k]]\leq rank[B[i+1]] 的元素(j,k 含义见上文)就行了; - 我们发现以上维护
d[i] 的过程,需要支持两个操作:- 把所有大于
k 的数改为k ; - 单点修改一个元素。
- 把所有大于
- 容易发现,
d[i] 其实是一个单调的数组,也就是说,所有大于k 的数都会连在一起。也就是说,可以将其转化为区间修改操作; B[i] 的答案,就是d[i] 的和。因此,还需要支持一个区间求和的操作;- 综上,我们发现可以用线段树维护整个
d[i] 转移的操作。一颗线段树维护一层d[i] ,层层递推,即可统计所有B[i] 的答案; - 以上过程统计了
B[i] 与A[k]\;(rank[A[k]]<=rank[B[i]]) 的答案。再反向做一遍,即可类似地统计B[i] 与A[k]\;(rank[A[k]]>rank[B[i]]) 的答案。这两部分的答案相加,就是最终的答案。
复杂度分析
- 倍增法求后缀数组,时间复杂度
O(nlogn) ; - 线段树的区间修改和单点修改,都是
O(logn) 的。一共要修改n 次,所以总时间复杂度为O(nlogn) ,可以通过本题。
代码
// 后缀数组+线段树 O(nlogn)
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int MAXN=220000;
const int MAXS=256;
const int MAXB=20;
namespace ST
{
struct Node
{
int l, r;
long long sum;
int maxx, minx, cov;
} tr[4*MAXN];
struct SegmentTree
{
#define lc (o<<1)
#define rc (o<<1|1)
inline void assign(int o, int k)
{
tr[o].sum=1ll*k*(tr[o].r-tr[o].l+1);
tr[o].maxx=tr[o].minx=k;
}
inline void pushup(int o)
{
tr[o].sum=tr[lc].sum+tr[rc].sum;
tr[o].maxx=max(tr[lc].maxx, tr[rc].maxx);
tr[o].minx=min(tr[lc].minx, tr[rc].minx);
}
inline void cover(int o, int k)
{
tr[o].cov=k;
assign(o, k);
}
inline void pushdown(int o)
{
if (tr[o].cov==-1) return;
cover(lc, tr[o].cov);
cover(rc, tr[o].cov);
tr[o].cov=-1;
}
void build(int o, int l, int r)
{
tr[o].l=l; tr[o].r=r; tr[o].cov=-1;
if (l==r) { assign(o, 0); return; }
int mid=l+r>>1;
build(lc, l, mid);
build(rc, mid+1, r);
pushup(o);
}
void change(int o, int x, int k)
{
if (tr[o].l==tr[o].r) { assign(o, k); return; }
pushdown(o);
int mid=tr[o].l+tr[o].r>>1;
if (x<=mid) change(lc, x, k);
else change(rc, x, k);
pushup(o);
}
void update(int o, int k)
{
if (tr[o].maxx<=k) return;
if (tr[o].minx>k) { cover(o, k); return; }
pushdown(o);
update(lc, k);
update(rc, k);
pushup(o);
}
inline long long query(int o)
{
return tr[o].sum;
}
#undef lc
#undef rc
};
}
int n, q; char s[MAXN];
int a[MAXN], b[MAXN];
int sa[MAXN], rnk[2*MAXN], rnk1[2*MAXN];
int ct[MAXS], cnt[MAXN], tmp[MAXN];
int height[MAXN], lg[MAXN], h[MAXN][MAXB];
ST::SegmentTree st;
inline bool cmp(int a, int b)
{
return rnk[a]<rnk[b];
}
inline int query(int l, int r)
{
int k=lg[r-l+1];
return min(h[l][k], h[r-(1<<k)+1][k]);
}
inline int lcp(int i, int j)
{
if (i==j) return n-i+1;
if (rnk[i]>rnk[j]) swap(i, j);
return query(rnk[i]+1, rnk[j]);
}
int main()
{
// freopen("CF1073G.in", "r", stdin);
// freopen("CF1073G.out", "w", stdout);
scanf("%d%d%s", &n, &q, s+1);
memset(ct, 0, sizeof ct);
memset(rnk, 0, sizeof rnk);
for (int i=1; i<=n; i++) ct[s[i]]=1;
for (int i=1; i<MAXS; i++) ct[i]+=ct[i-1];
for (int i=1; i<=n; i++) rnk[i]=ct[s[i]];
for (int k=0, p=1; k!=n; p<<=1)
{
memset(cnt, 0, sizeof cnt);
for (int i=1; i<=n; i++) cnt[rnk[i+p]]++;
for (int i=1; i<=n; i++) cnt[i]+=cnt[i-1];
for (int i=n; i>=1; i--) tmp[cnt[rnk[i+p]]--]=i;
memset(cnt, 0, sizeof cnt);
for (int i=1; i<=n; i++) cnt[rnk[i]]++;
for (int i=1; i<=n; i++) cnt[i]+=cnt[i-1];
for (int i=n; i>=1; i--) sa[cnt[rnk[tmp[i]]]--]=tmp[i];
memcpy(rnk1, rnk, sizeof rnk1);
rnk[sa[1]]=k=1;
for (int i=2; i<=n; i++)
{
if (rnk1[sa[i]]!=rnk1[sa[i-1]]||rnk1[sa[i]+p]!=rnk1[sa[i-1]+p])
k++;
rnk[sa[i]]=k;
}
}
for (int i=1, k=0; i<=n; i++)
{
if (rnk[i]==1)
{
height[rnk[i]]=k=0;
continue;
}
if (--k<0) k=0;
while (s[i+k]==s[sa[rnk[i]-1]+k]) k++;
height[rnk[i]]=k;
}
lg[0]=0;
for (int i=1; i<=n; i++)
lg[i]=lg[i-1]+(1<<lg[i-1]+1==i);
for (int i=1; i<=n; i++) h[i][0]=height[i];
for (int j=1; 1<<j<=n; j++)
for (int i=1; i+(1<<j)-1<=n; i++)
h[i][j]=min(h[i][j-1], h[i+(1<<j-1)][j-1]);
while (q--)
{
int k, l;
scanf("%d%d", &k, &l);
for (int i=1; i<=k; i++) scanf("%d", &a[i]);
for (int i=1; i<=l; i++) scanf("%d", &b[i]);
sort(a+1, a+k+1, cmp);
sort(b+1, b+l+1, cmp);
long long ans=0;
st.build(1, 1, k);
for (int i=1, j=1; i<=l; i++)
{
st.update(1, lcp(b[i], b[i-1]));
while (j<=k&&rnk[a[j]]<=rnk[b[i]])
st.change(1, j, lcp(b[i], a[j])), j++;
ans+=st.query(1);
}
st.build(1, 1, k);
for (int i=l, j=k; i>=1; i--)
{
st.update(1, lcp(b[i], b[i+1]));
while (j>=1&&rnk[a[j]]>rnk[b[i]])
st.change(1, j, lcp(b[i], a[j])), j--;
ans+=st.query(1);
}
printf("%I64d\n", ans);
}
return 0;
}
提交记录
三. 单调栈
- 换一个思路,单调栈也可以解决本题;
- 本题单调栈的做法与 POJ3415 类似,可以先尝试一下这道题;
- 把集合
A 与集合B 放在一起得到多重集合C ,并将C 按rank 排序; - 为了方便,我们定义
hgt[i]=lcp(C[i], C[i-1]) 。其实这个hgt 数组的意义和height 数组是类似的; - 维护一个单调栈。栈中的每个元素是一个二元组
(i,j) ,表示这个元素是按rank 排序后的第i 名,最前面的hgt 值大于hgt[i] 的元素是C[j] 。这个单调栈中的元素的hgt[i] 的值是递增的。 - 若发现当前元素
C[i] 的hgt 值小于栈顶元素的hgt 值,则将栈顶元素弹出。设栈顶元素为C[top] ,则[C[top].j-1,C[top].i-1] 中每一个元素与[C[top].i,i-1] 中的每一个元素的 LCP 都是hgt[C[top].i] 。根据乘法原理,两个区间的长度相乘再乘上 LCP 就是这个区间对答案的贡献。最后,将当前元素入栈,它的j 就是这次弹出的元素的j 的最小值; - 形象一点,单调栈其实维护的是这个东西:
- 在上图中,最后一个元素把第五个元素弹出栈,第五个元素维护的就是红色的那个区间。还是根据那个定理,底下绿色区间中的每一个元素与蓝色区间的每一个元素的 LCP 都是第五个元素的高度(即
hgt[5] )。这个区间对答案的贡献就是3*3*hgt[5] ; - 最后不要忘记把栈中元素弹尽,把答案统计完整。
复杂度分析
- 倍增法求后缀数组,时间复杂度
O(nlogn) ; - 排序
O(nlogn) ; - 单调栈中每个元素最多进栈一次出栈一次,因此均摊复杂度为
O(n) ; - 综上,总时间复杂度为
O(nlogn) 。
代码
// 后缀数组+单调栈 O(nlogn)
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int MAXN=220000;
const int MAXS=256;
const int MAXB=20;
int n, q; char s[MAXN];
int a[MAXN], b[MAXN], c[2*MAXN]; bool d[2*MAXN];
int sa[MAXN], rnk[2*MAXN], rnk1[2*MAXN];
int ct[MAXS], cnt[MAXN], tmp[MAXN];
int height[MAXN], lg[MAXN], h[MAXN][MAXB];
long long sum[2*MAXN];
struct Node
{
int id, pos;
Node(int i=0, int p=0): id(i), pos(p) {}
} sta[2*MAXN];
int top;
inline bool cmp(int a, int b)
{
return rnk[a]<rnk[b];
}
inline int query(int l, int r)
{
int k=lg[r-l+1];
return min(h[l][k], h[r-(1<<k)+1][k]);
}
inline int lcp(int i, int j)
{
if (i==j) return n-i+1;
if (rnk[i]>rnk[j]) swap(i, j);
return query(rnk[i]+1, rnk[j]);
}
int main()
{
// freopen("CF1073G.in", "r", stdin);
// freopen("CF1073G.out", "w", stdout);
scanf("%d%d%s", &n, &q, s+1);
memset(ct, 0, sizeof ct);
memset(rnk, 0, sizeof rnk);
for (int i=1; i<=n; i++) ct[s[i]]=1;
for (int i=1; i<MAXS; i++) ct[i]+=ct[i-1];
for (int i=1; i<=n; i++) rnk[i]=ct[s[i]];
for (int k=0, p=1; k!=n; p<<=1)
{
memset(cnt, 0, sizeof cnt);
for (int i=1; i<=n; i++) cnt[rnk[i+p]]++;
for (int i=1; i<=n; i++) cnt[i]+=cnt[i-1];
for (int i=n; i>=1; i--) tmp[cnt[rnk[i+p]]--]=i;
memset(cnt, 0, sizeof cnt);
for (int i=1; i<=n; i++) cnt[rnk[i]]++;
for (int i=1; i<=n; i++) cnt[i]+=cnt[i-1];
for (int i=n; i>=1; i--) sa[cnt[rnk[tmp[i]]]--]=tmp[i];
memcpy(rnk1, rnk, sizeof rnk1);
rnk[sa[1]]=k=1;
for (int i=2; i<=n; i++)
{
if (rnk1[sa[i]]!=rnk1[sa[i-1]]||rnk1[sa[i]+p]!=rnk1[sa[i-1]+p])
k++;
rnk[sa[i]]=k;
}
}
for (int i=1, k=0; i<=n; i++)
{
if (rnk[i]==1)
{
height[rnk[i]]=k=0;
continue;
}
if (--k<0) k=0;
while (s[i+k]==s[sa[rnk[i]-1]+k]) k++;
height[rnk[i]]=k;
}
lg[0]=0;
for (int i=1; i<=n; i++)
lg[i]=lg[i-1]+(1<<lg[i-1]+1==i);
for (int i=1; i<=n; i++) h[i][0]=height[i];
for (int j=1; 1<<j<=n; j++)
for (int i=1; i+(1<<j)-1<=n; i++)
h[i][j]=min(h[i][j-1], h[i+(1<<j-1)][j-1]);
while (q--)
{
int k, l;
scanf("%d%d", &k, &l);
for (int i=1; i<=k; i++) scanf("%d", &a[i]);
for (int i=1; i<=l; i++) scanf("%d", &b[i]);
sort(a+1, a+k+1, cmp);
sort(b+1, b+l+1, cmp);
for (int i=1, j=1, t=1; i<=k||j<=l; t++)
{
if (i<=k&&(j>l||rnk[a[i]]<=rnk[b[j]]))
c[t]=a[i++], d[t]=0;
else if (j<=l&&(i>k||rnk[a[i]]>rnk[b[j]]))
c[t]=b[j++], d[t]=1;
}
long long ans=0;
sum[0]=0; top=0;
for (int i=1; i<=k+l; i++) sum[i]=sum[i-1]+d[i];
for (int i=1; i<=k+l; i++)
{
int t=i;
while (top>0&&lcp(c[i], c[i-1])<lcp(c[sta[top-1].id], c[sta[top-1].id-1]))
{
t=min(t, sta[--top].pos);
ans+=(1ll*(sum[sta[top].id-1]-sum[sta[top].pos-2])
*((i-sta[top].id)-(sum[i-1]-sum[sta[top].id-1]))
+1ll*((sta[top].id-sta[top].pos+1)-(sum[sta[top].id-1]-sum[sta[top].pos-2]))
*(sum[i-1]-sum[sta[top].id-1]))*lcp(c[sta[top].id], c[sta[top].id-1]);
}
sta[top++]=Node(i, t);
}
while (top>0)
{
top--;
ans+=(1ll*(sum[sta[top].id-1]-sum[sta[top].pos-2])
*((k+l-sta[top].id+1)-(sum[k+l]-sum[sta[top].id-1]))
+1ll*((sta[top].id-sta[top].pos+1)-(sum[sta[top].id-1]-sum[sta[top].pos-2]))
*(sum[k+l]-sum[sta[top].id-1]))*lcp(c[sta[top].id], c[sta[top].id-1]);
}
printf("%I64d\n", ans);
}
return 0;
}
提交记录
总结
后缀数组可以与多种算法/数据结构结合,比如
- 二分
- 分块
- 单调队列/单调栈
- RMQ数据结构(ST表)
- 树状数组/线段树
灵活应用算法/数据结构和