给定长为m的数列a,求递推数列\(f(n)=\sum _{i=1}^{m}a_{i}f(n-i)\)的第n项,\(n \leq 10^9,m \leq 3000\)。
如果m小一些的话,我们可以用O(m^3logn)的时间复杂度利用矩阵快速幂来解决,但是显然我们需要更优秀的做法。
假设有矩阵A,那么它的特征向量\(x\)和特征值\(\lambda\)定义如下:
$$(A-\lambda E) x=0$$
特征多项式定义如下(就是行列式的值):
$$char(A) = | A - \lambda E |$$
设矩阵\(A\)的特征多项式为\(\phi(A)\),满足\(\phi(A) = 0\)。
那么我们可以这样变形:\(A^n = f(A)g(A)+h(A)\),然后将特征多项式带入g的位置,则\(A^n =h(A)\)。注意到这其实就是进行多项式取模,而h的次数小于等于m,也就是说A^n可以表示成A^i(0<= i <m)的线性组合。
如果这样直接做,复杂度是O(m^4)的,所以我们还需要优化。如果我们有前2m项的值,那么我们就可以凑一凑转移用的列向量,然后可以写出来递推式,之后就可以O(m)算出f(n)的值了。
如果暴力进行多项式取模,那么我们需要算的其实就是x^n这个多项式关于一个低次多项式取模的结果,我们可以从x^1开始倍增,然后在倍增过程中取模,这样的复杂度是\(O(m^2logn)\)的,实测常数比较小,卡卡常的话三四千的数据范围还是比较合理的。
如果用FFT优化的话复杂度就是\(O(mlogmlogn)\)的,但是常数比较大,最多跑到50000以下。但是,如果递推数列没有特殊性质的话,我们在算f的前2m项的时候还是m^2的,需要用多项式取模来优化。
放一份代码,BZOJ的模板题。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
#pragma GCC optimize("-O2") #include <bits/stdc++.h> #define mod 1000000007 #define ll long long #define mxn 5005 using namespace std; char ibuf[65536],*ih,*it; #define getc() ((ih==it)&&((it=(ih=ibuf)+fread(ibuf,1,65536,stdin)),ih==it)?0:*ih++) inline int read(){ int s=0,f=1; char ch=getc(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1; ch=getc();} while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getc(); return s*f; } int n,m,a[mxn],h[mxn],f[mxn],g[mxn]; ll tp[mxn]; inline void mul(int *A,int *B,int *C){ for(int i=0;i<2*m;++i) tp[i]=0; for(int i=0;i<m;++i) for(int j=0;j<m;++j) tp[i+j]+=1ll*A[i]*B[j]%mod; for(int i=2*m-2;i>=m;--i) { tp[i]=tp[i]%mod; for(int j=m-1;j>=0;--j) tp[i-m+j]-=1ll*tp[i]*f[j]%mod; } for(int i=0;i<m;++i) C[i]=(tp[i]%mod+mod)%mod; } int tmp[mxn]; inline void qpow(int *ret,int b){ ret[0]=tmp[1]=1; while(b){ if(b&1) mul(ret,tmp,ret); mul(tmp,tmp,tmp),b>>=1; } } int main(){ n=read(),m=read(); f[m]=1; for(int i=1;i<=m;++i) a[i]=(read()%mod+mod)%mod,f[m-i]=(mod-a[i])%mod; for(int i=0;i<m;++i) h[i]=(read()%mod+mod)%mod; for(int i=m;i<=2*m-1;++i) { for(int j=1;j<=m;++j) (h[i]+=1ll*h[i-j]*a[j]%mod)%=mod; } qpow(g,n-m); int ans=0; for(int i=0;i<m;++i) (ans+=1ll*g[i]*h[i+m]%mod)%=mod; printf("%d",ans); } |
叨叨几句... NOTHING