清华集训 2017 生成树计数

给定一张 $s$ 个点的图,存在 $s-n$ 个边,使得图中形成了 $n$ 个联通块,第 $i$ 个联通块内部有 $a_i$ 个点。

我们现在要再连 $n-1$ 个边,使得该图成为一个树。对于某一种连边方式,我们设原题第 $i$ 个联通块连出了 $d_i$ 条边,那么当前树的权值为

求所有生成树的权值和。


我们可以考虑把开始就连通的点缩起来,枚举一个联通块的生成树 $T$ ,那么

我们考虑一个联通块向外连,可能是这个联通块的任何一个点往外连,并且一个点可以连多个边出去。

我们把 $d_i^ma_i^{d_i}$ 提出到前面去,那么

但是生成树 $T$ 无疑是不容易枚举的,可以转而枚举 prufer 序。复习一下

prufer 序列的构造是,每次删除最小的叶子,把它的父亲加入序列。重复这个过程直到只剩下两个点,则构造出了一个长度为 $n-2$ 的序列。一个 prufer 序对应唯一一个无根树。

prufer 序有个很好的性质,一个数在 prufer 序列的出现次数就是它的度数减一。

如果我们钦定了每个点的出现次数 $c_i$ ,那么这样的 prufer 序列有正好 $\frac{(n-2)!}{\prod c_i!}$ 种。

我们考虑用枚举这样的出现次数 $c_i$ 序列替代枚举 $T$ ,那么

考虑这个 $\sum c_i = n - 2$ 的限制,我们可以用 $c_i$ 作为指数来构造 EGF。写出来就是

这个东西看起来完全不可优化?

我们的目标是,让 $A,B$ 能否从 $n$ 项化简到只需要 $m$ 项,乘上一个可以提出去的系数。

发现对于 $A,B$, $x$ 的指数都是 $k$ ,而前面竟然是 $k+1$ ,所以考虑先积分一下,提升 $x$ 的指数,化简后再求导

这个式子看起来又化简不动了,考虑后面那块 $a_i^k \frac{x^k}{k!}$ 看起来就非常 $\exp$ 。但是如果直接把 $\exp$ 提出去会发现又一次化不动了,所以考虑先把 $k^{2m}$ 用斯特林数展开,交换循环次序后看看能否化简去掉后面的 $\sum$

我们考虑提出一个 $a_i^jx^j$ 就可以把后面转化成 $\exp$ 了!

别忘了,我们还得给它求导回去。提出 $\exp(a_ix)$ 后利用 $(uv)’ = u’v + uv’$,有

提出 $e^{a_ix}$ 合并一下求和,那么有

后面竟然正好是斯特林数的递推公式。。。

对于 $B$ ,不难发现它和 $A$ 形式几乎是一摸一样,只是把 $2m$ 变成了 $m$ ,所以

然后记 $A_i(x) = e^{a_ix} A_i^1 (x) , B_i(x) = e^{a_ix}B_i^1(x)$

我们让后面的两个式子都变成了 $O(m)$ 项了!

于是后面乘积部分可以分治 NTT 乘出来除开某一项的 $B$ 的积再乘上 $A$ 加进去,总共也才 $O(nm)$ 项,于是复杂度是 $O(nm\log^2 (nm))$。前面的 $\exp$ 直接算一下就好了。

结 束 了 吗 ?

如果直接交上去,或许在 LOJ 5s 可以跑过去,但是在 luogu 会得到一个 TLE。


有没有更加优秀的优化呢?

回到最初的式子,令

注意此处非求导。

对比之前的式子,

我们只要提出去一个 $\prod a_i$ 就可以把 $A_i,B_j$ 换成 $A’,B’$ ,也就是

我们考虑交换 $\prod$ 和 $\sum$ 的顺序,

然后把乘积通过 $\ln,\exp$ 拆成求和的形式

我们现在可以预处理出

考虑 $[x^t]f(a_ix) = a_i^t[x^t]f(x)$ ,所以

所以现在我们需要对于 $k\in [1,n]$ 快速求出 $\sum a_i^k$。

我们构造它的 OGF

所以做一次 分治 FFT 就求出了 $T$ ,从而可以求出答案了。

总复杂度是分治 FFT 的 $O(n\log^2 n)$ 。

写这题顺便可以顺便复习 $\ln,\exp$ 求逆以及分治2333

