1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
| #include "iostream" #include "algorithm" #include "cstring" #include "cstdio" #include "cmath" #include "vector" #include "map" #include "set" #include "queue" using namespace std; #define MAXN 4000006
#define rep(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i) #define per(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i) #define pii pair<int,int> #define fi first #define se second #define mp make_pair #define pb push_back #define eb emplace_back #define vi vector<int> #define all(x) (x).begin() , (x).end() #define mem( a ) memset( a , 0 , sizeof a ) typedef long long ll; int n , k; int A[200006];
vi G[200006] , g[MAXN]; int cne;
void ade( int u , int v ) {
g[u].pb( v ); }
int J[200006][18] , nd[200006][18] , cn , dep[200006]; int dfn[MAXN] , cnt; void dfs( int u , int f ) { dfn[u] = ++ cnt; for( int v : G[u] ) if( v != f ) { dep[v] = dep[u] + 1; J[v][0] = u , nd[v][0] = v; for( int k = 1 ; k <= 17 ; ++ k ) if( J[J[v][k - 1]][k - 1] ) nd[v][k] = ++ cn , ade( nd[v][k - 1] , nd[v][k] ) , ade( nd[J[v][k - 1]][k - 1] , nd[v][k] ), J[v][k] = J[J[v][k - 1]][k - 1]; dfs( v , u ); } }
vi pr[MAXN]; void addit( int c , int u , int v ) { if( dep[u] < dep[v] ) swap( u , v ); if( dep[u] != dep[v] ) for( int k = 17 ; k >= 0 ; -- k ) if( dep[J[u][k]] >= dep[v] ) ade( nd[u][k] , c ) , u = J[u][k]; if( u == v ) { ade( v , c ); return; } per( k , 17 , 0 ) if( J[u][k] != J[v][k] ) ade( nd[u][k] , c ) , ade( nd[v][k] , c ) , u = J[u][k] , v = J[v][k]; ade( u , c ) , ade( v , c ) , ade( J[u][0] , c ); }
int sz[MAXN] , dft[MAXN] , low[MAXN] , clo , bel[MAXN] , stk[MAXN] , top , ins[MAXN] , scc; void tarjan( int u ) { dft[u] = low[u] = ++ clo; stk[++ top] = u , ins[u] = 1; for( int v : g[u] ) { if( !dft[v] ) tarjan( v ) , low[u] = min( low[u] , low[v] ); else if( ins[v] ) low[u] = min( low[u] , dft[v] ); } if( dft[u] == low[u] ) { int x; ++ scc; do { x = stk[top--]; ins[x] = 0; bel[x] = scc; } while( x != u ); } }
int ind[MAXN]; void solve() { cin >> n >> k; rep( i , 2 , n ) { int u , v; scanf("%d%d",&u,&v); G[u].pb( v ) , G[v].pb( u ); } cn = n + k; rep( i , 1 , n ) scanf("%d",A + i) , ade( A[i] + n , i ) , pr[A[i]].pb( i ); dep[1] = nd[1][0] = 1 , dfs( 1 , 1 ); rep( i , 1 , k ) { sort( all( pr[i] ) , [&]( int a , int b ) { return dfn[a] < dfn[b]; } ); rep( j , 0 , pr[i].size() - 1 ) { int u = pr[i][j] , v = pr[i][( j + 1 ) % pr[i].size()]; addit( i + n , u , v ); } } rep( i , 1 , cn ) if( !dft[i] ) tarjan( i ); rep( i , 1 , cn ) for( int v : g[i] ) if( bel[v] != bel[i] ) ++ ind[bel[v]]; rep( i , n + 1 , n + k ) ++ sz[bel[i]];
int as = 0x3f3f3f3f; rep( i , 1 , scc ) if( !ind[i] ) as = min( as , sz[i] - 1 );
cout << as << endl; }
signed main() {
solve(); }
|