1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
| #include<iostream> #include<cstring> #include<cstdio> #include<algorithm> #include "queue" using namespace std; typedef long long ll; #define MAXN 2000006
int n , lst , m; char ch[MAXN];
struct SAM{ int son[MAXN][11]; int par[MAXN] , len[MAXN]; int cnt , ecnt; void init( ) { memset( son , 0 , sizeof son ); cnt = lst = 1; ecnt = 0; } void ins( int x ) { int cur = ++ cnt; len[cur] = len[lst] + 1; int p = lst; while( p && !son[p][x] ) son[p][x] = cur , p = par[p]; if( !p ) par[cur] = 1; else { int q = son[p][x]; if( len[q] == len[p] + 1 ) par[cur] = q; else { int cl = ++ cnt; memcpy( son[cl] , son[q] , sizeof son[q] ); par[cl] = par[q]; len[cl] = len[p] + 1 , par[q] = par[cur] = cl; for( ; son[p][x] == q ; p = par[p] ) son[p][x] = cl; } } lst = cur; } int ch[MAXN][10] , pos[MAXN]; void build( ) { init( ); queue<int> Q; Q.push( 0 ); pos[0] = 1; while( !Q.empty() ) { int u = Q.front(); Q.pop(); for( int i = 0 ; i < 10 ; ++ i ) if( ch[u][i] ) { Q.push( ch[u][i] ); lst = pos[u]; ins( i ); pos[ch[u][i]] = lst; } } } long long work( ) { long long res = 0; for( int i = 2 ; i <= cnt ; ++ i ) { res += len[i] - len[par[i]]; } return res; } } S ; namespace wtf {
int n , m;
int head[MAXN] , to[MAXN << 1] , nex[MAXN << 1] , ecn; void ade( int u , int v ) { to[++ ecn] = v , nex[ecn] = head[u] , head[u] = ecn; }
int w[MAXN] , cn;
void build( int u , int ps , int fa ) { for( int i = head[u] ; i ; i = nex[i] ) { int v = to[i]; if( v == fa ) continue; int& x = S.ch[ps][w[v]]; if( !x ) x = ++ cn; build( v , x , u ); } } int c; void main() { cin >> n >> c; for( int i = 1 ; i <= n ; ++ i ) scanf("%d",&w[i]); for( int i = 1 , u , v ; i < n ; ++ i ) scanf("%d%d",&u,&v) , ade( u , v ) , ade( v , u ); S.init(); for( int i = 1 ; i <= n ; ++ i ) if( !nex[head[i]] ) build( i , S.ch[0][w[i]] ? S.ch[0][w[i]] : ( S.ch[0][w[i]] = ++ cn ) , i ); S.build(); cout << S.work( ) << endl; } } int main() { wtf::main(); }
|