LOJ3399「2020-2021 集训队作业」Communication Network

看到这题第一想法就是口罩,感觉做过口罩这题的过程就很自然了。我们考虑设 f(k)f(k) 表示恰好 kk 条边与原树相同的树的个数,g(k)g(k) 表示钦定 kk 条边,于是:

k0f(k)k2kk0k2kjkg(j)(1)jk(jk)j0jg(j)kj2k(1)jk(j1k1)j02jg(j)kj2k1(1)jk(j1k1)j02jg(j)\sum_{k \ge 0} f(k)k2^k\\ \sum_{k \ge 0} k2^k \sum_{j \ge k} g(j)(-1)^{j-k} \binom j k\\ \sum_{j \ge 0} jg(j) \sum_{k \le j} 2^k(-1)^{j-k} \binom {j-1} {k-1}\\ \sum_{j \ge 0} 2jg(j) \sum_{k \le j} 2^{k-1}(-1)^{j-k} \binom {j-1} {k-1}\\ \sum_{j \ge 0} 2jg(j)

然后现在求的就是 jg(j)\sum jg(j) 这个东西。类似口罩的做法,我们知道钦定 kk 条边的时候树的方案数量是 nnk2din^{n-k-2} \prod d_i ,后面这团考虑组合意义,相当于是对每个钦定边的连通块选择一个点,前面这团就是 nn 的钦定块的数量次幂。再考虑 jg(j)jg(j) 的这个 jj ,仍然考虑组合意义,即从所有钦定边中选择一个边即可。

于是考虑设计 dp[u][0/1][0/1]dp[u][0/1][0/1] 表示 uu 子树之内,uu 所在的连通块是否被钦定过了点,以及 uu 整个子树是否有一个保留边被钦定即可。复杂度 O(n)O(n) 即可解决问题。

这题不读优过不了。

#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
#include "vector"
#include "map"
#include "set"
#include "queue"
using namespace std;
#define MAXN 2000006
//#define int long long
#define rep(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i)
#define per(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i)
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector<int>
#define all(x) (x).begin() , (x).end()
#define mem( a ) memset( a , 0 , sizeof a )
typedef long long ll;
const int P = 998244353;

namespace IO {
const int MAXSIZE = 1 << 20;
char buf[MAXSIZE], *p1, *p2;
#define gc()                                                               \
  (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin), p1 == p2) \
       ? EOF                                                               \
       : *p1++)
inline int rd() {
  int x = 0, f = 1;
  char c = gc();
  while (!isdigit(c)) {
    if (c == '-') f = -1;
    c = gc();
  }
  while (isdigit(c)) x = x * 10 + (c ^ 48), c = gc();
  return x * f;
}
char pbuf[1 << 20], *pp = pbuf;
inline void push(const char &c) {
  if (pp - pbuf == 1 << 20) fwrite(pbuf, 1, 1 << 20, stdout), pp = pbuf;
  *pp++ = c;
}
inline void write(int x) {
  static int sta[35];
  int top = 0;
  do {
    sta[top++] = x % 10, x /= 10;
  } while (x);
  while (top) push(sta[--top] + '0');
  push( 10 );
}
}

int n , m;
int A[MAXN];
vi G[MAXN];

int Pow( int x , int a ) {
    int ret = 1;
    while( a ) {
        if( a & 1 ) ret = ret * 1ll * x % P;
        x = x * 1ll * x % P , a >>= 1;
    }
    return ret;
}

int dp[MAXN][2][2];
void dfs( int u , int f ) {
    dp[u][0][0] = dp[u][1][0] = 1;
    int t[2][2];
    for( int v : G[u] ) if( v != f ) {
        dfs( v , u );
        memcpy( t , dp[u] , sizeof dp[u] );
        mem( dp[u] );
        // Not cut
        dp[u][0][0] = ( dp[u][0][0] + t[0][0] * 1ll * dp[v][0][0] ) % P,
        dp[u][1][0] = ( dp[u][1][0] + t[1][0] * 1ll * dp[v][0][0] ) % P,
        dp[u][1][0] = ( dp[u][1][0] + t[0][0] * 1ll * dp[v][1][0] ) % P;
        rep( w , 0 , 1 )
            dp[u][0][1] = ( dp[u][0][1] + t[0][w] * 1ll * dp[v][0][w ^ 1] ) % P,
            dp[u][1][1] = ( dp[u][1][1] + t[1][w] * 1ll * dp[v][0][w ^ 1] ) % P,
            dp[u][1][1] = ( dp[u][1][1] + t[0][w] * 1ll * dp[v][1][w ^ 1] ) % P;
        dp[u][0][1] = ( dp[u][0][1] + t[0][0] * 1ll * dp[v][0][0] ) % P,
        dp[u][1][1] = ( dp[u][1][1] + t[1][0] * 1ll * dp[v][0][0] ) % P,
        dp[u][1][1] = ( dp[u][1][1] + t[0][0] * 1ll * dp[v][1][0] ) % P;

        // Cut
        
        dp[u][0][0] = ( dp[u][0][0] + t[0][0] * 1ll * dp[v][1][0] % P * n ) % P,
        dp[u][1][0] = ( dp[u][1][0] + t[1][0] * 1ll * dp[v][1][0] % P * n ) % P;
        rep( w , 0 , 1 )
            dp[u][0][1] = ( dp[u][0][1] + t[0][w] * 1ll * dp[v][1][w ^ 1] % P * n ) % P,
            dp[u][1][1] = ( dp[u][1][1] + t[1][w] * 1ll * dp[v][1][w ^ 1] % P * n ) % P;
    }
}

void solve() {
	n = IO::rd();
    rep( i , 2 , n ) {
        int u , v;
        u = IO::rd() , v = IO::rd();
        G[u].pb( v ) , G[v].pb( u );
    }
    dfs( 1 , 1 );
    int as = dp[1][1][1] * 1ll * Pow( n , P - 2 ) % P * 2 % P;
    cout << as << endl;
}

signed main() {
//    int T;cin >> T;while( T-- ) solve();
    solve();
}
\