Count on a Tree II

Count on a Tree II

以前写的用可持久化分块维护。。既难写而且常数还大。。最后在 Luogu 还是开了读优挂才跑过去的。。感觉题解区的 Bitset 维护比我这个不知道高到哪里去了。。

考虑随机在这个树上打 $\sqrt n$ 个点作为关键点,然后每个点向上跳找到第一个关键点的期望长度是 $ \sqrt n $ 的。当然,也有没那么玄学的打点方法,用深度是 $\sqrt n$ 倍数并且向下最深深度大于等于根号的点来打,这样打出来点的个数是 $\sqrt n$ 的并且明显从每个点向上的长度也是 $\sqrt n$ 的。

然后我们先预处理出关键点间两两的答案,做法是从每个关键点跑一次 DFS 。

然后考虑一次询问,对于一次询问我们需要查询这两个点到关键点的路径上有没有什么颜色是不存在于关键点的路径上的。也就是说我们需要维护根到每个节点的某个颜色的数量。直接主席树维护得带 $ \log $ 不优秀,可以考虑用根号平衡,修改 $O(\sqrt n)$ 查询 $O(1)$ ,套个可持久化就好了。

还需要用 ST 表 $ O(1) $ 维护 LCA。这样做复杂度就是在线的 $O(m\sqrt n)$ 了(虽然常数有点大)。空间是 $n\sqrt n$ 的。

代码写的很丑。。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
using namespace std;
#define MAXN 40006
#define B 206
#define swap( a , b ) ( (a) ^= (b) , (b) ^= (a) , (a) ^= (b) )
int n , m , blo , sz , bl;
int c[MAXN] , C[MAXN];
int head[MAXN] , to[MAXN << 1] , nex[MAXN << 1] , ecn;
void ade( int u , int v ) {
to[++ ecn] = v , nex[ecn] = head[u] , head[u] = ecn;
}
int ps[B] , im[MAXN] , pre[B][B] , cid , ee;
int w[MAXN] , an;
void dfs( int u , int fa ) {
if( !w[c[u]] ) ++ an;
++ w[c[u]];
if( im[u] ) pre[cid][im[u]] = an;
for( int i = head[u] ; i ; i = nex[i] ) {
int v = to[i];
if( v == fa ) continue;
dfs( v , u );
}
-- w[c[u]];
if( !w[c[u]] ) -- an;
}

int kk[MAXN << 2][B] , idx = 0 , rt[MAXN];
void build( ) {
idx = rt[0] = 1;
for( int i = 1 ; ( i - 1 ) * bl < sz ; ++ i ) kk[1][i] = i + 1 , ++ idx;
}
void add( int rt , int old , int x ) {
memcpy( kk[rt] , kk[old] , sizeof kk[old] );
++ idx;
memcpy( kk[idx] , kk[kk[rt][( x - 1 ) / bl + 1]] , sizeof kk[1] );
kk[rt][( x - 1 ) / bl + 1] = idx;
++ kk[idx][( x - 1 ) % bl + 1];
}
int que( int rt , int x ) {
int t = kk[rt][( x - 1 ) / bl + 1];
return kk[t][( x - 1 ) % bl + 1];
}