无卡常情况下 LOJ Rank 7。。复杂度很优秀

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
#include "vector"
#include "map"
#include "set"
#include "queue"
using namespace std;
#define MAXN 100006
//#define int long long
#define rep(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i)
#define per(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i)
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector<int>
#define all(x) (x).begin() , (x).end()
#define mem( a ) memset( a , 0 , sizeof a )
typedef long long ll;
int n , m;
int A[MAXN];
#define P 998244353
int Pow( int x , int a ) {
int ret = 1;
while( a ) {
if( a & 1 ) ret = 1ll * ret * x % P;
x = 1ll * x * x % P , a >>= 1;
}
return ret;
}
int wn[2][MAXN];
void getwn( int len ) {
int w0 , w1;
for( int mid = 1 ; mid < len ; mid <<= 1 ) {
w0 = Pow( 3 , ( P - 1 ) / ( mid << 1 ) ) , w1 = Pow( 3 , P - 1 - ( P - 1 ) / ( mid << 1 ) );
wn[0][mid] = wn[1][mid] = 1;
for( int i = 1 ; i < mid ; ++ i )
wn[0][mid + i] = wn[0][mid + i - 1] * 1ll * w0 % P,
wn[1][mid + i] = wn[1][mid + i - 1] * 1ll * w1 % P;
}
}
int rev[MAXN];
void getr( int len ) {
int t = __builtin_ctz( len ) - 1;
for( int i = 1 ; i < len ; ++ i ) rev[i] = ( rev[i >> 1] >> 1 ) | ( ( i & 1 ) << t );
}
void NTT( int A[] , int len , int f ) {
for( int i = 0 ; i < len ; ++ i ) if( i < rev[i] ) A[i] ^= A[rev[i]] , A[rev[i]] ^= A[i] , A[i] ^= A[rev[i]];
int t0 , t1;
for( int mid = 1 ; mid < len ; mid <<= 1 )
for( int i = 0 ; i < len ; i += ( mid << 1 ) )
for( int j = 0 ; j < mid ; ++ j ) {
t0 = A[i + j] , t1 = A[i + mid + j] * 1ll * wn[f][mid + j] % P;
A[i + j] = ( t0 + t1 ) % P , A[i + mid + j] = ( t0 + P - t1 ) % P;
}
if( f ) for( int inv = Pow( len , P - 2 ) , i = 0 ; i < len ; ++ i ) A[i] = 1ll * A[i] * inv % P;
}
void NTT( vi& A , int len , int f ) {
for( int i = 0 ; i < len ; ++ i ) if( i < rev[i] ) A[i] ^= A[rev[i]] , A[rev[i]] ^= A[i] , A[i] ^= A[rev[i]];
int t0 , t1;
for( int mid = 1 ; mid < len ; mid <<= 1 )
for( int i = 0 ; i < len ; i += ( mid << 1 ) )
for( int j = 0 ; j < mid ; ++ j ) {
t0 = A[i + j] , t1 = A[i + mid + j] * 1ll * wn[f][mid + j] % P;
A[i + j] = ( t0 + t1 ) % P , A[i + mid + j] = ( t0 + P - t1 ) % P;
}
if( f ) for( int inv = Pow( len , P - 2 ) , i = 0 ; i < len ; ++ i ) A[i] = 1ll * A[i] * inv % P;
}
int tmpa[MAXN];
void Inv( const int A[] , int B[] , int n ) { // B = Inv A (mod x^n)
if( n == 1 ) { B[0] = Pow( A[0] , P - 2 ); return; }
Inv( A , B , ( n + 1 ) >> 1 );
rep( i , 0 , n - 1 ) tmpa[i] = A[i];
int len = 1;
while( len <= n + n ) len <<= 1;
rep( i , n , len - 1 ) tmpa[i] = 0;
getwn( len ) , getr( len );
NTT( B , len , 0 ) , NTT( tmpa , len , 0 );
rep( i , 0 , len - 1 ) B[i] = 1ll * B[i] * ( 2 - 1ll * tmpa[i] * B[i] % P + P ) % P;
NTT( B , len , 1 );
rep( i , n , len - 1 ) B[i] = 0;
}
int ta[MAXN];
void Ln( const int A[] , int B[] , int n ) { // B = Ln A (mod x^n)
rep( i , 0 , n - 1 ) ta[i] = 1ll * ( i + 1 ) * A[i + 1] % P;
int len = 1;
Inv( A , B , n );
while( len <= n + n ) len <<= 1;
rep( i , n , len - 1 ) ta[i] = 0;
getr( len ) , getwn( len );
NTT( ta , len , 0 ) , NTT( B , len , 0 );
rep( i , 0 , len - 1 ) B[i] = 1ll * B[i] * ta[i] % P;
NTT( B , len , 1 );
per( i , n - 1 , 1 ) B[i] = 1ll * B[i - 1] * Pow( i , P - 2 ) % P;
rep( i , n , len ) B[i] = 0;
B[0] = 0;
}
int tln[MAXN];
void Exp( const int A[] , int B[] , int n ) { // B = Exp A (mod x^n)
if( n == 1 ) { B[0] = 1; return; }
Exp( A , B , ( n + 1 ) >> 1 );
int len = 1;
while( len <= n + n ) len <<= 1;
rep( i , 0 , len - 1 ) tln[i] = 0;
Ln( B , tln , n );
rep( i , 0 , n - 1 ) tln[i] = ( A[i] + P - tln[i] ) % P;
++ tln[0];
getwn( len ) , getr( len );
NTT( B , len , 0 ) , NTT( tln , len , 0 );
rep( i , 0 , len - 1 ) B[i] = 1ll * B[i] * tln[i] % P;
NTT( B , len , 1 );
rep( i , n , len - 1 ) B[i] = 0;
}

