Skip to content

Commit

Permalink
Add scale reward annealing
Browse files Browse the repository at this point in the history
  • Loading branch information
hartikainen committed Apr 10, 2018
1 parent b3ed03d commit e346359
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 10 deletions.
10 changes: 9 additions & 1 deletion examples/mujoco_all_sac_real_nvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,14 @@ def run_experiment(variant):
n_map_action_candidates=variant['n_map_action_candidates']
)

if variant['scale_reward'] == 'piecewise_constant':
boundaries = variant['scale_reward_boundaries']
values = variant['scale_reward_values']
scale_reward = lambda iteration: (
tf.train.piecewise_constant(iteration, boundaries, values))
else:
scale_reward = variant['scale_reward']

algorithm = SACV2(
base_kwargs=base_kwargs,
env=env,
Expand All @@ -289,7 +297,7 @@ def run_experiment(variant):
vf=vf,
lr=variant['lr'],
policy_lr=variant['policy_lr'],
scale_reward=variant['scale_reward'],
scale_reward=scale_reward,
discount=variant['discount'],
tau=variant['tau'],
target_update_interval=variant['target_update_interval'],
Expand Down
5 changes: 3 additions & 2 deletions sac/algos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,12 @@ def _evaluate(self, epoch):
if self._eval_render:
self._eval_env.render(paths)

iteration = epoch*self._epoch_length
batch = self._pool.random_batch(self._batch_size)
self.log_diagnostics(batch)
self.log_diagnostics(iteration, batch)

@abc.abstractmethod
def log_diagnostics(self, batch):
def log_diagnostics(self, iteration, batch):
raise NotImplementedError

@abc.abstractmethod
Expand Down
29 changes: 22 additions & 7 deletions sac/algos/sac_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def _init_placeholders(self):
- reward
- terminals
"""
self._iteration_pl = tf.placeholder(
tf.int64, shape=None, name='iteration')

self._obs_pl = tf.placeholder(
tf.float32,
Expand Down Expand Up @@ -193,6 +195,16 @@ def _init_placeholders(self):
name='terminals',
)

@property
def scale_reward(self):
if callable(self._scale_reward):
return self._scale_reward(self._iteration_pl)
elif isinstance(self._scale_reward, Number):
return self._scale_reward

raise ValueError(
'scale_reward must be either callable or scalar')

def _init_critic_update(self):
"""Create minimization operation for critic Q-function.
Expand All @@ -212,7 +224,7 @@ def _init_critic_update(self):
self._vf_target_params = self._vf.get_params_internal()

ys = tf.stop_gradient(
self._scale_reward * self._reward_pl +
self.scale_reward * self._reward_pl +
(1 - self._terminal_pl) * self._discount * vf_next_target_t
) # N

Expand Down Expand Up @@ -305,17 +317,17 @@ def _init_training(self, env, policy, pool):
self._sess.run(self._target_ops)

@overrides
def _do_training(self, itr, batch):
def _do_training(self, iteration, batch):
"""Runs the operations for updating training and target ops."""

feed_dict = self._get_feed_dict(batch)
feed_dict = self._get_feed_dict(iteration, batch)
self._sess.run(self._training_ops, feed_dict)

if itr % self._target_update_interval == 0:
if iteration % self._target_update_interval == 0:
# Run target ops here.
self._sess.run(self._target_ops)

def _get_feed_dict(self, batch):
def _get_feed_dict(self, iteration, batch):
"""Construct TensorFlow feed_dict from sample batch."""

feed_dict = {
Expand All @@ -326,10 +338,13 @@ def _get_feed_dict(self, batch):
self._terminal_pl: batch['terminals'],
}

if iteration is not None:
feed_dict[self._iteration_pl] = iteration

return feed_dict

@overrides
def log_diagnostics(self, batch):
def log_diagnostics(self, iteration, batch):
"""Record diagnostic information to the logger.
Records mean and standard deviation of Q-function and state
Expand All @@ -339,7 +354,7 @@ def log_diagnostics(self, batch):
Also calls the `draw` method of the plotter, if plotter defined.
"""

feed_dict = self._get_feed_dict(batch)
feed_dict = self._get_feed_dict(iteration, batch)
qf, vf, td_loss = self._sess.run(
[self._qf_t, self._vf_t, self._td_loss_t], feed_dict)

Expand Down

0 comments on commit e346359

Please sign in to comment.