-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathapo.py
130 lines (104 loc) · 4.78 KB
/
apo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import tensorflow as tf
from tonic import logger, replays
from tonic.tensorflow import agents, updaters, models
from segments import Segment
def default_model(actor_sizes=(64, 64), actor_activation='tanh',
critic_sizes=(64, 64), critic_activation='tanh',
observation_normalizer=None):
return models.ActorCritic(
actor=models.Actor(
encoder=models.ObservationEncoder(),
torso=models.MLP(actor_sizes, actor_activation),
head=models.DetachedScaleGaussianPolicyHead()),
critic=models.Critic(
encoder=models.ObservationEncoder(),
torso=models.MLP(critic_sizes, critic_activation),
head=models.ValueHead()),
observation_normalizer=observation_normalizer)
class PPO(agents.PPO):
def __init__(
self, discount_factor=1.0, trace_decay=0.95
):
model = default_model()
replay = replays.Segment(discount_factor=discount_factor, batch_size=256, batch_iterations=10, trace_decay=trace_decay)
actor_updater = updaters.ClippedRatio(gradient_clip=10)
critic_updater = updaters.VRegression(gradient_clip=10)
super().__init__(
model=model, replay=replay, actor_updater=actor_updater,
critic_updater=critic_updater)
@tf.function
def _test_step(self, observations):
return self.model.actor(observations).mean()
class APO(agents.A2C):
'''Average-Reward Reinforcement Learning with Trust Region Methods.
APO: https://arxiv.org/pdf/2106.03442.pdf
'''
def __init__(
self, alpha=0.1, v=0.1, trace_decay=0.95
):
model = default_model()
replay = Segment(discount_factor=1.0, batch_size=256, batch_iterations=10, trace_decay=trace_decay)
actor_updater = updaters.ClippedRatio(gradient_clip=10)
critic_updater = updaters.VRegression(gradient_clip=10)
self.alpha = alpha
self.v = v
super().__init__(
model=model, replay=replay, actor_updater=actor_updater,
critic_updater=critic_updater)
def initialize(self, observation_space, action_space, seed=None):
super().initialize(observation_space, action_space, seed=seed)
self.model.initialize(observation_space, action_space)
self.replay.initialize(seed)
self.actor_updater.initialize(self.model)
self.critic_updater.initialize(self.model)
self.eta = 0
self.b = 0
def _update(self):
# Compute the lambda-returns.
batch = self.replay.get_full('observations', 'next_observations')
values, next_values = self._evaluate(**batch)
values, next_values = values.numpy(), next_values.numpy()
rewards = self.replay.get_full('rewards')['rewards']
self.eta = (1 - self.alpha) * self.eta + self.alpha * rewards.mean()
self.b = (1 - self.alpha) * self.b + self.alpha * values.mean()
self.replay.compute_returns(values, next_values, self.eta)
train_actor = True
actor_iterations = 0
critic_iterations = 0
keys = 'observations', 'actions', 'advantages', 'log_probs', 'returns'
# Update both the actor and the critic multiple times.
for batch in self.replay.get(*keys):
if train_actor:
infos = self._update_actor_critic(**batch)
actor_iterations += 1
else:
batch = {k: batch[k] for k in ('observations', 'returns')}
infos = dict(critic=self.critic_updater(**batch))
critic_iterations += 1
# Stop earlier the training of the actor.
if train_actor:
train_actor = not infos['actor']['stop'].numpy()
for key in infos:
for k, v in infos[key].items():
logger.store(key + '/' + k, v.numpy())
logger.store('actor/iterations', actor_iterations)
logger.store('critic/iterations', critic_iterations)
logger.store('average_values', values.mean())
# Update the normalizers.
if self.model.observation_normalizer:
self.model.observation_normalizer.update()
if self.model.return_normalizer:
self.model.return_normalizer.update()
@tf.function
def _update_actor_critic(
self, observations, actions, advantages, log_probs, returns
):
actor_infos = self.actor_updater(
observations, actions, advantages, log_probs)
# Average Value Constraint
returns = returns - self.v * self.b
critic_infos = self.critic_updater(observations, returns)
return dict(actor=actor_infos, critic=critic_infos)
@tf.function
def _test_step(self, observations):
return self.model.actor(observations).mean()