5.25 训练

三道 Petrozavodsk 和 Open Cup

Different Summands Counng

链接

反套路题。介绍一下 O(m3logn)O(m^3\log n) 的卡常做法和 O(m3)O(m^3) 的标算。

我们考虑对于一个数,算出它恰好没有出现过的方案数,然后用总方案减去这个就是出现过的方案数量。前面这个可以方便地容斥成钦定,而如果是钦定且这题是有序拆分就可以用插板来安排剩下的数。

a=1n((n1m1)k=0min(m,n/x)(1)k(mk)(nka1mk1))=a=1n(k=1min(m,n/x)(1)k+1(mk)(nka1mk1))=k=1m(1)k+1(mk)a=1n/m(nka1mk1)\sum_{a=1}^n\left(\binom{n-1}{m-1} - \sum_{k = 0}^{\min(m,n / x)} (-1)^k \binom{m}{k} \binom{n-ka-1}{m-k-1}\right)\\ = \sum_{a=1}^n\left(\sum_{k = 1}^{\min(m,n / x)} (-1)^{k+1} \binom{m}{k} \binom{n-ka-1}{m-k-1}\right)\\ = \sum_{k = 1}^{m}(-1)^{k+1} \binom{m}{k} \sum_{a=1}^{n/m}\binom{n-ka-1}{m-k-1}\\

这个东西看起来比较难算。我们考虑第一层直接枚举,怎么计算后面这个东西。

我们倍增计算出所有的 (nak1j)\binom{n-ak-1}{j} 。具体来说,我们先计算出

F(j)=a=0n/2(nak1j)F(j) = \sum_{a=0}^{n/2} \binom{n-ak-1}{j}

然后考虑怎么去计算前半部分。

由于一个组合恒等式,有

(nak1j)=t(nak1nk2jt)(nk2t)\binom{n-ak-1}{j} = \sum_t \binom{n-ak-1-\frac{nk}{2}}{j-t} \binom{\frac{nk}{2}}{t}

所以可以 O(m2)O(m^2) 转移出所有的 F(j)F(j) 。复杂度 O(m3logn)O(m^3 \log n) ,转移时可以 NTT 做到 O(m2logmlogn)O(m^2\log m\log n),写着很难受而且很不优秀。

标算做法非常反套路。具体来说,我们把 aa 看作 xx ,于是会发现 (nkx1mk1)\binom{n-kx-1}{m-k-1} 这个玩意是个 O(m)O(m) 次项的多项式。所以说我们直接把多项式暴力算出来(这部分可以暴力或者 NTT),然后去求一个自然数幂和来算这个多项式的 F(a)\sum F(a) 。这样复杂度仅仅是 O(m3)O(m^3) 或者 O(m2logm)O(m^2\log m) ,非常优秀。

第一种做法是当 nn 很大的时候比较容易想到的套路做法,也就是去分治,第二种是一种优秀但是比较特殊的做法。

#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
#include "vector"
#include "map"
#include "set"
#include "queue"
using namespace std;
#define MAXN 1006
//#define int long long
#define rep(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i)
#define per(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i)
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector<int>
#define all(x) (x).begin() , (x).end()
#define mem( a ) memset( a , 0 , sizeof a )
typedef long long ll;
const int P = 998244353;
ll n;
int m;

int Pow( int x , int a ) {
	int ret = 1;
	while( a ) {
		if( a & 1 ) ret = ret * 1ll * x % P;
		x = x * 1ll * x % P , a >>= 1;
	}
	return ret;
}

int J[MAXN] , iJ[MAXN] , iv[MAXN];
int C( int a , int b ) {
	if( b > a || b < 0 || a < 0 ) return 0;
	return J[a] * 1ll * iJ[b] % P * iJ[a - b] % P;
}
int CC( ll a , int b ) {
	if( a < b || b < 0 || a < 0 ) return 0;
	int re = 1;
	per( i , a % P , a % P - b + 1 )
		re = re * 1ll * ( i + P ) % P;
	re = re * 1ll * iJ[b] % P;
	return re;
}

