Skip to content

Commit

Permalink
Enable using squashed gaussian in PPO when entropy coef is 0
Browse files Browse the repository at this point in the history
  • Loading branch information
ishihara-y committed Sep 17, 2024
1 parent 194aa9b commit 03908bc
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 8 deletions.
10 changes: 6 additions & 4 deletions nnabla_rl/model_trainers/policy/ppo_policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,12 @@ def _build_one_step_graph(self, models: Sequence[Model], training_variables: Tra
lower_bounds = NF.minimum2(probability_ratio * advantage, clipped_ratio * advantage)
clip_loss = NF.mean(lower_bounds)

entropy = distribution.entropy()
entropy_loss = NF.mean(entropy)

self._pi_loss += 0.0 if ignore_loss else (-clip_loss - self._config.entropy_coefficient * entropy_loss)
if self._config.entropy_coefficient != 0.0:
entropy = distribution.entropy()
entropy_loss = NF.mean(entropy)
self._pi_loss += 0.0 if ignore_loss else (-clip_loss - self._config.entropy_coefficient * entropy_loss)
else:
self._pi_loss += 0.0 if ignore_loss else -clip_loss

def _setup_training_variables(self, batch_size) -> TrainingVariables:
# Training input variables
Expand Down
73 changes: 69 additions & 4 deletions tests/model_trainers/test_policy_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
import nnabla.parametric_functions as NPF
import nnabla_rl.model_trainers as MT
from nnabla_rl.distributions.gaussian import Gaussian
from nnabla_rl.distributions.squashed_gaussian import SquashedGaussian
from nnabla_rl.environments.dummy import DummyContinuous
from nnabla_rl.environments.environment_info import EnvironmentInfo
from nnabla_rl.model_trainers.model_trainer import LossIntegration
from nnabla_rl.model_trainers.policy.dpg_policy_trainer import DPGPolicyTrainer
from nnabla_rl.model_trainers.policy.ppo_policy_trainer import PPOPolicyTrainer
from nnabla_rl.model_trainers.policy.soft_policy_trainer import AdjustableTemperature, SoftPolicyTrainer
from nnabla_rl.model_trainers.policy.trpo_policy_trainer import (
_concat_network_params_in_ndarray,
Expand Down Expand Up @@ -65,11 +67,30 @@ def pi(self, s):
return s


class StochasticNonRnnPolicy(StochasticPolicy):
def __init__(self, scope_name: str, squash: bool = False):
super().__init__(scope_name)
self._squash = squash

def pi(self, s):
if self._squash:
return SquashedGaussian(
mean=nn.Variable.from_numpy_array(np.zeros(s.shape)),
ln_var=nn.Variable.from_numpy_array(np.ones(s.shape)),
)
else:
return Gaussian(
mean=nn.Variable.from_numpy_array(np.zeros(s.shape)),
ln_var=nn.Variable.from_numpy_array(np.ones(s.shape)),
)


class StochasticRnnPolicy(StochasticPolicy):
def __init__(self, scope_name: str):
def __init__(self, scope_name: str, squash: bool = False):
super().__init__(scope_name)
self._internal_state_shape = (10,)
self._fake_internal_state = None
self._squash = squash

def is_recurrent(self) -> bool:
return True
Expand All @@ -85,9 +106,16 @@ def get_internal_states(self):

def pi(self, s):
self._fake_internal_state = self._fake_internal_state * 2
return Gaussian(
mean=nn.Variable.from_numpy_array(np.zeros(s.shape)), ln_var=nn.Variable.from_numpy_array(np.ones(s.shape))
)
if self._squash:
return SquashedGaussian(
mean=nn.Variable.from_numpy_array(np.zeros(s.shape)),
ln_var=nn.Variable.from_numpy_array(np.ones(s.shape)),
)
else:
return Gaussian(
mean=nn.Variable.from_numpy_array(np.zeros(s.shape)),
ln_var=nn.Variable.from_numpy_array(np.ones(s.shape)),
)


class TrainerTest(metaclass=ABCMeta):
Expand Down Expand Up @@ -349,6 +377,43 @@ def test_with_rnn_model(self, unroll_steps, burn_in_steps, loss_integration):
# pass: If no ecror occurs


class TestPPOPolicyTrainer(TrainerTest):
def setup_method(self, method):
nn.clear_parameters()

@pytest.mark.parametrize("unroll_steps", [1, 2])
@pytest.mark.parametrize("burn_in_steps", [0, 1, 2])
@pytest.mark.parametrize("loss_integration", [LossIntegration.LAST_TIMESTEP_ONLY, LossIntegration.ALL_TIMESTEPS])
@pytest.mark.parametrize("entropy_coefficient", [0.0, 1.0])
@pytest.mark.parametrize("squash", [True, False])
def test_with_non_rnn_model(self, unroll_steps, burn_in_steps, loss_integration, entropy_coefficient, squash):
env_info = EnvironmentInfo.from_env(DummyContinuous())

policy = StochasticNonRnnPolicy("stub_pi", squash=squash)
config = MT.policy_trainers.PPOPolicyTrainerConfig(
unroll_steps=unroll_steps,
burn_in_steps=burn_in_steps,
loss_integration=loss_integration,
entropy_coefficient=entropy_coefficient,
)
if squash and entropy_coefficient != 0.0:
with pytest.raises(NotImplementedError):
PPOPolicyTrainer(
policy,
solvers={},
env_info=env_info,
config=config,
)
else:
PPOPolicyTrainer(
policy,
solvers={},
env_info=env_info,
config=config,
)
# pass: If no ecror occurs


class TestAdjustableTemperature(TrainerTest):
def test_initial_temperature(self):
initial_value = 5.0
Expand Down

0 comments on commit 03908bc

Please sign in to comment.