hdu5307题解

Nov 26, 2017 14:17 · 227 words · 2 minutes read 题解 hdu FFT

我的思路似乎跟标算不太一样?

我的想法是分治,让a[i]表示第i段的长度,对于a[l]…a[r],先分别求出a[l]..a[mid],a[mid+1]…a[r]的答案,再考虑跨越中点的答案。

重点在于计算跨越中点的答案。如果要求的是长度和为s的方案数,自然而然地就会想到生成函数+FFT,但是这道题要求的东西有一点不一样……

还是考虑用多项式解决,用数对(a,b)表示包含a段,长度和为b的区间。仿照求长度和为s的方案数的方法,定义多项式$A_l=\sum_{区间[l,mid]的(a,b)} ax^b$,类似定义$A_r$。只需求出$A=\sum_{区间[l,r]的(a,b)} ax^b$。考虑区间[mid+1,b]的每对(a,b),它对$A$的贡献是($A_l$每项系数加a)*$x^b$。每项系数加a这个操作不容易直接快速地算出来,所以需要额外再定义一个多项式$B_l=\sum_{区间[l,mid]的(a,b)} x^b$,那么贡献就可以写成$x^b(A_l+aB_l)$。于是我们就可以写出: $$\begin{align*}A&=\sum_{区间[mid+1,r]的(a,b)}x^b(A_l+aB_l)\\&=(\sum_{区间[mid+1,r]的(a,b)}x^b)A_l+(\sum_{区间[mid+1,r]的(a,b)}ax^b)B_l\\&=A_lB_r+A_rB_l\end{align*}$$

然后就可以用FFT来算了。我觉得复杂度是$O(n\log^2 n)$,常数大得飞起……

据传double会爆精度要用long double

不知道为什么直接%.0Lf输出会WA……我已经避免了输出-0的情况了……改成先强制转化成long long再输出就能过……那我还不如直接用long long数组存答案……

代码:

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#define ele int
#define ll long long
using namespace std;
#define maxn 100010
const long double pi=acos(-1.0);
struct cd{
	long double a,b;
};
inline cd operator+(cd a,cd b){
	return (cd){a.a+b.a,a.b+b.b};
}
inline cd operator-(cd a,cd b){
	return (cd){a.a-b.a,a.b-b.b};
}
inline cd operator*(cd a,cd b){
	return (cd){a.a*b.a-a.b*b.b,a.a*b.b+a.b*b.a};
}
inline cd& operator*=(cd&a,cd b){
	return a=a*b;
}
ele n,a[maxn];
ll res[maxn];
cd b1[maxn],b2[maxn],b3[maxn],b4[maxn],t1[maxn],t2[maxn],t3[maxn],t4[maxn];
inline void FFT(ele K,ele n,cd *a,cd *y){
	static ele f[maxn];
	f[0]=0; y[0]=a[0];
	for (int i=1; i<n; ++i){
		f[i]=f[i>>1]>>1;
		if (i&1) f[i]+=n>>1;
		y[i]=a[f[i]];
	}
	for (int p=1; p<n; p<<=1){
		cd o=(cd){cos(pi/p*K),sin(pi/p*K)};
		for (int i=0; i<n; i+=(p<<1)){
			cd o1=(cd){1,0};
			for (int j=i; j<i+p; ++j,o1*=o){
				cd u=y[j],v=y[j+p]*o1;
				y[j]=u+v;
				y[j+p]=u-v;
			}
		}
	}
	if (!~K)
		for (int i=0; i<n; ++i)
			y[i].a/=n,y[i].b/=n;
}
inline ele solve(ele l,ele r){
	if (l==r){
		++res[a[l]];
		return a[l];
	}
	ele mid=(l+r)>>1;
	ele s=solve(l,mid)+solve(mid+1,r);
	ele tmp=1;
	while (tmp<=s) tmp<<=1;
	for (int i=mid,s1=a[mid]; i>=l; --i,s1+=a[i])
		b1[s1].a+=mid-i+1,b2[s1].a+=1;
	for (int i=mid+1,s1=a[mid+1]; i<=r; ++i,s1+=a[i])
		b3[s1].a+=i-mid,b4[s1].a+=1;
	FFT(1,tmp,b1,t1); FFT(1,tmp,b2,t2); FFT(1,tmp,b3,t3); FFT(1,tmp,b4,t4);
	for (int i=0; i<tmp; ++i)
		t1[i]=t1[i]*t4[i]+t2[i]*t3[i];
	FFT(-1,tmp,t1,b1);
	for (int i=0; i<=s; ++i) res[i]+=(ll)(b1[i].a+0.5);
	for (int i=0; i<tmp; ++i) b1[i]=b2[i]=b3[i]=b4[i]=(cd){0,0};
	return s;
}
int main(){
	ele T;
	scanf("%d",&T);
	while (T--){
		scanf("%d",&n);
		for (int i=0; i<n; ++i) scanf("%d",a+i);
		memset(res,0,sizeof(res));
		ele s=solve(0,n-1);
		for (int i=0; i<=s; ++i) printf("%lld\n",res[i]);
	}
	return 0;
}