因为其实和一元多项式的各种运算相差不大,所以这里就简单写了。。
引入
二元多项式
二元多项式乘法
计算
类比一元多项式,如果我们对
但是注意到 DFT 是线性变换,所以对这一堆级数的 DFT 可以枚举
所以整体来看,就是先对每列做 DFT,再对每行做 DFT,然后点乘,然后对每行做 IDFT,最后对每列做 IDFT。实际上行和列的顺序都是无所谓的。复杂度
二元多项式乘法逆
给定
采用倍增的思想。但是同时倍增两维是不可取的,因为
-
边界:当
n=1 时,b_0(y)\equiv a_0(y)^{-1}\pmod{y^m} 。 - 设
A*G\equiv 1(\bmod x^n\bmod y^m),A*B\equiv 1(\bmod x^{2n}\bmod y^m) 。则B-G\equiv 0(\bmod x^n\bmod y^m) ,那么有(B-G)^2\equiv 0(\bmod x^{2n}\bmod y^m) 。将其展开后两边同时成A(x,y) 得B=2G-AG^2(\bmod x^{2n}\bmod y^m) 。
二元多项式 ln
求
两边同时求导(对
注意到
二元多项式 exp
这个东西就是牛顿迭代啦,由于泰勒展开中满足
常规操作先就写这几种,剩下的靠脑补吧(
关于多元多项式的扩展,也可以考虑类似的做法。但通常元数越多题目限制的长度也就越短。可以考虑类比一下异或运算的 DWT,发现
感觉这篇博客的信息量好少
最后放上我写的丑常数又大的代码吧
#include <cstdio>
#include <algorithm>
using namespace std;
const int MAXN=4096;
const int P=998244353, G=3;
int qpow(int n, int k)
{
int r=1;
while (k)
{
if (k&1) r=1ll*r*n%P;
n=1ll*n*n%P, k>>=1;
}
return r;
}
struct Reimu
{
int l, r[MAXN], w[MAXN];
void getl(int n)
{
int d=0; l=1; r[0]=0;
while (l<=n) l<<=1, d++;
for (int i=1; i<l; i++)
r[i]=r[i>>1]>>1|(i&1)<<(d-1);
int wn=qpow(G, (P-1)/l);
w[l>>1]=1;
for (int i=(l>>1)+1; i<l; i++) w[i]=1ll*w[i-1]*wn%P;
for (int i=(l>>1)-1; i>=0; i--) w[i]=w[i<<1];
}
void NTT(int* a, int ty)
{
if (ty==-1)
{
NTT(a, 1);
reverse(a+1, a+l);
int t=P-(P-1)/l;
for (int i=0; i<l; i++) a[i]=1ll*a[i]*t%P;
return;
}
for (int i=0; i<l; i++)
if (i<r[i]) swap(a[i], a[r[i]]);
for (int k=1; k<l; k<<=1)
{
for (int i=0; i<l; i+=k<<1)
for (int j=0; j<k; j++)
{
int x=a[i+j], y=1ll*w[k+j]*a[i+k+j]%P;
a[i+j]=x+y<P?x+y:x+y-P;
a[i+k+j]=x-y<0?x-y+P:x-y;
}
}
}
} R, C;
void NTT2D(int a[][MAXN], int ty)
{
static int t[MAXN];
for (int i=0; i<C.l; i++) R.NTT(a[i], ty);
for (int j=0; j<R.l; j++)
{
for (int i=0; i<C.l; i++) t[i]=a[i][j];
C.NTT(t, ty);
for (int i=0; i<C.l; i++) a[i][j]=t[i];
}
}
void Inv(int* a, int* b, int n)
{
static int t[MAXN];
if (n==1) return b[0]=qpow(a[0], P-2), void();
int k=(n+1)>>1; Inv(a, b, k);
for (int i=0; i<n; i++) t[i]=a[i];
R.getl(2*n);
for (int i=n; i<R.l; i++) t[i]=0;
for (int i=k; i<R.l; i++) b[i]=0;
R.NTT(t, 1); R.NTT(b, 1);
for (int i=0; i<R.l; i++) b[i]=(2-1ll*t[i]*b[i]%P+P)*b[i]%P;
R.NTT(b, -1);
}
void Inv2D(int a[][MAXN], int b[][MAXN], int n, int m)
{
static int t[MAXN][MAXN];
if (n==1) return Inv(a[0], b[0], m);
int k=(n+1)>>1; Inv2D(a, b, k, m);
C.getl(2*n); R.getl(3*m);
for (int i=0; i<C.l; i++)
for (int j=0; j<R.l; j++)
{
t[i][j]=(i<n&&j<m)?a[i][j]:0;
if (i>=k||j>=m) b[i][j]=0;
}
NTT2D(t, 1), NTT2D(b, 1);
for (int i=0; i<C.l; i++)
for (int j=0; j<R.l; j++)
b[i][j]=(2-1ll*t[i][j]*b[i][j]%P+P)*b[i][j]%P;
NTT2D(b, -1);
}
void Diff(int* a, int* b, int n)
{
for (int i=1; i<n; i++) b[i-1]=1ll*i*a[i]%P; b[n-1]=0;
}
void Inte(int* a, int* b, int n)
{
for (int i=1; i<n; i++) b[i]=1ll*qpow(i, P-2)*a[i-1]%P; b[0]=0;
}
void Ln(int* a, int* b, int n)
{
static int f[MAXN], g[MAXN];
Diff(a, f, n); Inv(a, g, n);
R.getl(2*n);
for (int i=n; i<R.l; i++) f[i]=g[i]=0;
R.NTT(f, 1); R.NTT(g, 1);
for (int i=0; i<R.l; i++) f[i]=1ll*f[i]*g[i]%P;
R.NTT(f, -1);
Inte(f, b, n);
}
void Ln2D(int a[][MAXN], int b[][MAXN], int n, int m)
{
static int f[MAXN][MAXN], g[MAXN][MAXN];
for (int i=0; i<n; i++) Diff(a[i], f[i], m);
Inv2D(a, g, n, m);
C.getl(2*n); R.getl(2*m);
for (int i=0; i<C.l; i++)
for (int j=0; j<R.l; j++)
if (i>=n||j>=m) f[i][j]=g[i][j]=0;
NTT2D(f, 1); NTT2D(g, 1);
for (int i=0; i<C.l; i++)
for (int j=0; j<R.l; j++)
f[i][j]=1ll*f[i][j]*g[i][j]%P;
NTT2D(f, -1);
for (int i=0; i<n; i++) Inte(f[i], b[i], m);
for (int i=0; i<n; i++) f[0][i]=a[i][0];
Ln(f[0], g[0], n);
for (int i=0; i<n; i++) b[i][0]=g[0][i];
}
void Exp(int* a, int* b, int n)
{
static int t[MAXN];
if (n==1) return b[0]=1, void();
int k=(n+1)>>1; Exp(a, b, k);
for (int i=k; i<n; i++) b[i]=0;
Ln(b, t, n);
for (int i=0; i<n; i++) t[i]=(a[i]-t[i]+P)%P;
t[0]=(t[0]+1)%P; R.getl(2*n);
for (int i=n; i<R.l; i++) b[i]=t[i]=0;
R.NTT(b, 1); R.NTT(t, 1);
for (int i=0; i<R.l; i++) b[i]=1ll*b[i]*t[i]%P;
R.NTT(b, -1);
}
void Exp2D(int a[][MAXN], int b[][MAXN], int n, int m)
{
static int t[MAXN][MAXN];
if (n==1) return Exp(a[0], b[0], m), void();
int k=(n+1)>>1; Exp2D(a, b, k, m);
for (int i=k; i<n; i++)
for (int j=0; j<m; j++) b[i][j]=0;
Ln2D(b, t, n, m);
for (int i=0; i<n; i++)
for (int j=0; j<m; j++)
t[i][j]=(a[i][j]-t[i][j]+P)%P;
t[0][0]=(t[0][0]+1)%P;
C.getl(2*n); R.getl(2*m);
for (int i=0; i<C.l; i++)
for (int j=0; j<R.l; j++)
if (i>=n||j>=m) t[i][j]=b[i][j]=0;
NTT2D(t, 1); NTT2D(b, 1);
for (int i=0; i<C.l; i++)
for (int j=0; j<R.l; j++)
b[i][j]=1ll*b[i][j]*t[i][j]%P;
NTT2D(b, -1);
}