vi& mul( vi& a , vi& b , int f ) { // a = a * b
int sa = a.size() , sb = b.size() , len = 1;
while( len <= sa + sb - 2 ) len <<= 1;
a.resize( len ) , b.resize( len );
getwn( len ) , getr( len );
NTT( a , len , 0 ) , NTT( b , len , 0 );
for( int i = 0 ; i < len ; ++ i ) a[i] = 1ll * a[i] * b[i] % P;
NTT( a , len , 1 );
if( f ) NTT( b , len , 1 );
return a;
}
vi S[MAXN];
vi& solve( int l , int r ) {
if( l == r ) return S[l];
int m = l + r >> 1;
return mul( solve( l , m ) , solve( m + 1 , r ) , 0 );
}
int R[MAXN] , T[MAXN];

int J[MAXN] , iJ[MAXN];
int A1[MAXN] , B1[MAXN] , a[MAXN] , b[MAXN] , rb[MAXN];
void solve() {
cin >> n >> m;
J[0] = iJ[0] = 1;
for( int i = 1 ; i <= n ; ++ i ) J[i] = 1ll * J[i - 1] * i % P , iJ[i] = Pow( J[i] , P - 2 );
int pro = 1;
rep( i , 0 , n - 1 ) {
scanf("%d",A + i) , pro = 1ll * pro * A[i] % P;
S[i].pb( 1 ) , S[i].pb( P - A[i] );
}
vi& re = solve( 0 , n - 1 );
rep( i , 0 , re.size() - 1 ) R[i] = re[i];
Ln( R , T , n + 1 );
for( int i = 0 ; i <= n ; ++ i ) T[i] = 1ll * i * T[i] % P;
T[0] = P - n;
for( int i = 0 ; i <= n ; ++ i )
A1[i] = 1ll * Pow( i + 1 , 2 * m ) * iJ[i] % P , B1[i] = 1ll * Pow( i + 1 , m ) * iJ[i] % P , T[i] = P - T[i];
Ln( B1 , b , n + 1 );
Inv( B1 , a , n + 1 );
int len = 1;
while( len <= n + n ) len <<= 1;
getwn( len ) , getr( len );
NTT( a , len , 0 ) , NTT( A1 , len , 0 );
rep( i , 0 , len - 1 ) a[i] = 1ll * a[i] * A1[i] % P;
NTT( a , len , 1 );
for( int i = 0 ; i <= n ; ++ i ) a[i] = 1ll * a[i] * T[i] % P , b[i] = 1ll * b[i] * T[i] % P;
for( int i = n + 1 ; i < len ; ++ i ) a[i] = 0;
// printf("Before Exp : "); for( int i = 0 ; i <= n ; ++ i ) printf("%d ",b[i]); puts("");
Exp( b , rb , n + 1 );
// printf("After Exp : "); for( int i = 0 ; i <= n ; ++ i ) printf("%d ",rb[i]); puts("");
while( len <= n + n ) len <<= 1;
getwn( len ) , getr( len );
NTT( rb , len , 0 ) , NTT( a , len , 0 );
for( int i = 0 ; i < len ; ++ i ) rb[i] = 1ll * rb[i] * a[i] % P;
NTT( rb , len , 1 );
cout << J[n - 2] * 1ll * pro % P * rb[n - 2] % P << endl;
}

signed main() {
// freopen("6.in","r",stdin);
// freopen("fuckout","w",stdout);
// int T;cin >> T;while( T-- ) solve();
solve();
}

文章作者: yijan
文章链接: https://yijan.co/qing-hua-ji-xun-2017-sheng-cheng-shu-ji-shu/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Yijan's Blog