void Inc( int& x , int y ) {
	x = ( x + y < P ? x + y : x + y - P );
}

int S[506] , i , gc[506];
int gs[506];
ll tot;
int solve( ll x ) {
	gs[0] = x % P;
	rep( k , 1 , m - i - 1 ) {
		gs[k] = Pow( ( x + 1 ) % P , k + 1 );
		rep( j , 0 , k - 1 )
			gs[k] = ( gs[k] + P - C( k + 1 , j ) * 1ll * gs[j] % P ) % P;
		gs[k] = ( gs[k] + P - 1 ) * 1ll * iv[k + 1] % P;
	}
	S[0] = 1;
	int sz = 0;
	rep( j , 0 , m - i - 2 ) {
		per( t , sz + 1 , 1 )
			S[t] = ( S[t] * 1ll * ( n % P + P - 1 - j ) + S[t - 1] * 1ll * ( P - i ) ) % P;
		S[0] = S[0] * 1ll * ( n % P + P - 1 - j ) % P;
		++ sz;
	}
	int re = 0;
	rep( t , 0 , m - i - 1 )
		re = ( re + S[t] * 1ll * gs[t] ) % P;
	re = re * 1ll * iJ[m - i - 1] % P;
	rep( j , 0 , m - i - 1 ) S[j] = 0;
	return re;
}

void solve() {
	cin >> n >> m;
	J[0] = iJ[0] = 1;
	rep( i , 1 , 506 ) J[i] = J[i - 1] * 1ll * i % P , iv[i] = Pow( i , P - 2 ) , iJ[i] = iJ[i - 1] * 1ll * iv[i] % P;
	int as = 0;
//	tot = n - 1;
//	solve( n - 1 );
//	rep( i , 0 , m ) cout << S[i] << ' '; puts("");
	for( i = 1 ; i < m ; ++ i ) {
		rep( c , 0 , m ) S[c] = 0;
		tot = ( n - m + i ) / i;
		int re = solve( tot );
		int w = C( m , i ) * 1ll * re % P;
		if( i & 1 ) as = ( as + w ) % P;
		else as = ( as + P - w ) % P;
	}
	if( n % m == 0 ) {
		if( m & 1 ) Inc( as , 1 );
		else Inc( as , P - 1 );
	}
	cout << as << endl;
}

signed main() {
//	freopen("in.in","r",stdin);
//	freopen("ot","w",stdout);
//    int T;cin >> T;while( T-- ) solve();
	solve();
}

B Process with Constant Sum

链接

我可能比较擅长这种题。

我们来手玩一下样例以及手捏的数据,会发现最后一段区间操作完后一定会成为一段 00 再拼一段 11 再拼一个 xx 。具体来说,我们可以这样操作:首先对每个数 mod2\bmod 2 ,把所有 2\ge 2 的数用操作把 xxmod2x - x\bmod 2 送到结尾。然后找到倒数第二段 11 ,如果它和最后一段之间隔了奇数个 00 ,会发现你可以使用操作把最后两段互相削除掉。如果它和最后一段之间隔了偶数个00 ,会发现可以使用操作把这两段拼起来(这两部分需要仔细想想,有点细节)。而且会发现先拼 / 削任意两段是不会有区别的,也就是操作顺序其实不影响。

然后就是一些套路操作了,因为顺序不影响,我们把一个区间会得到的最终状态放到线段树上,在一个区间维护前面 00 的长度,中间 11 的长度,最后 xx 是多少。合并就是上面说的,需要想一想,细节可以参考代码。修改就是正常的线段树单点修改。

复杂度 O(nlogn)O(n\log n)

