1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
| #include "iostream" #include "algorithm" #include "cstring" #include "cstdio" #include "set" #include "vector" #include "map" #define MAXN 100006
using namespace std; typedef long long ll; const double eps = 1e-7; int n , k , P; int w[MAXN] , pw[MAXN] , ivp[MAXN]; vector<int> G[MAXN]; int siz[MAXN] , mx[MAXN] , vis[MAXN] , rt , pr; inline int Pow( int x , int a ) { int cur = x % P , ans = 1; while( a ) { if( a & 1 ) ans = 1ll * ans * cur % P; cur = 1ll * cur * cur % P , a >>= 1; } return ans; } void dfs1( int u , int fa ) { siz[u] = 1 , mx[u] = 0; for( int v : G[u] ) if( v != fa && !vis[v] ) { dfs1( v , u ) , siz[u] += siz[v] , mx[u] = max( mx[u] , siz[v] ); } } void dfs2( int u , int fa ) { mx[u] = max( mx[u] , siz[rt] - siz[u] ); if( mx[u] < mx[pr] ) pr = u; for( int v : G[u] ) if( v != fa && !vis[v] ) dfs2( v , u ); } int findrt( int u ) { pr = rt = u; dfs1( u , u ); dfs2( u , u ); return pr; } int f[MAXN] , g[MAXN] , dep[MAXN]; int to[MAXN] , co[MAXN]; map<int,int> ff , gg; void dfs3( int u , int fa ) { if( ff.count( ( P - 1ll * g[u] * ivp[n] % P ) % P ) )to[u] += ff[( P - 1ll * g[u] * ivp[n] % P ) % P]; if( gg.count( ( P - 1ll * f[u] * pw[n] % P ) % P ) ) co[u] += gg[( P - 1ll * f[u] * pw[n] % P ) % P]; for( int v : G[u] ) if( v != fa && !vis[v] ) { dep[v] = dep[u] + 1; f[v] = ( f[u] + 1ll * w[v] * pw[dep[v]] % P ) % P; g[v] = ( g[u] + 1ll * w[v] * pw[n - dep[v]] % P ) % P; dfs3( v , u ); } } void dfs4( int u , int fa ) { ++ ff[f[u]] , ++ gg[g[u]]; for( int v : G[u] ) if( v != fa && !vis[v] ) dfs4( v , u ); } void work( int u ) { vis[u] = 1; ff.clear() , gg.clear(); f[u] = w[u]; ++ ff[f[u]] , ++ gg[0] , g[u] = 0; for( int v : G[u] ) if( !vis[v] ) { dep[v] = 1 , f[v] = ( 1ll * w[v] * k % P + w[u] ) % P , g[v] = 1ll * w[v] * pw[n - 1] % P; dfs3( v , u ) , dfs4( v , u ); } to[u] += ff[0]; co[u] += gg[( P - 1ll * f[u] * pw[n] % P ) % P]; ff.clear() , gg.clear(); for( int i = G[u].size() - 1 ; i >= 0 ; -- i ) { int v = G[u][i]; if( vis[v] ) continue; dfs3( v , u ) , dfs4( v , u ); } for( int v : G[u] ) if( !vis[v] ) work( findrt( v ) ); } int main() {
cin >> n >> k >> P; k %= P; pw[0] = ivp[0] = 1 , pw[1] = k , ivp[1] = Pow( k , P - 2 ); for( int i = 2 ; i <= n ; ++ i ) pw[i] = 1ll * pw[i - 1] * k % P , ivp[i] = 1ll * ivp[i - 1] * ivp[1] % P; for( int i = 1 ; i <= n ; ++ i ) scanf("%d",&w[i]) , w[i] = ( w[i] % P + P ) % P; for( int i = 1 , u , v ; i < n ; ++ i ) { scanf("%d%d",&u,&v); G[u].push_back( v ) , G[v].push_back( u ); } work( findrt( 1 ) ); long long res = 0; for( int i = 1 ; i <= n ; ++ i ) {
res += 2ll * to[i] * ( n - to[i] ) + 2ll * co[i] * ( n - co[i] ); res += 1ll * to[i] * ( n - co[i] ) + 1ll * co[i] * ( n - to[i] ); } cout << ( 1ll * n * n * n - res / 2 ) << endl; }
|