0%

「AHOI/HNOI2017」礼物

题目链接

题意

给定两个可旋转但不可翻转的环 x,yx,y,支持整个环同时增加一个整数,最小化 i=1n(xiyi)2\sum_{i=1}^n(x_i-y_i)^2

思路

把所有条件整合到一起,即求 i=1n(xiyil+k)2 (l,kZ)\sum_{i=1}^n(x_i-y_{i-l}+k)^2\ (l,k\in\mathbb Z) 的最小值。

将其展开:

i=1n(xiyil+k)2=i=1n((xiyil)2+2(xiyil)k+k2)=i=1n(xi2+yi2)2i=1n(xiyil)+nk2+2i=1n(xiyi)k\begin{aligned}\sum_{i=1}^n(x_i-y_{i-l}+k)^2&=\sum_{i=1}^n((x_i-y_{i-l})^2+2\cdot(x_i-y_{i-l})\cdot k+k^2)\\&=\sum_{i=1}^n(x_i^2+y_i^2)-2\cdot\sum_{i=1}^n(x_i\cdot y_{i-l})+n\cdot k^2+2\cdot\sum_{i=1}^n(x_i-y_i)\cdot k\end{aligned}

其中,第一项可以直接求出,第三项与第四项组成了一个关于 kk 的一元二次方程组,求出对称轴 k=i=1n(xiyi)nk=-\frac{\sum_{i=1}^n(x_i-y_i)}{n},将 k\lfloor k\rfloork\lceil k\rceil 代入原式,取个 min\min 值即可。

第二项需要用到 FFT 求出。

xx 序列复制一遍(因为是环),yy 序列反转,则 xiyi+lx_i\cdot y_{i+l} 变为 xiyni+lx_i\cdot y_{n-i+l},这个单项式的次数是 n+ln+l 的。

我们只要对 xxyy 进行 FFT 卷积,然后第二项的值减去 2max(fi) (n+1i2n)2\cdot\max(f_i)\ (n+1\leq i\leq2\cdot n) 即可。

实现

注意 FFT 的精度误差。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include<cstdio>
#include<cmath>
#include<algorithm>
using namespace std;
const int N=50000,S=1<<18;
const double Pi2=acos(-1)*2;
int n,m,a[N+1],b[N+1],suma,sumb,sqra,sqrb,tot,len=1,rev[S],maxn,ans;
struct com {
double x,y;
com operator+(com ano)const
{
return com{x+ano.x,y+ano.y};
}
com operator-(com ano)const
{
return com{x-ano.x,y-ano.y};
}
com operator*(com ano)const
{
return com{x*ano.x-y*ano.y,x*ano.y+y*ano.x};
}
}f[S],g[S];
void init_rev()
{
for(int i=0;i<len;++i) {
rev[i]=(rev[i>>1]>>1)|(i&1?(len>>1):0);
}
return;
}
void fft(com*f,bool flag)
{
for(int i=0;i<len;++i) {
if(i<rev[i]) {
swap(f[i],f[rev[i]]);
}
}
for(int i=2,l=1;i<=len;i<<=1,l<<=1) {
com ide=com{cos(Pi2/i),flag?sin(Pi2/i):-sin(Pi2/i)};
for(int j=0;j<len;j+=i) {
com now=com{1,0};
for(int k=j;k<j+l;++k) {
com tem=now*f[l+k];
f[l+k]=f[k]-tem;
f[k]=f[k]+tem;
now=now*ide;
}
}
}
return;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i) {
scanf("%d",&a[i]);
suma+=a[i],sqra+=a[i]*a[i];
}
for(int i=1;i<=n;++i) {
scanf("%d",&b[i]);
sumb+=b[i],sqrb+=b[i]*b[i];
tot+=b[i]-a[i],ans+=a[i]*a[i]+b[i]*b[i];
}
int l=floor(1.0*tot/n),r=ceil(1.0*tot/n);
ans+=min(n*l*l-(tot*l<<1),n*r*r-(tot*r<<1));
while(len<=n*3) {
len<<=1;
}
init_rev();
for(int i=1;i<=n;++i) {
f[i].x=f[n+i].x=a[i];
g[i].x=b[n-i+1];
}
fft(f,1),fft(g,1);
for(int i=0;i<len;++i) {
f[i]=f[i]*g[i];
}
fft(f,0);
for(int i=n+1;i<=n<<1;++i) {
maxn=max(maxn,int(f[i].x/len+0.5));
}
ans-=maxn<<1;
printf("%d\n",ans);
return 0;
}