1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| #include "iostream" #include "algorithm" #include "cstring" #include "cstdio" using namespace std; #define MAXN 100006 #define chkmn( a , b ) ( (a) > (b) ? ( (a) = (b) , 1 ) : 0 ) #define chkmx( a , b ) ( (a) < (b) ? ( (a) = (b) , 1 ) : 0 ) #define f( a ) ( (a) > 0 ? (a) : 0 ) int n , k; int A[MAXN] , cn[MAXN]; int L = 1 , R = 0; long long cw , dp[MAXN] , nw[MAXN]; void upd(int c,int d){ cw += 1ll * d * cn[c] * (cn[c] - 1) / 2;} long long getw( int l , int r ) { while( L < l ) upd( A[L] , -1 ) , -- cn[A[L]] , upd( A[L] , 1 ) , ++ L; while( R > r ) upd( A[R] , -1 ) , -- cn[A[R]] , upd( A[R] , 1 ) , -- R; while( L > l ) -- L , upd( A[L] , -1 ) , ++ cn[A[L]] , upd( A[L] , 1 ); while( R < r ) ++ R , upd( A[R] , -1 ) , ++ cn[A[R]] , upd( A[R] , 1 ); return cw; } void solve( int l , int r , int L , int R ) { if( l > r ) return; int mid = l + r >> 1 , op; nw[mid] = 0x3f3f3f3f3f3f3f3f; for( int i = min( R , mid ) ; i >= L ; -- i ) { if( chkmn( nw[mid] , dp[i - 1] + getw( i , mid ) ) ) op = i; } if( l == r ) return; solve( l , mid - 1 , L , op ) , solve( mid + 1 , r , op , R ); } int main() { cin >> n >> k; for( int i = 1 ; i <= n ; ++ i ) scanf("%d",&A[i]); memset( dp , 0x3f3f , sizeof dp ) , memset( nw , 0x3f3f , sizeof nw ) , dp[0] = 0; for( int i = 1 ; i <= k ; ++ i ) { solve( 1 , n , 1 , n ); swap( nw , dp ); } cout << dp[n] << endl; }
|