int dfn[MAXN << 1] , dep[MAXN << 1] , pos[MAXN << 1] , en = 0 , d[MAXN] , par[MAXN];
int dfs1( int u , int fa ) {
dfn[++en] = u , pos[u] = en , dep[en] = d[u] , par[u] = fa;
rt[u] = ++ idx;
add( rt[u] , rt[fa] , c[u] );
int re = 0;
for( int i = head[u] ; i ; i = nex[i] ) {
int v = to[i];
if( v == fa ) continue;
d[v] = d[u] + 1;
re = max( re , dfs1( v , u ) );
dfn[++ en] = u , dep[en] = d[u];
}
if( re - d[u] >= blo && d[u] % blo == 0 ) ps[++ ee] = u , im[u] = ee;
return re ? re : d[u];
}
int lg[MAXN << 1];
int st[MAXN << 1][17];
void ST() {
for( int i = 1 ; i <= en ; ++ i ) lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
for( int i = 1 ; i <= en ; ++ i ) st[i][0] = i;
for( int i = 1 ; ( 1 << i ) <= en ; ++ i )
for (int j = 1; j + (1 << i) - 1 <= en; j++)
st[j][i] = dep[st[j][i - 1]] < dep[st[j + (1 << (i - 1))][i - 1]] ? st[j][i - 1] : st[j + (1 << (i - 1))][i - 1];
}
int lca( int l , int r ) {
l = pos[l] , r = pos[r];
if( l > r ) swap( l , r );
int k = lg[r - l + 1] - 1;
return dep[st[l][k]] <= dep[st[r-(1<<k)+1][k]] ? dfn[st[l][k]] : dfn[st[r-(1<<k)+1][k]];
}
int col[MAXN] , cn , l;
int jump( int u ) {
if( im[u] ) return im[u];
col[++ cn] = c[u];
if( u == l ) return -1;
return jump( par[u] );
}
int rejjump( int u ) {
if( im[u] ) return im[u];
col[++ cn] = c[u];
return rejjump( par[u] );
}
int oc[MAXN];
int main() {
// freopen("10.in","r",stdin);
// freopen("ot","w",stdout);
cin >> n >> m;
blo = sqrt( n );
for( int i = 1 ; i <= n ; ++ i ) scanf("%d",&c[i]) , C[i] = c[i];
sort( C + 1 , C + 1 + n ); sz = unique( C + 1 , C + 1 + n ) - C - 1;
bl = sqrt( sz );
for( int i = 1 ; i <= n ; ++ i ) c[i] = lower_bound( C + 1 , C + 1 + sz , c[i] ) - C;
for( int i = 1 , u , v ; i < n ; ++ i ) {
scanf("%d%d",&u,&v);
ade( u , v ) , ade( v , u );
}
build( );
dfs1( 1 , 0 );
// cout << que( rt[3] , 3 ) << endl;
ST( );
for( int i = 1 ; i <= blo ; ++ i ) {
int p = ps[i]; cid = i;
dfs( p , p );
}
int u , v , psu , psv , flg = 0 , re , last = 0 , t , pr;
while( m-- ) {
scanf("%d%d",&u,&v);
u ^= last;
// cout << u << ' ' << v << endl;
l = lca( u , v );
cn = re = 0;
psu = jump( u ) , psv = jump( v );
if( psu == psv ) {
oc[c[l]] = 1 , ++ re;
for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) oc[col[i]] = 1 , ++ re;
for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
oc[c[l]] = 0;
} else if( ~psu && ~psv ) {
re = pre[psu][psv];
for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) {
oc[col[i]] = 1;
if( que( rt[ps[psu]] , col[i] ) + que( rt[ps[psv]] , col[i] ) - que( rt[par[l]] , col[i] ) * 2 == 0 )
++ re;
}
for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
} else {
if( ~psu ) swap( u , v ) , swap( psu , psv );
pr = l; t = cn;
while( !im[pr] ) {
pr = par[pr];
if( !oc[c[pr]] ) {
oc[c[pr]] = 1;
if (que(rt[ps[psv]], c[pr]) - que(rt[par[l]], c[pr]) == 0)
--re;
col[++ cn] = c[pr];
}
}
re += pre[im[pr]][psv];
for( int i = t + 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
cn = t;
for( int i = 1 ; i <= cn ; ++ i ) if( !oc[col[i]] ) {
oc[col[i]] = 1;
if( que( rt[ps[psv]] , col[i] ) - que( rt[par[l]] , col[i] ) == 0 )
++ re;
}
for( int i = 1 ; i <= cn ; ++ i ) oc[col[i]] = 0;
}
printf("%d\n",last = re);
}
}

文章作者: yijan
文章链接: https://yijan.co/old60/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Yijan's Blog