Skip to content

Commit

Permalink
Parse noises in runners
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jan 28, 2025
1 parent d57c8ea commit ee204db
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
19 changes: 16 additions & 3 deletions skrl/utils/runner/jax/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from skrl.agents.jax import Agent
from skrl.envs.wrappers.jax import MultiAgentEnvWrapper, Wrapper
from skrl.models.jax import Model
from skrl.resources.noises.jax import GaussianNoise, OrnsteinUhlenbeckNoise # noqa
from skrl.resources.preprocessors.jax import RunningStandardScaler # noqa
from skrl.resources.schedulers.jax import KLAdaptiveLR # noqa
from skrl.trainers.jax import Trainer
Expand Down Expand Up @@ -148,6 +149,8 @@ def _process_cfg(self, cfg: dict) -> dict:
"state_preprocessor",
"value_preprocessor",
"amp_state_preprocessor",
"noise",
"smooth_regularization_noise",
]

def reward_shaper_function(scale):
Expand All @@ -162,7 +165,7 @@ def update_dict(d):
update_dict(value)
else:
if key in _direct_eval:
if type(d[key]) is str:
if isinstance(value, str):
d[key] = eval(value)
elif key.endswith("_kwargs"):
d[key] = value if value is not None else {}
Expand Down Expand Up @@ -311,8 +314,18 @@ def _generate_agent(
agent_id = possible_agents[0]
agent_cfg = self._component(f"{agent_class}_DEFAULT_CONFIG").copy()
agent_cfg.update(self._process_cfg(cfg["agent"]))
agent_cfg["state_preprocessor_kwargs"].update({"size": observation_spaces[agent_id], "device": device})
agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": device})
agent_cfg.get("state_preprocessor_kwargs", {}).update(
{"size": observation_spaces[agent_id], "device": device}
)
agent_cfg.get("value_preprocessor_kwargs", {}).update({"size": 1, "device": device})
if agent_cfg.get("exploration", {}).get("noise", None):
agent_cfg["exploration"]["noise"] = agent_cfg["exploration"]["noise"](
**agent_cfg["exploration"].get("noise_kwargs", {})
)
if agent_cfg.get("smooth_regularization_noise", None):
agent_cfg["smooth_regularization_noise"] = agent_cfg["smooth_regularization_noise"](
**agent_cfg.get("smooth_regularization_noise_kwargs", {})
)
agent_kwargs = {
"models": models[agent_id],
"memory": memories[agent_id],
Expand Down
23 changes: 18 additions & 5 deletions skrl/utils/runner/torch/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from skrl.agents.torch import Agent
from skrl.envs.wrappers.torch import MultiAgentEnvWrapper, Wrapper
from skrl.models.torch import Model
from skrl.resources.noises.torch import GaussianNoise, OrnsteinUhlenbeckNoise # noqa
from skrl.resources.preprocessors.torch import RunningStandardScaler # noqa
from skrl.resources.schedulers.torch import KLAdaptiveLR # noqa
from skrl.trainers.torch import Trainer
Expand Down Expand Up @@ -160,6 +161,8 @@ def _process_cfg(self, cfg: dict) -> dict:
"state_preprocessor",
"value_preprocessor",
"amp_state_preprocessor",
"noise",
"smooth_regularization_noise",
]

def reward_shaper_function(scale):
Expand All @@ -174,7 +177,7 @@ def update_dict(d):
update_dict(value)
else:
if key in _direct_eval:
if type(d[key]) is str:
if isinstance(value, str):
d[key] = eval(value)
elif key.endswith("_kwargs"):
d[key] = value if value is not None else {}
Expand Down Expand Up @@ -263,7 +266,7 @@ def _generate_models(
roles = list(models_cfg.keys())
if len(roles) != 2:
raise ValueError(
"Runner currently only supports shared models, made up of exactly two models."
"Runner currently only supports shared models, made up of exactly two models. "
"Set 'separate' field to True to create non-shared models for the given cfg"
)
# get shared model structure and parameters
Expand Down Expand Up @@ -402,12 +405,22 @@ def _generate_agent(
"reply_buffer": reply_buffer,
"collect_reference_motions": lambda num_samples: env.collect_reference_motions(num_samples),
}
if agent_class in ["a2c", "cem", "ddpg", "ddqn", "dqn", "ppo", "rpo", "sac", "td3", "trpo"]:
elif agent_class in ["a2c", "cem", "ddpg", "ddqn", "dqn", "ppo", "rpo", "sac", "td3", "trpo"]:
agent_id = possible_agents[0]
agent_cfg = self._component(f"{agent_class}_DEFAULT_CONFIG").copy()
agent_cfg.update(self._process_cfg(cfg["agent"]))
agent_cfg["state_preprocessor_kwargs"].update({"size": observation_spaces[agent_id], "device": device})
agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": device})
agent_cfg.get("state_preprocessor_kwargs", {}).update(
{"size": observation_spaces[agent_id], "device": device}
)
agent_cfg.get("value_preprocessor_kwargs", {}).update({"size": 1, "device": device})
if agent_cfg.get("exploration", {}).get("noise", None):
agent_cfg["exploration"]["noise"] = agent_cfg["exploration"]["noise"](
**agent_cfg["exploration"].get("noise_kwargs", {})
)
if agent_cfg.get("smooth_regularization_noise", None):
agent_cfg["smooth_regularization_noise"] = agent_cfg["smooth_regularization_noise"](
**agent_cfg.get("smooth_regularization_noise_kwargs", {})
)
agent_kwargs = {
"models": models[agent_id],
"memory": memories[agent_id],
Expand Down

0 comments on commit ee204db

Please sign in to comment.