#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
#include "vector"
#include "map"
#include "set"
#include "queue"
using namespace std;
#define MAXN 1000006
//#define int long long
#define rep(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i)
#define per(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i)
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector<int>
#define all(x) (x).begin() , (x).end()
#define mem( a ) memset( a , 0 , sizeof a )
typedef long long ll;
int n , q , k;
int A[MAXN];

struct tcurts {
	int r , t , le;
	tcurts( ) : r( 0 ) , t( 0 ) , le( 0 ) {}
	tcurts( int r , int t , int le ) : r(r) , t(t) , le(le) {}
} T[MAXN << 2] ;

tcurts merge( tcurts ls , tcurts rs ) {
	if( !ls.le ) return rs;
	int tl = ls.t , tr = rs.t , sl = ls.r , sr = rs.r , cr = rs.le - rs.t - 1;
	tcurts as;
	as.le = ls.le + rs.le;
	as.r = sr + sl / 2 * 2;
	if( sl & 1 ) ++ tl; else ++ cr;
	if( cr & 1 ) {
		if( tl <= tr ) as.r += tl * 2 , as.t = tr - tl;
		else as.r += tr * 2 + 1 , as.t = tl - tr - 1;
	} else {
		as.t = tr + tl;
	}
	return as;
}

void pu( int rt ) {
	T[rt] = merge( T[rt << 1] , T[rt << 1 | 1] );
}
void build( int rt , int l , int r ) {
	if( l == r ) { T[rt] = tcurts( A[l] , 0 , 1 ); return; }
	int m = l + r >> 1;
	build( rt << 1 , l , m ) , build( rt << 1 | 1 , m + 1 , r );
	pu( rt );
}

void mdfy( int rt , int l , int r , int p , int c ) {
	if( l == r ) { T[rt] = tcurts( c , 0 , 1 ); return; }
	int m = l + r >> 1;
	if( p <= m ) mdfy( rt << 1 , l , m , p , c );
	else mdfy( rt << 1 | 1 , m + 1 , r , p , c );
	pu( rt );
}

tcurts que( int rt , int l , int r , int L , int R ) {
	if( L <= l && R >= r ) return T[rt];
	int m = l + r >> 1;
	tcurts as;
	if( L <= m ) as = que( rt << 1 , l , m , L , R );
	if( R > m ) as = merge( as , que( rt << 1 | 1 , m + 1 , r , L , R ) );
	return as;
}

void solve() {
	cin >> n;
	rep( i , 1 , n ) scanf("%d",A + i);
	build( 1 , 1 , n );
	cin >> q;
	rep( i , 1 , q ) {
		int op , l , r;
		scanf("%d%d%d",&op,&l,&r);
		if( op == 1 ) {
			mdfy( 1 , 1 , n , l , r );
		} else {
			tcurts pr = que( 1 , 1 , n , l , r );
			printf("%d\n",pr.le - pr.t - ( !!pr.r ));
		}
	}
}

signed main() {
//	freopen("in.in","r",stdin);
//	freopen("ot","w",stdout);
//    int T;cin >> T;while( T-- ) solve();
	solve();
}

C Interest

神仙题

链接

会发现这是以前 DLS round 的第三题,AUOJ 459。

我们考虑 O(n2)O(n^2) 暴力怎么做。可以看出选出两条边不相交路径可以直接看成一个网络流问题,也就是求出 11 到点 uu 的大小为 22 的最小费用流,且每条边容量为 11 ,然后就可以类似环游那个题真正的跑 22 次单路增广。

首先第一次流其实可以非常简单的得出,也就是当我们建出最短路树后,第一次流一定是从 11 在最短路树上流到 uu ,可以一次 dijkstra 实现。我们考虑能否通过一次 dijkstra 做完第二次流。但是在流了一次并反向后会得到负权边,所以我们可以对最开始的图做一次 primal dual ,也就是设到每个点最短路为 dud_u ,那么把 uvu \to v 边权为 ww 的边的边权设置为 w=wdv+duw' = w - d_v + d_u 。不难发现这样操作后最短路树上的边一定全部变成了 00 ,并且对于一条 1k1 \to k 的路径,实际长度一定正好是 dk+disd_k + dis 。并且此时,如果我们对 1u1\to u 的路径全部反向,所有边的长度仍然是 0\ge 0 的!

