-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsegment-tree.py
95 lines (83 loc) · 2.42 KB
/
segment-tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import math
# Class definition
##################
class SegmentTree():
def __init__(self, array, query_fun):
self.query_fun = query_fun
try:
t = len(query_fun(None))
except:
t = 1
self.array = array
self.n = len(array)
self.n_nodes = 2 * self.n - 1
self.height = math.ceil(math.log2(self.n))
self.size = 2 ** (self.height + 1) - 1;
self.st = np.empty((self.size, t), dtype=int)
self._build(0, self.n - 1, 0)
def _center(self, sx, ex):
return sx + (ex - sx) // 2;
def _build(self, sx, ex, i):
if sx == ex:
self.st[i] = self.array[sx]
return self.query_fun(self.st[i])
else:
c = self._center(sx, ex)
self.st[i] = self.query_fun(self._build(sx, c, i * 2 + 1),
self._build(c + 1, ex, i * 2 + 2))
return self.st[i]
def _query(self, l, r, sx, ex, i):
if l <= sx and r >= ex:
# Segment is full part of query
return self.query_fun(self.st[i])
elif l > ex or r < sx:
# No intersection
return self.query_fun(None)
else:
c = self._center(sx, ex)
return self.query_fun(self._query(l, r, sx, c, 2 * i + 1),
self._query(l, r, c + 1, ex, 2 * i + 2))
def query(self, l, r):
return self._query(l, r, 0, self.n - 1, 0)
# Query functions to run queries
################################
def minmax_fun(x, y=None):
if x is None:
# Empty case
return (np.iinfo(int).max, 0)
elif y is None:
# Single case
return (x[0], x[1])
else:
# Double case
return (min(x[0], y[0]), max(x[1], y[1]))
def min_fun(x, y=None):
if x is None:
# Empty case
return np.iinfo(st.st[0].dtype).max
elif y is None:
# Single case
return x
else:
# Double case
return min(x, y)
def max_fun(x, y=None):
if x is None:
# Empty case
return 0
elif y is None:
# Single case
return x
else:
# Double case
return max(x, y)
# Test code
###########
x = [1, 4, 9, 8, 7, 2, 3, 5, 6]
l, r = 2, 6
print("Array:", x)
st = SegmentTree(x, minmax_fun)
min, max = st.query(l, r)
print("Min:", min, "Max:", max)
print("for bounds of [{},{}]".format(l, r))