forked from EndlessCheng/codeforces-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtreap_kthsum.go
160 lines (143 loc) · 3.01 KB
/
treap_kthsum.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
package copypasta
import "time"
/* 维护前 k 小元素和的 treap
支持添加删除元素
https://leetcode.cn/problems/divide-an-array-into-subarrays-with-minimum-cost-ii/
https://atcoder.jp/contests/abc306/tasks/abc306_e
https://atcoder.jp/contests/abc287/tasks/abc287_g
*/
type nodeSum struct {
lr [2]*nodeSum
priority uint
key int
keyCnt int
subSize int
keySum int
subSum int
}
func (o *nodeSum) cmp(a int) int {
b := o.key
if a == b {
return -1
}
if a < b {
return 0
}
return 1
}
func (o *nodeSum) getSize() int {
if o != nil {
return o.subSize
}
return 0
}
func (o *nodeSum) getSum() int {
if o != nil {
return o.subSum
}
return 0
}
func (o *nodeSum) maintain() {
o.subSize = o.keyCnt + o.lr[0].getSize() + o.lr[1].getSize()
o.subSum = o.keySum + o.lr[0].getSum() + o.lr[1].getSum()
}
func (o *nodeSum) rotate(d int) *nodeSum {
x := o.lr[d^1]
o.lr[d^1] = x.lr[d]
x.lr[d] = o
o.maintain()
x.maintain()
return x
}
type treapSum struct {
rd uint
root *nodeSum
}
func (t *treapSum) fastRand() uint {
t.rd ^= t.rd << 13
t.rd ^= t.rd >> 17
t.rd ^= t.rd << 5
return t.rd
}
func (t *treapSum) _put(o *nodeSum, key, num int) *nodeSum {
if o == nil {
o = &nodeSum{priority: t.fastRand(), key: key, keyCnt: num, keySum: key * num}
} else if d := o.cmp(key); d >= 0 {
o.lr[d] = t._put(o.lr[d], key, num)
if o.lr[d].priority > o.priority {
o = o.rotate(d ^ 1)
}
} else {
o.keyCnt += num
o.keySum += key * num
}
o.maintain()
return o
}
// num=1 表示添加一个 key
// num=-1 表示移除一个 key
func (t *treapSum) put(key, num int) { t.root = t._put(t.root, key, num) }
func newTreapSum() *treapSum {
return &treapSum{rd: uint(time.Now().UnixNano())/2 + 1}
}
// <= size 的元素个数,元素和
// LC3245 https://leetcode.cn/problems/alternating-groups-iii/
func (t *treapSum) cntSum(size int) (cnt, sum int) {
for o := t.root; o != nil; {
c := o.cmp(size)
if c == 0 { // size 小,去左子树找
o = o.lr[0]
continue
}
// 左子树 + 自己
cnt += o.lr[0].getSize() + o.keyCnt
sum += o.lr[0].getSum() + o.keySum
if c < 0 { // 相等
return
}
o = o.lr[1]
}
return
}
// 返回前 k 小数的和(k 从 1 开始)
func (t *treapSum) kth(k int) (sum int) {
if k > t.root.getSize() {
panic(-1)
}
for o := t.root; o != nil; {
if ls := o.lr[0].getSize(); k < ls {
o = o.lr[0]
} else {
sum += o.lr[0].getSum()
k -= ls
if k <= o.keyCnt {
sum += o.key * k
return
}
sum += o.keySum
k -= o.keyCnt
o = o.lr[1]
}
}
return
}
// 从大到小,计算凑出 need 至少需要多少个数
// cmp 需要改成 >
// 来自 https://codeforces.com/contest/1978/problem/D 的麻烦写法
func (t *treapSum) rank(need int) (cnt int) {
for o := t.root; o != nil; {
if o.lr[0].getSum() >= need {
o = o.lr[0]
} else {
need -= o.lr[0].getSum()
cnt += o.lr[0].getSize()
if o.keyCnt*o.key >= need {
cnt += (need + o.key - 1) / o.key
return
}
need -= o.keyCnt * o.key
o = o.lr[1]
}
}
panic(-1)
}