于是我们考虑做一个 dijkstra 的过程,也就是我们设 dud_u' 为在 1u1 \to u 的边被反向时,1u1 \to u 的最短路。然后我们仍然考虑类 dijkstra 的转移,也就是每次拿出堆顶来更新其他点。考虑当前堆顶为 uu ,那么转移有两种。首先,这个点自己连出去的边肯定可以转移,对于一条非树边 (u,v,w)(u,v,w) ,一定可以用 du+wd_u + w 来更新 dvd_v

然后考虑还可以用 dud_u 更新什么,会发现,其实 1u1 \to u 的边被反向后等价于当前的根换成了 uu ,我们考虑对于一对其他的边 (a,b,w)(a,b,w) ,我们可以用 du+wd_u + w 来更新 bb 当且仅当 a,ba,b 不在 uu 的同一棵子树。因为如果把 1b1 \to b 的路径反向其实等价于在这个图上把 ubu \to b 的路径反向,然后如果 a,ba,b 不在同一个子树,那么把 1b1 \to b 反向后仍然可以从 uu00 的代价走到 aa ,否则一定走不到。所以可以这样更新。

但是直接这样更新需要枚举 uu 的所有子树,这是不优秀的。但是考虑 uu 的某个子树中如果存在某个点 vv ,而 dv<dud_v < d_u ,即 vv 已经作为堆顶过了,那么 vv 的子树一定是不用更新的。因为 vv 子树内部一定不能更新, vv 子树向外已经被 vv 更新过了。所以我们更新完一个点 uu 后其实可以直接把 uu 删去,也就是做一个类似点分的过程。

但是这种启发式点分也很容易被卡到 O(n2)O(n^2) 。我们考虑在更新一个点 uu 时,其实可以不用枚举这个点的最大的子树。我们可以对每个点存下这个点的所有出边即入边,然后枚举所有非最大的子树以及它们的出入边即可。我们可以证明这样每个点被枚举的次数是不超过 O(logn)O(\log n) 的,因为每次它作为轻子树出现,就说明这次点分之前它所在子树大小会翻倍。

然后还需要考虑最后一个问题,怎么判断哪个子树是最大的。考虑维护一个当前最大的子树 mxmx 。考虑用 O(min(szu,szv))O(\min (sz_u,sz_v)) 的复杂度比大小。具体来说,可以倍增一个上界 limlim ,然后在两个子树分别 dfs limlim 次,如果其中某一个没 dfs 到这么多次即可直接比较,否则倍增 limlim 。可以发现这样实际上只会用 O(min(sz))O(\min(sz)) 次。所以,每次枚举一个子树后如果它比当前子树大,那么我们去把当前最大子树给 dfs 一次枚举所有出入边,然后把最大的子树修改成当前子树,否则对当前子树做即可。

前面说明了,这样的复杂度是 O(mlogn)O(m\log n) 的。

#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "cmath"
#include "vector"
#include "map"
#include "set"
#include "queue"
using namespace std;
//#pragma GCC optimize(3)
#define MAXN 200006
//#define int long long
#define rep(i, a, b) for (int i = (a), i##end = (b); i <= i##end; ++i)
#define per(i, a, b) for (int i = (a), i##end = (b); i >= i##end; --i)
#define pii pair<int,int>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector<int>
#define all(x) (x).begin() , (x).end()
#define mem( a ) memset( a , 0 , sizeof a )
typedef long long ll;
const int P = 998244353;
int n , m;
vector<pair<int,ll> > G[MAXN] , F[MAXN];

