1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
| #include<bits/stdc++.h> using namespace std; #define MAXN 200010 #define P 998244353 #define clr( a ) memset( a , 0 , sizeof a ) typedef long long ll; int wn[2][MAXN]; int Pow( int x , int y ) { int res=1; while(y) { if(y&1) res=res*(ll)x%P; x=x*(ll)x%P,y>>=1; } return res; } void getwn(int l) { for(int i=1;i<(1<<l);i<<=1) { int w0=Pow(3,(P-1)/(i<<1)),w1=Pow(3,P-1-(P-1)/(i<<1)); wn[0][i]=wn[1][i]=1; for(int j=1;j<i;++j) wn[0][i+j]=wn[0][i+j-1]*(ll)w0%P, wn[1][i+j]=wn[1][i+j-1]*(ll)w1%P; } } int rev[MAXN]; void getr(int l) { for(int i=1;i<(1<<l);++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<l-1); } void NTT(int *A,int len,int f) { for(int i=0;i<len;++i) if(rev[i]<i) swap(A[i],A[rev[i]]); for(int l=1;l<len;l<<=1) for(int i=0;i<len;i+=(l<<1)) for(int k=0;k<l;++k) { int t1=A[i+k],t2=A[i+l+k]*(ll)wn[f][l+k]%P; A[i+k]=(t1+t2)%P; A[i+l+k]=(t1-t2+P)%P; } if( f == 1 ) for(int inv=Pow(len,P-2),i=0;i<len;++i) A[i]=A[i]*(ll)inv%P; } int a , b , c , d; int J[MAXN] , invJ[MAXN] , inv[MAXN]; int cc( int a , int b ) { if( b > a ) return 0; return 1ll * J[a] * invJ[b] % P * invJ[a - b] % P; } int A[MAXN] , B[MAXN] , C[MAXN] , D[MAXN]; int main() { J[0] = inv[1] = invJ[0] = J[1] = invJ[1] = 1; for( int i = 2 ; i < MAXN ; ++ i ) inv[i] = 1ll * ( P - P / i ) * inv[P % i] % P , J[i] = 1ll * J[i - 1] * i % P , invJ[i] = 1ll * invJ[i - 1] * inv[i] % P; while( cin >> a >> b >> c >> d ) { clr( A ) , clr( B ) , clr( C ) , clr( D ); int n = a + b + c + d; for( int i = 1 ; i <= a ; ++ i ) A[i] = 1ll * cc( a - 1 , i - 1 ) * invJ[i] % P; for( int i = 1 ; i <= b ; ++ i ) B[i] = 1ll * cc( b - 1 , i - 1 ) * invJ[i] % P; for( int i = 1 ; i <= c ; ++ i ) C[i] = 1ll * cc( c - 1 , i - 1 ) * invJ[i] % P; for( int i = 1 ; i <= d ; ++ i ) D[i] = 1ll * cc( d - 1 , i - 1 ) * invJ[i] % P; int len = 1 , l = 0; while( len <= n ) len <<= 1 , ++ l; getwn( l ) , getr( l ); NTT( A , len , 0 ) , NTT( B , len , 0 ) , NTT( C , len , 0 ) , NTT( D , len , 0 ); for( int i = 0 ; i < len ; ++ i ) A[i] = 1ll * A[i] * B[i] % P * C[i] % P * D[i] % P; NTT( A , len , 1 ); ll res = 0; for( int i = 1 ; i <= n ; ++ i ) res += ( ( n - i & 1 ) ? -1ll : 1ll ) * J[i] * A[i] % P , res += P , res %= P; cout << res << endl; } }
|