正文
hdu 1402(FFT乘法 || NTT乘法)
小程序:扫一扫查出行
【扫一扫了解最新限行尾号】
复制小程序
【扫一扫了解最新限行尾号】
复制小程序
A * B Problem Plus
Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/32768 K (Java/Others)
Total Submission(s): 9413 Accepted Submission(s): 1468
Problem Description
Calculate A * B.
Input
Each line will contain two integers A and B. Process to end of file.
Note: the length of each integer will not exceed 50000.
Output
For each case, output A * B in one line.
Sample Input
1 2 1000 2
Sample Output
2 2000
Author
DOOM III
Recommend
DOOM III
就一个高精度乘法 FFT加速。
最近正好要捡起fft,就顺便整理了模板。
FFT的原理还是算法导论靠谱,没有那么艰深难懂,就涉及怎么进行FFT和FFT需要的原理和定理。
看看算法导论里FFT的部分,一定要读到迭代实现那部分!!
看了好久求和引理,才发觉他是为了保证$w_n^k$与$w_n^{k+2/h}$的对称性(即$w_n^{k+2/h}=-w_n^k$)的,这个引理是必要的。
对于多项式序列,我们可以用两个O(nlgn)(n>max(len1,len2)*2)的FFT将其系数表示转化为点值表示(DFT),然后用O(n) 相乘,接着用FFT把结果的点值表示变为系数表示(IDFT),总体算起来是3O(nlgn)+O(n),即O(nlgn)的时间复杂度。比O(n^2)好多了。
以下是学习的两个版本。
#include<bits/stdc++.h>
#define clr(x) memset(x,0,sizeof(x))
#define clr_1(x) memset(x,-1,sizeof(x))
#define clrmax(x) memset(x,0x3f3f3f3f,sizeof(x))
#define LL long long
#define mod 1000000007
#define PI 3.1415926535
using namespace std;
char s1[],s2[];
int a[],b[];
//复数序列结构体
struct complexed
{
double r,i;
complexed(double _r=0.0,double _i=0.0)
{
r=_r;
i=_i;
}
complexed operator +(complexed b)
{
return complexed(r+b.r,i+b.i);
}
complexed operator -(complexed b)
{
return complexed(r-b.r,i-b.i);
}
complexed operator *(complexed b)
{
return complexed(r*b.r-i*b.i,i*b.r+r*b.i);
}
}num1,num2;
vector<complexed> multi1,multi2;
inline int max(int a,int b)
{
return a>b?a:b;
}
//并将长度变为2…^(k+1)
void changelen(int &len)
{
int mul=;
while(mul<len)
mul<<=;
mul<<=;
len=mul;
return ;
}
//将整数序列复制到复数序列中
void copyed(int *a,vector<complexed> &multi,int len)
{
multi.resize(len);
for(int i=;i<len;i++)
multi[i]=(complexed){a[i],};
return;
}
//DFT的话on=1,IDFT on=-1;
void fft(vector<complexed> &multi,int len,int on)
{
complexed wn,w,u,t;
//wn,w,u,t如算法导论中所示
vector<complexed> ans;
ans.resize(len);
//ans存每次操作计算后的y,最后再作为下次的multi。
for(int h=len/;h>=;h>>=)
{
wn=(complexed){cos(*on*PI/(len/h)),sin(*on*PI/(len/h))};
for(int i=;i<h;i++)
{
w=(complexed){,};
for(int j=;j<len/h/;j++)
{
//蝴蝶操作
u=multi[i+*h*j];
t=multi[i+*h*j+h]*w;
ans[i+h*j]=u+t;
ans[i+h*j+len/]=u-t;
w=w*wn;
}
}
//ans作为下次计算的multi
multi=ans;
}
//IDFT每个元素都得除以n
if(on==-)
for(int i=;i<len;i++)
multi[i].r/=len;
return ;
}
int main()
{
int len1,len2,len;
while(scanf("%s%s",s1,s2)!=EOF)
{
len1=strlen(s1);
len2=strlen(s2);
clr(a);
clr(b);
for(int i=;i<len1;i++)
{
a[len1-i-]=s1[i]-'';
}
for(int i=;i<len2;i++)
{
b[len2-i-]=s2[i]-'';
}
len=max(len1,len2);
//取长度较长者作为长度,并将长度变为2…^(k+1)
changelen(len);
//将两个整数序列复制到复数序列中
copyed(a,multi1,len);
copyed(b,multi2,len);
//对两个复数序列进行DFT,变为点值表示
fft(multi1,len,);
fft(multi2,len,);
//对应点点值相乘
for(int i=;i<len;i++)
multi1[i]=multi1[i]*multi2[i];
//将的出来的点值表示进行IDFT变为系数表示
fft(multi1,len,-);
//四舍五入减小损失精度
for(int i=;i<len;i++)
{
a[i]=(int)(multi1[i].r+0.5);
}
//进位
for(int i=;i<len;i++)
{
a[i+]=a[i+]+a[i]/;
a[i]%=;
}
len=len1+len2-;
//去掉前导0
while(a[len]<= && len>) len--;
for(int i=len;i>=;i--)
printf("%d",a[i]);
printf("\n");
}
return ;
}
无位逆序置换的步长实现
#include<bits/stdc++.h>
#define clr(x) memset(x,0,sizeof(x))
#define clr_1(x) memset(x,-1,sizeof(x))
#define clrmax(x) memset(x,0x3f3f3f3f,sizeof(x))
#define LL long long
#define mod 1000000007
#define PI 3.1415926535
using namespace std;
char s1[],s2[];
int a[],b[];
struct complexed
{
double r,i;
complexed(double _r=0.0,double _i=0.0)
{
r=_r;
i=_i;
}
complexed operator +(complexed b)
{
return complexed(r+b.r,i+b.i);
}
complexed operator -(complexed b)
{
return complexed(r-b.r,i-b.i);
}
complexed operator *(complexed b)
{
return complexed(r*b.r-i*b.i,i*b.r+r*b.i);
}
}num1,num2;
complexed multi1[<<],multi2[<<];
inline int max(int a,int b)
{
return a>b?a:b;
}
void changelen(int &len)
{
int mul=;
while(mul<len)
mul<<=;
mul<<=;
len=mul;
return ;
}
//将整数序列复制到复数序列中
void copyed(int *a,complexed *multi,int len)
{
for(int i=;i<len;i++)
multi[i]=(complexed){a[i],};
return;
}
//位逆序变换
void bitchange(complexed *multi,int len)
{
int i,j,k;
for(i = , j = len/;i < len-; i++)
{
if(i < j)swap(multi[i],multi[j]);
k = len/;
while( j >= k)
{
j -= k;
k /= ;
}
if(j < k) j += k;
}
return ;
}
//DFT的话on=1,IDFT on=-1;
void fft(complexed *multi,int len,int on)
{
bitchange(multi,len);//位逆序置换
complexed wn,w,u,t;//如算法导论所示
for(int h=;h<=len;h<<=)
{
wn=(complexed){cos(*on*PI/h),sin(*on*PI/h)};
for(int i=;i<len;i+=h)
{
//蝴蝶操作
w=(complexed){,};
for(int j=i;j<i+h/;j++)
{
u=multi[j];
t=multi[j+h/]*w;
multi[j]=u+t;
multi[j+h/]=u-t;
w=w*wn;
}
}
}
//IDFT每个元素都得除以n
if(on==-)
for(int i=;i<len;i++)
multi[i].r/=len;
return ;
}
void mul(int *a,int *b,int &len1,int &len2)
{
int len=max(len1,len2);
//取长度较长者作为长度,并将长度变为2…^(k+1)
changelen(len);
//将两个整数序列复制到复数序列中
copyed(a,multi1,len);
copyed(b,multi2,len);
//对两个复数序列进行DFT,变为点值表示
fft(multi1,len,);
fft(multi2,len,);
//对应点点值相乘
for(int i=;i<len;i++)
multi1[i]=multi1[i]*multi2[i];
//将的出来的点值表示进行IDFT变为系数表示
fft(multi1,len,-);
//四舍五入减小损失精度
for(int i=;i<len;i++)
{
a[i]=(int)(multi1[i].r+0.5);
}
while(len-> && a[len-]==)
len--;
len1=len;
return ;
}
int main()
{
int len1,len2,len;
while(scanf("%s%s",s1,s2)!=EOF)
{
len1=strlen(s1);
len2=strlen(s2);
clr(a);
clr(b);
for(int i=;i<len1;i++)
{
a[len1-i]=s1[i]-'';
}
for(int i=;i<len2;i++)
{
b[len2-i]=s2[i]-'';
}
mul(a+,b+,len1,len2);
//进位
len=len1;
for(int i=;i<len;i++)
{
a[i+]=a[i+]+a[i]/;
a[i]%=;
}
while(a[len]>)
{
a[len+]=a[len+]+a[len]/;
a[len]%=;
len++;
}
for(int i=len;i>=;i--)
printf("%d",a[i]);
printf("\n");
}
return ;
}
位逆序置换的迭代实现
后来看了ntt,小改了下原迭代实现的模板,实现了迭代实现的NTT模板:
#include<bits/stdc++.h>
#define clr(x) memset(x,0,sizeof(x))
#define clr_1(x) memset(x,-1,sizeof(x))
#define clrmax(x) memset(x,0x3f3f3f3f,sizeof(x))
#define LL long long
#define mod 1004535809
#define PI 3.1415926535
#define P 1004535809
#define G 3
using namespace std;
char s1[],s2[];
LL a[],b[],c[];
LL quick_pow(LL mul,LL n)
{
LL res=;
mul=(mul%mod+mod)%mod;
while(n)
{
if(n%)
res=res*mul%mod;
mul=mul*mul%mod;
n/=;
}
return res;
}
inline int max(int a,int b)
{
return a>b?a:b;
}
void bitchange(LL *a,int len)
{
int i,j,k;
for(i = , j = len>>;i < len-; i++)
{
if(i < j)swap(a[i],a[j]);
k = len>>;
while( j >= k)
{
j -= k;
k >>= ;
}
if(j < k) j += k;
}
return ;
}
void changelen(int &len)
{
int mul=;
while(mul<len)
mul<<=;
mul<<=;
len=mul;
return ;
}
//DFT的话on=1,IDFT on=-1;
void ntt(LL *a,int len,LL on)
{
bitchange(a,len);//位逆序置换
LL wn,w,u,t;//如算法导论所示
for(int h=;h<=len;h<<=)
{
wn=quick_pow(G,(P-)/h)%mod;
for(int i=;i<len;i+=h)
{
//蝴蝶操作
w=;
for(int j=i;j<i+h/;j++)
{
u=a[j]%mod;
t=a[j+h/]*w%mod;
a[j]=(u+t)%mod;
a[j+h/]=(u-t+mod)%mod;
w=w*wn%mod;
}
}
}
//IDFT调换次序实现wn^-1的情况,并且乘以len的逆元
if(on==-)
{
//k^0显然不调换次序,但是k^1与k^-1,k^2与k^-2.... k^len/2与k^-len/2 要调换次序
for(int i=;i<len/;i++)
swap(a[i],a[len-i]);
LL re=quick_pow(len,P-);
for(int i=;i<len;i++)
a[i]=a[i]*re%mod;
}
return ;
}
void mul(LL *a,LL *b,int &len1,int &len2)
{
int len=max(len1,len2);
//取长度较长者作为长度,并将长度变为2…^(k+1)
changelen(len);
//对两个整数序列进行DFT,变为点值表示
ntt(a,len,);
ntt(b,len,);
//对应点点值相乘
for(int i=;i<len;i++)
a[i]=b[i]*a[i]%mod;
//将的出来的点值表示进行IDFT变为系数表示
ntt(a,len,-);
while(len-> && a[len-]==)
len--;
len1=len;
return ;
}
int main()
{
int len1,len2,len;
while(scanf("%s%s",s1,s2)!=EOF)
{
len1=strlen(s1);
len2=strlen(s2);
clr(a);
clr(b);
for(int i=;i<len1;i++)
{
a[len1-i]=s1[i]-'';
}
for(int i=;i<len2;i++)
{
b[len2-i]=s2[i]-'';
}
mul(a+,b+,len1,len2);
//进位
len=len1;
for(int i=;i<len;i++)
{
a[i+]=a[i+]+a[i]/;
a[i]%=;
}
while(a[len]>)
{
a[len+]=a[len+]+a[len]/;
a[len]%=;
len++;
}
for(int i=len;i>=;i--)
printf("%lld",a[i]);
printf("\n");
}
return ;
}
NTT的迭代实现
NTT需要爆搜下找到该质数的原根(这部分一般不写到代码里,一般是自己找出来以后再直接作为常量放在程序里,建议分解完P-1的质因数后去搜索快点,一般原根都不太大比较好搜)。在比赛中一般给出的质数P,P-1后一般是C*2^k的形式,才能支持2^k的分治。
学习资料推荐:http://blog.sina.com.cn/s/blog_7c4c33190102wht6.html 这个看下原理一类的,包括FFT的。其中笔者把(P-1)*2^m写错写成了P*2^m了。
代码以及等价性参考ACdreamer的代码:http://blog.csdn.net/acdreamers/article/details/39026505