分治|【模拟赛|ZROI】01串(容斥,分治FFT)

题面 【分治|【模拟赛|ZROI】01串(容斥,分治FFT)】分治|【模拟赛|ZROI】01串(容斥,分治FFT)
文章图片

分治|【模拟赛|ZROI】01串(容斥,分治FFT)
文章图片

分治|【模拟赛|ZROI】01串(容斥,分治FFT)
文章图片

题解 前面的转化不重要,我就直接贴了(其实是因为我怎么努力都想不明白)
分治|【模拟赛|ZROI】01串(容斥,分治FFT)
文章图片

然后我们将每两个数中间加分割线(两端还有两个,总共n + 1 n+1 n+1 个),每次选择了一个01 01 01 后就顺便把分割线也删了。分割线删除的时间就是一个排列,每个0 0 0 右边的分割线一定比左边的分割线早删, 1 1 1 相反, ? ? ? 随意。
所以我们就可以把01 01 01 转化成排列中相邻两个数的相对大小限制< >。然后就是个经典题了。
对于一个排列,相邻两个数有大于或小于的限制,怎么做?
我们的做法是容斥。先保留所有的<符号,去掉>符号的限制,计算总方案数。这时一个>符号的限制不被满足,等价于原先的位置放上了<符号。我们根据这点容斥,令d p [ i ] dp[i] dp[i] 表示考虑前i i i 个位置的方案数。
我们枚举排列 1~i 中最后一个逆序位置j ( p j > p j + 1 ) j(p_j>p_{j+1}) j(pj?>pj+1?) ,令p r o [ i ] = ( ? 1 ) i 之 前 > 符 号 的 个 数 pro[i]=(-1)^{i之前>符号的个数} pro[i]=(?1)i之前>符号的个数 , c [ i ] c[i] c[i] 表示i i i 和i + 1 i+1 i+1 之间的符号:
d p [ i ] = ∑ j < i , c [ j ] = ‘ > ’ d p [ j ] ? ( p r o [ j + 1 ] ? p r o [ i ] ) ? ( i j ) = i ! ? p r o [ i ] ∑ j < i , c [ j ] = ‘ > ’ d p [ j ] ? p r o [ j + 1 ] j ! ? 1 ( i ? j ) ! dp[i]=\sum_{j’} dp[j]\cdot (pro[j+1]\cdot pro[i])\cdot {i\choose j}\\ =i!\cdot pro[i]\sum_{j’} \frac{dp[j]\cdot pro[j+1]}{j!}\cdot \frac{1}{(i-j)!} dp[i]=j’∑?dp[j]?(pro[j+1]?pro[i])?(ji?)=i!?pro[i]j’∑?j!dp[j]?pro[j+1]??(i?j)!1?
我们用分治FFT(NTT)就好了,时间复杂度O ( n log ? 2 n ) O(n\log^2n) O(nlog2n) 。
CODE

#include #include #include #include #include #include #include using namespace std; #define MAXN 250005 #define LL long long #define DB double #define lowbit(x) ((-x) & (x)) #define ENDL putchar('\n') #define FI first #define SE second int xchar() { static const int mxn = 1000000; static char b[mxn]; static int pos = 0,len = 0; if(pos == len) pos = 0,len = fread(b,1,mxn,stdin); if(pos == len) return -1; return b[pos ++]; } //#define getchar() xchar() LL read() { LL f=1,x=0; int s = getchar(); while(s<'0' || s>'9') {if(s<0)return -1; if(s=='-')f=-f; s=getchar(); } while(s>='0'&&s<='9') {x = (x<<3)+(x<<1)+(s^48); s = getchar(); } return f*x; } void putpos(LL x) { if(!x) return ; putpos(x/10); putchar((x%10)^48); } void putnum(LL x) { if(!x) {putchar('0'); return ; } if(x<0) putchar('-'),x=-x; return putpos(x); } void AIput(LL x,int c) {putnum(x); putchar(c); }const int MOD = 998244353; int n,m,s,o,k; int fac[MAXN],inv[MAXN],invf[MAXN]; char ss[MAXN]; int om,xm[MAXN<<2],rev[MAXN<<2]; int qkpow(int a,int b) { int res = 1; while(b > 0) { if(b & 1) res = res *1ll* a % MOD; a = a *1ll* a % MOD; b >>= 1; } return res; } void NTT(int *s,int n,int op) { for(int i = 1; i < n; i ++) { rev[i] = (rev[i>>1]>>1) | ((i&1) ? (n>>1):0); if(rev[i] < i) swap(s[rev[i]],s[i]); } om = qkpow(3,(MOD-1)/n); xm[0] = 1; if(op < 0) om = qkpow(om,MOD-2); for(int i = 1; i <= n; i ++) xm[i] = xm[i-1] *1ll* om % MOD; for(int k = 2,t = n>>1; k <= n; k <<= 1,t >>= 1) { for(int j = 0; j < n; j += k) { for(int i = j,l = 0; i < j+(k>>1); i ++,l += t) { int A = s[i],B = s[i+(k>>1)]; s[i] = (A + xm[l] *1ll* B) % MOD; s[i+(k>>1)] = (A +MOD- xm[l]*1ll*B%MOD) % MOD; } } } if(op < 0) { int iv = qkpow(n,MOD-2); for(int i = 0; i < n; i ++) s[i] = s[i] *1ll* iv % MOD; }return ; } int A[MAXN<<2],B[MAXN<<2]; int pro[MAXN],dp[MAXN]; int ST; void solve(int l,int r) { if(l == r) return ; int md = (l + r) >> 1; solve(l,md); int le = 1; while(le <= (md-l)+(r-l)) le <<= 1; for(int i = 0; i < le; i ++) A[i] = B[i] = 0; for(int i = l; i <= md; i ++) { if(ss[i+1] == '1') A[i-l] = dp[i]*1ll*pro[i+1]%MOD*invf[i-ST+1]%MOD; } for(int i = 1; i <= r-l; i ++) B[i] = invf[i]; NTT(A,le,1); NTT(B,le,1); for(int i = 0; i < le; i ++) A[i] = A[i] *1ll* B[i] % MOD; NTT(A,le,-1); for(int i = 0; i < le; i ++) { if(i+l > md && i+l <= r) { (dp[i+l] += fac[i+l-ST+1]*1ll*pro[i+l]%MOD*A[i]%MOD) %= MOD; } } solve(md+1,r); return ; } int main() { freopen("a.in","r",stdin); freopen("a.out","w",stdout); n = read(); fac[0] = fac[1] = inv[0] = inv[1] = invf[0] = invf[1] = 1; for(int i = 2; i <= n+3; i ++) { fac[i] = fac[i-1] *1ll* i % MOD; inv[i] = (MOD - inv[MOD%i]) *1ll* (MOD/i) % MOD; invf[i] = invf[i-1] *1ll* inv[i] % MOD; } scanf("%s",ss + 1); pro[0] = 1; for(int i = 1; i <= n; i ++) { pro[i] = pro[i-1]; if(ss[i] == '1') pro[i] = MOD-pro[i]; } int ans = fac[n+1]; for(int i = 0; i <= n; i ++) { int r = i; while(r < n && ss[r+1] != '?') r ++; for(int j = i; j <= r; j ++) { dp[j] = pro[j]*1ll*pro[i]%MOD; } ST = i; solve(i,r); ans = ans *1ll* dp[r] % MOD; ans = ans *1ll* invf[r-i+1] % MOD; i = r; } AIput(ans,'\n'); return 0; }

    推荐阅读