青春猪头少年不会梦到兔女郎学姐

好的现在开始番剧介绍(bushi

我们先来考虑一个前置问题。假设有 $n$ 种球,第 $i$ 种球有 $a_i$ 个,你需要放在序列上,满足相邻的两个位置球不能相同,问方案数量。这是一个经典容斥题,我们考虑假设第 $i$ 种球出现了 $j$ 组相邻,那么就把这些球缩在一起。现在就变成了有 $i-j$ 个球进行排列,乘上容斥系数 $(-1)^{j}{i-1 \choose j}$ 。

回到这个题,我们现在希望确定的是每种颜色的球的分段情况。一旦我们确定了分段情况,就变成了上面那个问题。

我们先不扩展到环上,定义 $f(a,b)$ 表示把 $a$ 个球分成 $b$ 段,所有方案的权值的和。

我们考虑枚举最后形成了多少段,然后容斥计算放这么多段的方案数量,枚举这么多段又缩成了多少个球,那么某一种球的 EGF 为:

我们考虑如果把所有球的 EGF 乘了起来,那么 $x^i$ 的系数就是最终缩成了 $i$ 个球的容斥系数。

然后问题是 $f(a,b)$ 如何求?

我们可以用这篇文章 里面提到的那种转化,也就是解

这个方程。

插板法一下,就是 $a+b-1\choose 2b-1$ 。

推一波式子:

这个式子明显是个差卷积,做两次 NTT 就完事了。

由于 $\sum a_i = O(n)$ ,所以算出所有的复杂度是 $O(n\log n)$。

现在我们完全明白了链上的情况,那么环上怎么做?

我们可以钦定以环上某个 $1$ 的位置开头,然后从这里断环为链,也就是说可以统计开头是第一种球,结尾不是的方案。

对于每一种这样的方案,我们可以把这个开头位置任意转,所以方案得乘上 $S = \sum a_i$。

同时我们发现,对于我们统计出来的任何一个方案,它都会被算重 $1$ 的连续段个数次。也就是说如果第一种球被分成了 $k$ 段,那么这种方案将被计算 $k$ 次。

我们把第一种球的 EGF $\hat A_1$ 单独拿出来,如果钦定开头是第一种球,那么 EGF 中 $\frac{x^i}{i!}$ 应当变为 $\frac{x^{i-1}}{(i-1)!}$ 。

我们还得除去开头结尾都是第一种球的方案,这种时候应当变为 $\frac{x^{i-2}}{(i-2)!}$。

然后,有两种实现,一种是老老实实地做 $x^i$ 变成 $x^{i+1}$ 再变成 $x^{i+2}$ ,都拿去乘一下剩下的多项式,最后相减。

但是还有一种做法,可以直接让 $x^i$ 的系数变成 $x^{i+1},x^{i+2}$ 系数做差。乘一次就好了。

明显第二种更优秀。

第一种:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
#include "vector"
#include "map"
#include "set"
#include "queue"
using namespace std;
#define MAXN 600006
//#define int long long
#define rep(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i)
#define per(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i)
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector<int>
#define all(x) (x).begin() , (x).end()
#define mem( a ) memset( a , 0 , sizeof a )
#define P 998244353
typedef long long ll;
int n , s;
int A[MAXN];

int Pow( int x , int a ) {
int ret = 1;
while( a ) {
if( a & 1 ) ret = 1ll * ret * x % P;
x = 1ll * x * x % P , a >>= 1;
}
return ret;
}

int Wn[2][MAXN];
void getwn( int len ) {
for( int mid = 1 ; mid < len ; mid <<= 1 ) {
int w0 = Pow( 3 , ( P - 1 ) / ( mid << 1 ) ) , w1 = Pow( 3 , P - 1 - ( P - 1 ) / ( mid << 1 ) );
Wn[0][mid] = Wn[1][mid] = 1;
rep( i , 1 , mid - 1 )
Wn[0][mid + i] = 1ll * Wn[0][mid + i - 1] * w0 % P,
Wn[1][mid + i] = 1ll * Wn[1][mid + i - 1] * w1 % P;
}
}
int rev[MAXN];
void getr( int len ) {
int t = __builtin_ctz( len );
for( int i = 1 ; i < len ; ++ i ) rev[i] = ( rev[i >> 1] >> 1 ) | ( ( i & 1 ) << t - 1 );
}
void NTT( vi& A , int len , int f ) {
for( int i = 0 ; i < len ; ++ i ) if( i < rev[i] ) A[i] ^= A[rev[i]] , A[rev[i]] ^= A[i] , A[i] ^= A[rev[i]];
for( int mid = 1 ; mid < len ; mid <<= 1 )
for( int i = 0 ; i < len ; i += ( mid << 1 ) )
for( int k = 0 ; k < mid ; ++ k ) {
int t1 = A[i + k] , t2 = 1ll * A[i + k + mid] * Wn[f][mid + k] % P;
A[i + k] = ( t1 + t2 ) % P , A[i + k + mid] = ( t1 + P - t2 ) % P;
}
if( f ) for( int i = 0 , iv = Pow( len , P - 2 ) ; i < len ; ++ i ) A[i] = 1ll * A[i] * iv % P;
}

int J[MAXN] , iJ[MAXN];

int C( int a , int b ) {
if( a < b || a < 0 || b < 0 ) return 0;
return J[a] * 1ll * iJ[b] % P * iJ[a - b] % P;
}

vi& mul( vi& a , vi& b , int fuck = 0 ) { // a = a * b
int len = 1 , sz = ( a.size() + b.size() - 2 );
while( len <= sz ) len <<= 1;
getr( len ) , getwn( len );
a.resize( len ) , b.resize( len );
NTT( a , len , 0 ) , NTT( b , len , 0 );
rep( i , 0 , len - 1 ) a[i] = 1ll * a[i] * b[i] % P;
NTT( a , len , 1 );
if( fuck ) NTT( b , len , 1 );
a.resize( sz + 1 );
return a;
}

int flg;
vi tmp;
void getit( int m , vi& fun ) {
tmp.clear();
tmp.resize( m + 1 ) , fun.resize( m + 1 );
rep( i , 1 , m ) {
tmp[i] = C( m + i - 1 , 2 * i - 1 ) * 1ll * ( ( i & 1 ) ? P - 1 : 1 ) % P * J[i - 1] % P;
if( flg ) tmp[i] = 1ll * tmp[i] * Pow( i , P - 2 ) % P;
fun[m - i] = iJ[i];
}
fun[m] = iJ[0];
mul( fun , tmp );
fun[0] = 0;
rep( i , 1 , m ) {
fun[i] = fun[i + m] * 1ll * iJ[i - 1] % P * ((i & 1) ? P - 1 : 1) % P * iJ[i] % P;
}
fun.resize( m + 1 );
}

vi fs[MAXN];

vi& solve( int l , int r ) {
if( l == r ) return fs[l];
int mid = l + r >> 1;
return mul( solve( l , mid ) , solve( mid + 1 , r ) );
}

void solve() {
J[0] = iJ[0] = 1;
rep( i , 1 , MAXN - 1 ) J[i] = J[i - 1] * 1ll * i % P , iJ[i] = Pow( J[i] , P - 2 );
cin >> n;
flg = 1;
rep( i , 1 , n ) scanf("%d",A + i) , s += A[i] , getit( A[i] , fs[i] ) , flg = 0;
// rep( i , 0 , A[1] - 1 ) fs[1][i] = ( fs[1][i + 1] + P - ( i + 2 <= A[1] ? fs[1][i + 2] : 0 ) ) % P * 1ll * iJ[i] % P;
vi& re = solve( 2 , n );
vi t( A[1] , 0 );
rep( i , 0 , A[1] - 1 ) t[i] = ( fs[1][i + 1] * 1ll * ( i + 1 ) ) % P;
mul( t , re , 1 );
int ans = 0;
rep( i , 0 , t.size() - 1 )
ans = ( ans + t[i] * 1ll * J[i] % P ) % P;
if( A[1] == 1 ) return void( cout << ans * 1ll * s % P << endl );
// cout << ans << endl;
t.clear();
t.resize( A[1] - 1 );
rep( i , 0 , A[1] - 2 ) t[i] = ( fs[1][i + 2] * 1ll * ( i + 2 ) % P * ( i + 1 ) % P ) % P;
mul( t , re );
rep( i , 0 , t.size() - 1 ) ans = ( ans + P - t[i] * 1ll * J[i] % P ) % P;
cout << 1ll * ans * s % P << endl;
}

signed main() {
// freopen("in1.in","r",stdin);
// freopen("fuckout","w",stdout);
// int T;cin >> T;while( T-- ) solve();
solve();
}

第二种

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
#include "vector"
#include "map"
#include "set"
#include "queue"
using namespace std;
#define MAXN 600006
//#define int long long
#define rep(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i)
#define per(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i)
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector<int>
#define all(x) (x).begin() , (x).end()
#define mem( a ) memset( a , 0 , sizeof a )
#define P 998244353
typedef long long ll;
int n , s;
int A[MAXN];

int Pow( int x , int a ) {
int ret = 1;
while( a ) {
if( a & 1 ) ret = 1ll * ret * x % P;
x = 1ll * x * x % P , a >>= 1;
}
return ret;
}

int Wn[2][MAXN];
void getwn( int len ) {
for( int mid = 1 ; mid < len ; mid <<= 1 ) {
int w0 = Pow( 3 , ( P - 1 ) / ( mid << 1 ) ) , w1 = Pow( 3 , P - 1 - ( P - 1 ) / ( mid << 1 ) );
Wn[0][mid] = Wn[1][mid] = 1;
rep( i , 1 , mid - 1 )
Wn[0][mid + i] = 1ll * Wn[0][mid + i - 1] * w0 % P,
Wn[1][mid + i] = 1ll * Wn[1][mid + i - 1] * w1 % P;
}
}
int rev[MAXN];
void getr( int len ) {
int t = __builtin_ctz( len );
for( int i = 1 ; i < len ; ++ i ) rev[i] = ( rev[i >> 1] >> 1 ) | ( ( i & 1 ) << t - 1 );
}
void NTT( vi& A , int len , int f ) {
for( int i = 0 ; i < len ; ++ i ) if( i < rev[i] ) A[i] ^= A[rev[i]] , A[rev[i]] ^= A[i] , A[i] ^= A[rev[i]];
for( int mid = 1 ; mid < len ; mid <<= 1 )
for( int i = 0 ; i < len ; i += ( mid << 1 ) )
for( int k = 0 ; k < mid ; ++ k ) {
int t1 = A[i + k] , t2 = 1ll * A[i + k + mid] * Wn[f][mid + k] % P;
A[i + k] = ( t1 + t2 ) % P , A[i + k + mid] = ( t1 + P - t2 ) % P;
}
if( f ) for( int i = 0 , iv = Pow( len , P - 2 ) ; i < len ; ++ i ) A[i] = 1ll * A[i] * iv % P;
}

int J[MAXN] , iJ[MAXN];

int C( int a , int b ) {
if( a < b || a < 0 || b < 0 ) return 0;
return J[a] * 1ll * iJ[b] % P * iJ[a - b] % P;
}

vi& mul( vi& a , vi& b , int fuck = 0 ) { // a = a * b
int len = 1 , sz = ( a.size() + b.size() - 2 );
while( len <= sz ) len <<= 1;
getr( len ) , getwn( len );
a.resize( len ) , b.resize( len );
NTT( a , len , 0 ) , NTT( b , len , 0 );
rep( i , 0 , len - 1 ) a[i] = 1ll * a[i] * b[i] % P;
NTT( a , len , 1 );
if( fuck ) NTT( b , len , 1 );
a.resize( sz + 1 );
return a;
}

int flg;
vi tmp;
void getit( int m , vi& fun ) {
tmp.clear();
tmp.resize( m + 1 ) , fun.resize( m + 1 );
rep( i , 1 , m ) {
tmp[i] = C( m + i - 1 , 2 * i - 1 ) * 1ll * ( ( i & 1 ) ? P - 1 : 1 ) % P * J[i - 1] % P;
if( flg ) tmp[i] = 1ll * tmp[i] * Pow( i , P - 2 ) % P;
fun[m - i] = iJ[i];
}
fun[m] = iJ[0];
mul( fun , tmp );
fun[0] = 0;
rep( i , 1 , m ) {
fun[i] = fun[i + m] * 1ll * iJ[i - 1] % P * ((i & 1) ? P - 1 : 1) % P;
if( !flg ) fun[i] = 1ll * fun[i] * iJ[i] % P;
}
fun.resize( m + 1 );
}

vi fs[MAXN];

vi& solve( int l , int r ) {
if( l == r ) return fs[l];
int mid = l + r >> 1;
return mul( solve( l , mid ) , solve( mid + 1 , r ) );
}

void solve() {
J[0] = iJ[0] = 1;
rep( i , 1 , MAXN - 1 ) J[i] = J[i - 1] * 1ll * i % P , iJ[i] = Pow( J[i] , P - 2 );
cin >> n;
flg = 1;
rep( i , 1 , n ) scanf("%d",A + i) , s += A[i] , getit( A[i] , fs[i] ) , flg = 0;
rep( i , 0 , A[1] - 1 ) fs[1][i] = ( fs[1][i + 1] + P - ( i + 2 <= A[1] ? fs[1][i + 2] : 0 ) ) % P * 1ll * iJ[i] % P;
fs[1].erase( fs[1].begin() + A[1] , fs[1].end() );
vi& re = solve( 1 , n );
int ans = 0;
rep( i , 0 , re.size() - 1 )
ans = ( ans + re[i] * 1ll * J[i] % P ) % P;
cout << 1ll * ans * s % P << endl;
}

signed main() {
// freopen("in1.in","r",stdin);
// freopen("fuckout","w",stdout);
// int T;cin >> T;while( T-- ) solve();
solve();
}
文章作者: yijan
文章链接: https://yijan.co/qing-chun-zhu-tou-shao-nian-bu-hui-meng-dao-tu-nu-lang-xue-jie/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Yijan's Blog