ll dis[MAXN];
int vis[MAXN] , pre[MAXN];
priority_queue<pair<ll,int> > Q;
void dijk( ) {
	rep( i , 1 , n ) dis[i] = 1e18 , vis[i] = 0;
	Q.emplace( 0 , 1 );
	dis[1] = 0;
	while( !Q.empty() ) {
		int u = Q.top().se; Q.pop();
		if( vis[u] ) continue;
		vis[u] = 1;
		for( auto t : G[u] ) if( dis[t.fi] > dis[u] + t.se ) {
				int v = t.fi;
				dis[v] = dis[u] + t.se , Q.emplace( -dis[v] , v );
				pre[v] = u;
			}
	}
}

vi T[MAXN];
ll pr[MAXN];

int cnt = 0 , lim;
void dfs( int u , int f ) {
	++ cnt;
	for( int v : T[u] ) if( v != f && !vis[v] ) {
			if( cnt >= lim ) break;
			dfs( v , u );
		}
}

bool chk( int u , int v ) { // return siz[u] > siz[v]
	lim = 1;
	while( 233 ) {
		int a , b;
		dfs( u , 0 ) , a = cnt , cnt = 0;
		dfs( v , 0 ) , b = cnt , cnt = 0;
		if( a != lim || b != lim ) return a > b;
		lim <<= 1;
	}
}

int col[MAXN] , co;
void assign( int u , int f ) {
	col[u] = co;
	for( int v : T[u] ) if( v != f && !vis[v] )
			assign( v , u );
}

void wkr( int u , int f , ll D ) {
	for( auto t : G[u] ) {
		int v = t.fi;
		if( col[v] != col[u] && dis[v] > D + t.se )
			dis[v] = D + t.se , Q.emplace( -dis[v] , v );
	}
	for( auto t : F[u] ) {
		int v = t.fi;
		if( !vis[v] && col[v] != col[u] && dis[u] > D + t.se )
			dis[u] = D + t.se , Q.emplace( -dis[u] , u );
	}
	for( int v : T[u] ) if( v != f && !vis[v] ) wkr( v , u , D );
}

void solve() {
	cin >> n >> m;
	rep( i , 1 , m ) {
		int u , v , w;
		scanf("%d%d%d",&u,&v,&w);
		G[u].eb( v , w );
	}
	dijk( );
	rep( u , 1 , n ) {
		for( auto& t : G[u] ) {
			t.se -= dis[t.fi] - dis[u];
			F[t.fi].eb( u , t.se );
		}
		if( pre[u] ) T[pre[u]].pb( u ) , T[u].pb( pre[u] );
	}
	rep( i , 1 , n ) pr[i] = dis[i] , vis[i] = 0 , dis[i] = 1e18;
	dis[1] = 0;
	Q.emplace( 0 , 1 );
	co = 0;
	while( !Q.empty() ) {
		int u = Q.top().se; Q.pop();
		if( vis[u] ) continue;
		vis[u] = 1;
		for( auto t : G[u] ) {
			int v = t.fi;
			if( pre[v] == u || dis[v] <= dis[u] + t.se ) continue;
			dis[v] = dis[u] + t.se;
			Q.emplace( -dis[v] , v );
		}
		int hea = 0;
		for( int v : T[u] ) {
			if( vis[v] ) continue;
			if( !hea ) { hea = v; continue; }
			if( chk( v , hea ) )
				swap( hea , v );
			++ co , assign( v , u ) , wkr( v , u , dis[u] );
		}
	}
	rep( i , 2 , n ) printf("%lld\n",( dis[i] + pr[i] * 2 > 1e17 ? -1 : dis[i] + pr[i] * 2 ) );
	rep( i , 1 , n ) G[i].clear() , F[i].clear() , col[i] = 0 , T[i].clear();
}

signed main() {
//	freopen("color3_1.in","r",stdin);
//	freopen("ot","w",stdout);
//    int T;cin >> T;while( T-- ) solve();
	solve();
}
\