冷滟泽的个人博客冷滟泽的个人博客

二元多项式全家桶

因为其实和一元多项式的各种运算相差不大,所以这里就简单写了。。

引入

二元多项式 A(x,y),可以看作一个每一项都是一个关于 y 的级数的关于 x 的多项式,即

A(x,y)=\sum_{i=0}^{n-1}a_i(y)x^i

二元多项式乘法

计算 A(x,y)B(x,y)

类比一元多项式,如果我们对 x 那一元进行 DFT,然后对于每个 ia_i(y)b_i(y) 做一元多项式乘法,再 IDFT 回去就行了。问题是 a_i(y) 是个级数,所以看起来不太好搞。

但是注意到 DFT 是线性变换,所以对这一堆级数的 DFT 可以枚举 y 的幂次来分别做,也就是

\operatorname{DFT}_x[A(x,y)]=\sum_{j=0}^{m-1}y^j\operatorname{DFT}\left(\sum_{i=0}^{n-1}[y^j]a_i(y)x^i\right)

所以整体来看,就是先对每列做 DFT,再对每行做 DFT,然后点乘,然后对每行做 IDFT,最后对每列做 IDFT。实际上行和列的顺序都是无所谓的。复杂度 O(nm\log (n+m))

二元多项式乘法逆

给定 A(x,y),求 B(x,y) 满足 A(x,y)B(x,y)\equiv 1(\bmod x^n\bmod y^m)

采用倍增的思想。但是同时倍增两维是不可取的,因为 B(x,y)-G(x,y)\equiv 0(\bmod x^n\bmod y^m) 并不意味着 [B(x,y)-G(x,y)]^2\equiv 0(\bmod x^{2n}\bmod y^{2m}) 。当然你也可以换一种定义,即每次求出所有 x^iy^j(i+j<n) 项的系数,这样的系数看起来是个阶梯状,每次倍增时阶梯的边长也会翻倍。不过这么做较难扩展,且常数也比较大,所以这里采取只倍增一维的做法。

  • 边界:当 n=1 时,b_0(y)\equiv a_0(y)^{-1}\pmod{y^m}

  • A*G\equiv 1(\bmod x^n\bmod y^m),A*B\equiv 1(\bmod x^{2n}\bmod y^m)。则 B-G\equiv 0(\bmod x^n\bmod y^m),那么有 (B-G)^2\equiv 0(\bmod x^{2n}\bmod y^m)。将其展开后两边同时成 A(x,y)B=2G-AG^2(\bmod x^{2n}\bmod y^m)

T(n,m)=T(n/2,m)+nm\log(n+m),解得 T(n,m)=O(nm\log (n+m))

二元多项式 ln

B(x,y)=\ln A(x,y)

两边同时求导(对 xy 皆可,这里钦定为对 x 求导),再积回去得

B(x,y)=\int\frac{\partial A}{\partial x}(x,y)\frac{1}{A(x,y)} \mathbb{d}x+c(y)

注意到 c(y) 被当作常数吃掉了。于是我们还需要把所有不带 x 的项单独拿出来做一次一元 ln。

二元多项式 exp

这个东西就是牛顿迭代啦,由于泰勒展开中满足 B-G\equiv 0(\bmod x^n\bmod y^m)\Rightarrow (B-G)^k(\bmod x^{2n}\bmod y^m),k>1 这种性质,所以可以直接倍增。过程类似求逆,就不多说了。


