胡闹 序列

题目大意

给定一个长度为$n\le10^6$的序列$x$。

你需要从序列中选出一些位置。对于第$i$个位置,如果它被选中,你会获得$x_i$的收益;如果它没被选中,找到最小的$j$使得第$j$个位置到第$i$个位置都没有被选中,你需要付出$i−j+1$的代价。

此外,你选出的位置必须满足$x_i$是单调不下降的。

最大化收益减去代价的结果。

解析

可以先考虑一下$x_i$单调不降的情况。

设$dp_i$表示最后一个选的数是第$i$个的最佳答案,则:
$$
dp_i=x_i+max{dp_j+\frac{(i-j)\times(i-j-1)}{2}}
$$
这个东西是一个非常板子的斜率优化对吧。

然后怎么考虑$x_i$无序的情况呢?

我们发现,我们采用的顺序是可以换的。

那么我们对$x_i$进行分治,就是把$x_i$分成$(l,r)$的区间分别计算答案,然后归并一下就好了。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
//for test ? why
#define REP(i, s, e) for (int i = s; i <= e ; i++)
#include <stdio.h>
#include <algorithm>

typedef long long ll;

namespace io {
const int SIZE = (1 << 21) + 1;
char ibuf[SIZE], *iS, *iT, obuf[SIZE], *oS = obuf, *oT = oS + SIZE - 1, c, qu[55]; int f, qr;
// getchar
#define gc() (iS == iT ? (iT = (iS = ibuf) + fread (ibuf, 1, SIZE, stdin), (iS == iT ? EOF : *iS ++)) : *iS ++)
// print the remaining part
inline void flush () {
fwrite (obuf, 1, oS - obuf, stdout);
oS = obuf;
}
// putchar
inline void putc (char x) {
*oS ++ = x;
if (oS == oT) flush ();
}
// input a signed integer
template <class I>
inline void gi (I &x) {
for (f = 1, c = gc(); c < '0' || c > '9'; c = gc()) if (c == '-') f = -1;
for (x = 0; c <= '9' && c >= '0'; c = gc()) x = x * 10 + (c & 15); x *= f;
}
// print a signed integer
template <class I>
inline void print (I &x) {
if (!x) putc ('0'); if (x < 0) putc ('-'), x = -x;
while (x) qu[++ qr] = x % 10 + '0', x /= 10;
while (qr) putc (qu[qr --]);
}
}
using io :: gi;
using io :: putc;
using io :: print;

const int N = 1000005, G = 21;

int n, a[N], id[N], rid[N], qu[N], ql, qr, pos[G][N], m;
ll dp[N], ret;

bool check (int x, int y, int z) { return (dp[z] - dp[y]) * (y - x) >= (dp[y] - dp[x]) * (z - y); }

void mergesort (int l, int r, int d) {
if (l == r) return (void)(pos[d][l] = id[l]);
int mi = (l + r) >> 1, i = l, j = mi + 1;
mergesort (l, mi, d + 1); mergesort (mi + 1, r, d + 1);
int *x = pos[d], *y = pos[d + 1], p = l;
while (p <= r) if (i <= mi && (y[i] < y[j] || j > r)) x[p ++] = y[i ++]; else x[p ++] = y[j ++];
}

void solve (int l, int r, int d) {
if (l == r) {
int u = id[l];
dp[u] += a[u] - (ll)u * u;
return ;
}
int mi = (l + r) >> 1, i;
solve (l, mi, d + 1);
for (ql = 1, qr = 0, i = l; i <= r; i ++) {
int u = pos[d][i];
if (rid[u] <= mi) {
while (ql < qr && check (qu[qr - 1], qu[qr], u)) qr --;
qu[++ qr] = u;
} else {
while (ql < qr && dp[qu[ql]] + (ll)u * qu[ql] <= dp[qu[ql + 1]] + (ll)u * qu[ql + 1]) ql ++;
if (ql <= qr) dp[u] = std :: max (dp[u], dp[qu[ql]] + (ll)u * qu[ql]);
}
}
solve (mi + 1, r, d + 1);
}

int main () {
int i;
for (gi (n), i = 1; i <= n; i ++) gi (a[i]), id[i] = i, dp[i] = 0;
std :: sort (id + 1, id + n + 1, [&] (const int &x, const int &y) { return a[x] == a[y] ? x < y : a[x] < a[y]; });
for (i = 1; i <= n; i ++) rid[id[i]] = i;
;mergesort (1, n, 0); solve (1, n, 0);
for (ret = -(ll)n * (n + 1) / 2, i = 1; i <= n; i ++) ret = std :: max (ret, dp[i] + (ll)i * (i + 1) / 2 - (ll)(n - i) * (n - i + 1) / 2);
printf ("%lld\n", ret);
return 0;
}

评论

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×