题目链接
题意
给定引号字符集大小 k,多次询问有多少个长度为 n 的引号序列。
思路
一个引号序列对应 n 个元素进出栈的过程。
特别地,规定栈顶元素等于当前字符时, 选择弹栈而非压栈,这样引号序列和进出栈过程可以一一对应。
设 fi 为长度为 2i 的引号序列个数。
考虑最后一个出栈的元素,设它是第 j 个入栈的元素,则它有 k 种取值,且它前后是各自独立的,前面的过程有 fj−1 种方案, 后面的过程需保证该元素不在中途被弹出。
设 gi 为长度为 2i 的引号序列个数,且它保证在它的操作过程中,栈底始终有一个字符 c。
仍考虑最后一个出栈的元素,设它是第 j 个入栈的元素,则它有 k−1 种取值(不能与 c 相同),且它前后是各自独立的,前面的过程有 gj−1 种方案, 后面的过程需保证该元素不在中途被弹出。
列出转移式:
figi=kj=1∑ifj−1gi−j=(k−1)j=1∑igj−1gi−j
显然此为卷积形式,于是考虑生成函数。
设 F=∑i=0+∞fixi,G=∑i=0+∞gixi,
则 F=kxFG+1,G=(k−1)xG2+1。
解得 G=2(k−1)x1±1−4(k−1)x,由于 g0=1,即 limx→+0G(x)=1,所以 G=2(k−1)x1−1−4(k−1)x。
由此解得 F:
F=kxFG+1=1−kxG1=1−k⋅2(k−1)1−1−4(k−1)x1=2(k−1)−k+k1−4(k−1)x2(k−1)=k−2+k1−4(k−1)x2(k−1)=k2(1−4(k−1)x)−(k−2)22(k−1)(k1−4(k−1)x−(k−2))=−4(k2x−1)(k−1)2(k−1)(k1−4(k−1)x−(k−2))=−2(k2x−1)k1−4(k−1)x−(k−2)=−2(k2x−1)k∑i=0+∞((i21)(−4(k−1))ixi)−(k−2)
分母移至左侧,得出第 n 项与第 n−1 项的关系式:
fn−k2fn−1=2k(n21)(−4(k−1))n
最后得出递推式:
fn=k2fn−1+2k⋅n!∏i=0n−1(21−i)⋅(−4(k−1))n=k2fn−1+2k⋅n!∏i=0n−1(2i−1)⋅(2(k−1))n=k2fn−1+2k⋅n!−(2n−3)!!⋅(2(k−1))n=k2fn−1+2k⋅n!⋅2n−1(n−1)!(2n−2)!⋅(2(k−1))n=k2fn−1+nk⋅(n−1)!(n−1)!(2n−2)!⋅(k−1)n=k2fn−1+nk⋅(k−1)n(n−12n−2)
实现
预处理逆元、阶乘以及阶乘逆元,递推 f,O(1) 回答每个询问即可。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
| #include<cstdio> #include<cctype> using namespace std; const int N=1e7,mod=1e9+7; int t,k,n; int iv[N+1],fac[N*2-1],ifac[N]; int f[N+1]; void read(int &x) { char c; x=0; do { c=getchar(); } while(isspace(c)); do { x=x*10+c-'0'; c=getchar(); } while(isdigit(c)); return; } void write(int x) { static const int L=10; static int t,s[L]; do { s[t++]=x%10; } while(x/=10); while(t) { putchar(s[--t]+'0'); } return; } int pow(int x,int t) { int r=1; while(t) { if(t&1) { r=1ll*r*x%mod; } t>>=1,x=1ll*x*x%mod; } return r; } inline int inv(int x) { return pow(x,mod-2); } inline int C(int n,int m) { return 1ll*fac[n]*ifac[n-m]%mod*ifac[m]%mod; } int main() { freopen("quote.in","r",stdin); freopen("quote.out","w",stdout); read(t),read(k); iv[1]=1; for(int i=2;i<=N;++i) { iv[i]=-1ll*(mod/i)*iv[mod%i]%mod; } fac[0]=1; for(int i=1;i<=N*2-2;++i) { fac[i]=1ll*fac[i-1]*i%mod; } ifac[0]=1; for(int i=1;i<=N-1;++i) { ifac[i]=1ll*ifac[i-1]*iv[i]%mod; } f[0]=1; for(int i=1,p=1;p=1ll*p*(k-1)%mod,i<=N;++i) { f[i]=(1ll*k*k%mod*f[i-1]-1ll*k*p%mod*iv[i]%mod*C(i*2-2,i-1))%mod; } while(t--) { read(n); write((f[n]+mod)%mod),putchar('\n'); } return 0; }
|