-
Notifications
You must be signed in to change notification settings - Fork 335
/
segment_tree.py
142 lines (107 loc) · 4.18 KB
/
segment_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# -*- coding: utf-8 -*-
"""Segment tree for Prioritized Replay Buffer."""
import operator
from typing import Callable
class SegmentTree:
""" Create SegmentTree.
Taken from OpenAI baselines github repository:
https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py
Attributes:
capacity (int)
tree (list)
operation (function)
"""
def __init__(self, capacity: int, operation: Callable, init_value: float):
"""Initialization.
Args:
capacity (int)
operation (function)
init_value (float)
"""
assert (
capacity > 0 and capacity & (capacity - 1) == 0
), "capacity must be positive and a power of 2."
self.capacity = capacity
self.tree = [init_value for _ in range(2 * capacity)]
self.operation = operation
def _operate_helper(
self, start: int, end: int, node: int, node_start: int, node_end: int
) -> float:
"""Returns result of operation in segment."""
if start == node_start and end == node_end:
return self.tree[node]
mid = (node_start + node_end) // 2
if end <= mid:
return self._operate_helper(start, end, 2 * node, node_start, mid)
else:
if mid + 1 <= start:
return self._operate_helper(start, end, 2 * node + 1, mid + 1, node_end)
else:
return self.operation(
self._operate_helper(start, mid, 2 * node, node_start, mid),
self._operate_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end),
)
def operate(self, start: int = 0, end: int = 0) -> float:
"""Returns result of applying `self.operation`."""
if end <= 0:
end += self.capacity
end -= 1
return self._operate_helper(start, end, 1, 0, self.capacity - 1)
def __setitem__(self, idx: int, val: float):
"""Set value in tree."""
idx += self.capacity
self.tree[idx] = val
idx //= 2
while idx >= 1:
self.tree[idx] = self.operation(self.tree[2 * idx], self.tree[2 * idx + 1])
idx //= 2
def __getitem__(self, idx: int) -> float:
"""Get real value in leaf node of tree."""
assert 0 <= idx < self.capacity
return self.tree[self.capacity + idx]
class SumSegmentTree(SegmentTree):
""" Create SumSegmentTree.
Taken from OpenAI baselines github repository:
https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py
"""
def __init__(self, capacity: int):
"""Initialization.
Args:
capacity (int)
"""
super(SumSegmentTree, self).__init__(
capacity=capacity, operation=operator.add, init_value=0.0
)
def sum(self, start: int = 0, end: int = 0) -> float:
"""Returns arr[start] + ... + arr[end]."""
return super(SumSegmentTree, self).operate(start, end)
def retrieve(self, upperbound: float) -> int:
"""Find the highest index `i` about upper bound in the tree"""
# TODO: Check assert case and fix bug
assert 0 <= upperbound <= self.sum() + 1e-5, "upperbound: {}".format(upperbound)
idx = 1
while idx < self.capacity: # while non-leaf
left = 2 * idx
right = left + 1
if self.tree[left] > upperbound:
idx = 2 * idx
else:
upperbound -= self.tree[left]
idx = right
return idx - self.capacity
class MinSegmentTree(SegmentTree):
""" Create SegmentTree.
Taken from OpenAI baselines github repository:
https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py
"""
def __init__(self, capacity: int):
"""Initialization.
Args:
capacity (int)
"""
super(MinSegmentTree, self).__init__(
capacity=capacity, operation=min, init_value=float("inf")
)
def min(self, start: int = 0, end: int = 0) -> float:
"""Returns min(arr[start], ..., arr[end])."""
return super(MinSegmentTree, self).operate(start, end)