-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsegments.py
99 lines (82 loc) · 3.36 KB
/
segments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
from tonic import replays
def lambda_returns(
values, next_values, rewards, resets, terminations, discount_factor,
trace_decay, eta
):
'''Function used to calculate lambda-returns on parallel buffers.'''
returns = np.zeros_like(values)
last_returns = next_values[-1]
for t in reversed(range(len(rewards))):
bootstrap = (
(1 - trace_decay) * next_values[t] + trace_decay * last_returns)
bootstrap *= (1 - resets[t])
bootstrap += resets[t] * next_values[t]
bootstrap *= (1 - terminations[t])
returns[t] = last_returns = (
# average TD residual
rewards[t]-eta) + discount_factor * bootstrap
return returns
class Segment:
'''Replay storing recent transitions for on-policy learning.'''
def __init__(
self, size=4096, batch_iterations=80, batch_size=None,
discount_factor=0.99, trace_decay=0.97
):
self.max_size = size
self.batch_iterations = batch_iterations
self.batch_size = batch_size
self.discount_factor = discount_factor
self.trace_decay = trace_decay
def initialize(self, seed=None):
self.np_random = np.random.RandomState(seed)
self.buffers = None
self.index = 0
def ready(self):
return self.index == self.max_size
def store(self, **kwargs):
if self.buffers is None:
self.num_workers = len(list(kwargs.values())[0])
self.buffers = {}
for key, val in kwargs.items():
shape = (self.max_size,) + np.array(val).shape
self.buffers[key] = np.zeros(shape, np.float32)
for key, val in kwargs.items():
self.buffers[key][self.index] = val
self.index += 1
def get_full(self, *keys):
self.index = 0
if 'advantages' in keys:
advs = self.buffers['returns'] - self.buffers['values']
std = advs.std()
if std != 0:
advs = (advs - advs.mean()) / std
self.buffers['advantages'] = advs
return {k: replays.flatten_batch(self.buffers[k]) for k in keys}
def get(self, *keys):
'''Get mini-batches from named buffers.'''
batch = self.get_full(*keys)
if self.batch_size is None:
for _ in range(self.batch_iterations):
yield batch
else:
size = self.max_size * self.num_workers
all_indices = np.arange(size)
for _ in range(self.batch_iterations):
self.np_random.shuffle(all_indices)
for i in range(0, size, self.batch_size):
indices = all_indices[i:i + self.batch_size]
yield {k: v[indices] for k, v in batch.items()}
def compute_returns(self, values, next_values, eta):
shape = self.buffers['rewards'].shape
self.buffers['values'] = values.reshape(shape)
self.buffers['next_values'] = next_values.reshape(shape)
self.buffers['returns'] = lambda_returns(
values=self.buffers['values'],
next_values=self.buffers['next_values'],
rewards=self.buffers['rewards'],
resets=self.buffers['resets'],
terminations=self.buffers['terminations'],
discount_factor=self.discount_factor,
trace_decay=self.trace_decay,
eta=eta)