常规操作先就写这几种,剩下的靠脑补吧(

关于多元多项式的扩展,也可以考虑类似的做法。但通常元数越多题目限制的长度也就越短。可以考虑类比一下异或运算的 DWT,发现 d 元长度为 n 的 DFT 等价于 dn 进制的异或 DWT。由此可知对于指定 d,n,这两个操作可以做到 O(dn^{d+1}) / O(dn^d\log n) 的时间复杂度。

感觉这篇博客的信息量好少

最后放上我写的丑常数又大的代码吧

#include <cstdio>
#include <algorithm>
using namespace std;
const int MAXN=4096;
const int P=998244353, G=3;
int qpow(int n, int k)
{
    int r=1;
    while (k)
    {
        if (k&1) r=1ll*r*n%P;
        n=1ll*n*n%P, k>>=1;
    }
    return r;
}
struct Reimu
{
    int l, r[MAXN], w[MAXN];
    void getl(int n)
    {
        int d=0; l=1; r[0]=0;
        while (l<=n) l<<=1, d++;
        for (int i=1; i<l; i++)
            r[i]=r[i>>1]>>1|(i&1)<<(d-1);
        int wn=qpow(G, (P-1)/l);
        w[l>>1]=1;
        for (int i=(l>>1)+1; i<l; i++) w[i]=1ll*w[i-1]*wn%P;
        for (int i=(l>>1)-1; i>=0; i--) w[i]=w[i<<1];
    }
    void NTT(int* a, int ty)
    {
        if (ty==-1)
        {
            NTT(a, 1);
            reverse(a+1, a+l);
            int t=P-(P-1)/l;
            for (int i=0; i<l; i++) a[i]=1ll*a[i]*t%P;
            return;
        }
        for (int i=0; i<l; i++)
            if (i<r[i]) swap(a[i], a[r[i]]);
        for (int k=1; k<l; k<<=1)
        {
            for (int i=0; i<l; i+=k<<1)
                for (int j=0; j<k; j++)
                {
                    int x=a[i+j], y=1ll*w[k+j]*a[i+k+j]%P;
                    a[i+j]=x+y<P?x+y:x+y-P;
                    a[i+k+j]=x-y<0?x-y+P:x-y;
                }
        }
    }
} R, C;
void NTT2D(int a[][MAXN], int ty)
{
    static int t[MAXN];
    for (int i=0; i<C.l; i++) R.NTT(a[i], ty);
    for (int j=0; j<R.l; j++)
    {
        for (int i=0; i<C.l; i++) t[i]=a[i][j];
        C.NTT(t, ty);
        for (int i=0; i<C.l; i++) a[i][j]=t[i];
    }
}
void Inv(int* a, int* b, int n)
{
    static int t[MAXN];
    if (n==1) return b[0]=qpow(a[0], P-2), void();
    int k=(n+1)>>1; Inv(a, b, k);
    for (int i=0; i<n; i++) t[i]=a[i];
    R.getl(2*n);
    for (int i=n; i<R.l; i++) t[i]=0;
    for (int i=k; i<R.l; i++) b[i]=0;
    R.NTT(t, 1); R.NTT(b, 1);
    for (int i=0; i<R.l; i++) b[i]=(2-1ll*t[i]*b[i]%P+P)*b[i]%P;
    R.NTT(b, -1);
}
void Inv2D(int a[][MAXN], int b[][MAXN], int n, int m)
{
    static int t[MAXN][MAXN];
    if (n==1) return Inv(a[0], b[0], m);
    int k=(n+1)>>1; Inv2D(a, b, k, m);
    C.getl(2*n); R.getl(3*m);
    for (int i=0; i<C.l; i++)
        for (int j=0; j<R.l; j++)
        {
            t[i][j]=(i<n&&j<m)?a[i][j]:0;
            if (i>=k||j>=m) b[i][j]=0;
        }
    NTT2D(t, 1), NTT2D(b, 1);
    for (int i=0; i<C.l; i++)
        for (int j=0; j<R.l; j++)
            b[i][j]=(2-1ll*t[i][j]*b[i][j]%P+P)*b[i][j]%P;
    NTT2D(b, -1);
}
void Diff(int* a, int* b, int n)
{
    for (int i=1; i<n; i++) b[i-1]=1ll*i*a[i]%P; b[n-1]=0;
}
void Inte(int* a, int* b, int n)
{
    for (int i=1; i<n; i++) b[i]=1ll*qpow(i, P-2)*a[i-1]%P; b[0]=0;
}
void Ln(int* a, int* b, int n)
{
    static int f[MAXN], g[MAXN];
    Diff(a, f, n); Inv(a, g, n);
    R.getl(2*n);
    for (int i=n; i<R.l; i++) f[i]=g[i]=0;
    R.NTT(f, 1); R.NTT(g, 1);
    for (int i=0; i<R.l; i++) f[i]=1ll*f[i]*g[i]%P;
    R.NTT(f, -1);
    Inte(f, b, n);
}
void Ln2D(int a[][MAXN], int b[][MAXN], int n, int m)
{
    static int f[MAXN][MAXN], g[MAXN][MAXN];
    for (int i=0; i<n; i++) Diff(a[i], f[i], m);
    Inv2D(a, g, n, m);
    C.getl(2*n); R.getl(2*m);
    for (int i=0; i<C.l; i++)
        for (int j=0; j<R.l; j++)
            if (i>=n||j>=m) f[i][j]=g[i][j]=0;
    NTT2D(f, 1); NTT2D(g, 1);
    for (int i=0; i<C.l; i++)
        for (int j=0; j<R.l; j++)
            f[i][j]=1ll*f[i][j]*g[i][j]%P;
    NTT2D(f, -1);
    for (int i=0; i<n; i++) Inte(f[i], b[i], m);
    for (int i=0; i<n; i++) f[0][i]=a[i][0];
    Ln(f[0], g[0], n);
    for (int i=0; i<n; i++) b[i][0]=g[0][i];
}
void Exp(int* a, int* b, int n)
{
    static int t[MAXN];
    if (n==1) return b[0]=1, void();
    int k=(n+1)>>1; Exp(a, b, k);
    for (int i=k; i<n; i++) b[i]=0;
    Ln(b, t, n);
    for (int i=0; i<n; i++) t[i]=(a[i]-t[i]+P)%P;
    t[0]=(t[0]+1)%P; R.getl(2*n);
    for (int i=n; i<R.l; i++) b[i]=t[i]=0;
    R.NTT(b, 1); R.NTT(t, 1);
    for (int i=0; i<R.l; i++) b[i]=1ll*b[i]*t[i]%P;
    R.NTT(b, -1);
}
void Exp2D(int a[][MAXN], int b[][MAXN], int n, int m)
{
    static int t[MAXN][MAXN];
    if (n==1) return Exp(a[0], b[0], m), void();
    int k=(n+1)>>1; Exp2D(a, b, k, m);
    for (int i=k; i<n; i++)
        for (int j=0; j<m; j++) b[i][j]=0;
    Ln2D(b, t, n, m);
    for (int i=0; i<n; i++)
        for (int j=0; j<m; j++)
            t[i][j]=(a[i][j]-t[i][j]+P)%P;
    t[0][0]=(t[0][0]+1)%P;
    C.getl(2*n); R.getl(2*m);
    for (int i=0; i<C.l; i++)
        for (int j=0; j<R.l; j++)
            if (i>=n||j>=m) t[i][j]=b[i][j]=0;
    NTT2D(t, 1); NTT2D(b, 1);
    for (int i=0; i<C.l; i++)
        for (int j=0; j<R.l; j++)
            b[i][j]=1ll*b[i][j]*t[i][j]%P;
    NTT2D(b, -1);
}
未经允许不得转载:冷滟泽的个人博客 » 二元多项式全家桶

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址