51nod 1709 复杂度分析

51nod 1709 复杂度分析

考虑定义$F(x)$为$x$为根的子树所有点与$x$的深度差(其实就是$x$到每个子树内点的距离)的 1 的个数和。

注意,$F(x)$的值不是答案,但是只需要一点树形dp的基础内容就可以变成要求的答案。

对于一个点$u$, 考虑它的一个儿子$v$, 我们此时已经计算出了$F( v )$的值那么怎么统计$v$中所有点对于$u$的贡献呢?首先考虑$F(v)$ 的变化,由于当前的点$u$是$v$的父亲,$v$中所有点到$u$的距离实际上是原来到$v$的路径长度 + 1。那么二进制中1的个数加了多少呢?

对于一个$v$子树中点$k$,假设它到$v$的距离是$d$,则:

  • 如果$d \equiv 0 \pmod 2$那么显然二进制1的个数直接+1
  • 如果$d \equiv 1 \pmod 2$那么二进制中1的个数 不变
  • 如果$d \equiv 3 \pmod {2^2}$那么二进制中1的个数 少1
  • 如果$d \equiv 7 \pmod {2^3}$那么二进制中1的个数 少1

那么就有了一个思路,把$v$子树中与$v$距离$d \equiv {2^k - 1} \pmod {2^k}$的点的个数存着,这个可以倍增预处理。

那么对于$F$我们就会转移了,先+上子树的size,然后减去$v$子树中$2^k - 1$距离的点的个数。

转移了$F$后,直接给$v$中的$F$乘上$size(u) - size(v)$(这个是显然的树形dp了)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
#define MAXN 100006
typedef long long ll;
int n;

int read( ) {
int ret = 0; char ch = ' ';
while( ch < '0' || ch > '9' ) ch = getchar();
while( ch >= '0' && ch <= '9' ) ret *= 10 , ret += ch - '0' , ch = getchar();
return ret;
}

int head[MAXN] , to[MAXN << 1] , nex[MAXN << 1] , ecn = 0;
void ade( int u , int v ) {
to[++ecn] = v , nex[ecn] = head[u] , head[u] = ecn;
}

int G[MAXN][18] , GG[MAXN][18]; ll t[MAXN][18] ; // G 2^k , GG 2^{k - 1} , t how many nodes at dep % 2^k = 2^k - 1
int siz[MAXN];
void dfs( int u , int fa ) {
siz[u] = 1;
for( int i = head[u] ; i ; i = nex[i] ) {
int v = to[i];
if( v == fa ) continue;
G[v][0] = u , GG[v][0] = v;
for( int k = 1 ; k < 18 ; ++ k ) {
if( G[G[v][k-1]][k-1] )
G[v][k] = G[G[v][k-1]][k-1];
if( G[GG[v][k-1]][k-1] )
GG[v][k] = G[GG[v][k-1]][k-1];
else break;
}
dfs( v , u );
siz[u] += siz[v];
}
for( int k = 1 ; k < 18 ; ++ k ) {
if( G[u][k] )
t[G[u][k]][k] += t[u][k];
if( GG[u][k] )
++ t[GG[u][k]][k];
else break;
}
}
ll res = 0;
ll T[MAXN];
ll solve( int u , int fa ) {
ll R = 0 , ret = 0;
for( int i = head[u] ; i ; i = nex[i] ) {
int v = to[i];
if( v == fa ) continue;
R = 0;
ll lst = solve( v , u );
R += lst + siz[v];
R -= T[v];
res += R * ( siz[u] - siz[v] );
ret += R;
}
return ret;
}

signed main( ) {
n = read();
for( int i = 1 , u , v ; i < n ; ++ i ) {
u = read() , v = read();
ade( u , v ) , ade( v , u );
}
dfs( 1 , 1 );
for( int i = 1 ; i <= n ; ++ i )
for( int k = 1 ; k < 18 ; ++ k )
T[i] += t[i][k];
solve( 1 , 1 );
printf("%lld",res);
}
文章作者: yijan
文章链接: https://yijan.co/old19/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Yijan's Blog