From 550d397b88f9c1a9e5f17968e736e347027869ab Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 6 Sep 2023 14:59:04 +0200 Subject: [PATCH 001/109] vtrace module --- examples/impala/config.yaml | 35 ++++ examples/impala/impala.py | 215 +++++++++++++++++++++ examples/impala/utils.py | 237 +++++++++++++++++++++++ torchrl/objectives/value/vtrace.py | 294 ++++++++++++++++++++++++++++- 4 files changed, 777 insertions(+), 4 deletions(-) create mode 100644 examples/impala/config.yaml create mode 100644 examples/impala/impala.py create mode 100644 examples/impala/utils.py diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml new file mode 100644 index 00000000000..1bee59c3adb --- /dev/null +++ b/examples/impala/config.yaml @@ -0,0 +1,35 @@ +# Environment +env: +env_name: PongNoFrameskip-v4 + +# collector +collector: +frames_per_batch: 4096 +total_frames: 40_000_000 + +# logger +logger: +backend: wandb +exp_name: Atari_Schulman17 +test_interval: 40_000_000 +num_test_episodes: 3 + +# Optim +optim: +lr: 2.5e-4 +eps: 1.0e-6 +weight_decay: 0.0 +max_grad_norm: 0.5 +anneal_lr: True + +# loss +loss: +gamma: 0.99 +mini_batch_size: 1024 +ppo_epochs: 3 +gae_lambda: 0.95 +clip_epsilon: 0.1 +anneal_clip_epsilon: True +critic_coef: 1.0 +entropy_coef: 0.01 +loss_critic_type: l2 \ No newline at end of file diff --git a/examples/impala/impala.py b/examples/impala/impala.py new file mode 100644 index 00000000000..edc666dfe9d --- /dev/null +++ b/examples/impala/impala.py @@ -0,0 +1,215 @@ +""" +This script reproduces the Proximal Policy Optimization (PPO) Algorithm +results from Schulman et al. 2017 for the on Atari Environments. +""" +import hydra + + +@hydra.main(config_path=".", config_name="config_atari", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import numpy as np + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.collectors.distributed import RPCDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import ClipPPOLoss + from torchrl.record.loggers import generate_exp_name, get_logger + from torchrl.objectives.value.vtrace import VTrace + from utils import make_parallel_env, make_ppo_models + + device = "cpu" if not torch.cuda.is_available() else "cuda" + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + mini_batch_size = cfg.loss.mini_batch_size // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Create models (check utils_atari.py) + actor, critic, critic_head = make_ppo_models(cfg.env.env_name) + actor, critic, critic_head = ( + actor.to(device), + critic.to(device), + critic_head.to(device), + ) + + # Create collector + # collector = RPCDataCollector( + # create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * 2, + # policy=actor, + # frames_per_batch=frames_per_batch, + # total_frames=total_frames, + # storing_device="cpu", + # max_frames_per_traj=-1, + # sync=False, + # ) + collector = SyncDataCollector( + create_env_fn=make_parallel_env(cfg.env.env_name, device), + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch), + sampler=sampler, + batch_size=mini_batch_size, + ) + + # Create loss and adv modules + vtrace_module = VTrace( + gamma=cfg.loss.gamma, + lmbda=cfg.loss.gae_lambda, + value_network=critic, + average_gae=False, + ) + loss_module = ClipPPOLoss( + actor=actor, + critic=critic, + clip_epsilon=cfg.loss.clip_epsilon, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + normalize_advantage=True, + ) + + # Create optimizer + optim = torch.optim.Adam( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + ) + + # Create logger + exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}") + logger = get_logger(cfg.logger.backend, logger_name="ppo", experiment_name=exp_name) + + # Create test environment + test_env = make_parallel_env(cfg.env.env_name, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=total_frames) + num_mini_batches = frames_per_batch // mini_batch_size + total_network_updates = ( + (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches + ) + + for data in collector: + + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Train loging + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + logger.log_scalar( + "reward_train", episode_rewards.mean().item(), collected_frames + ) + + # Apply episodic end of life + data["done"].copy_(data["end_of_life"]) + data["next", "done"].copy_(data["next", "end_of_life"]) + + losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches]) + for j in range(cfg.loss.ppo_epochs): + + # Compute VTrace + with torch.no_grad(): + import ipdb; ipdb.set_trace() + data = vtrace_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for i, batch in enumerate(data_buffer): + + # Linearly decrease the learning rate and clip epsilon + alpha = 1 - (num_network_updates / total_network_updates) + if cfg.optim.anneal_lr: + for g in optim.param_groups: + g["lr"] = cfg.optim.lr * alpha + if cfg.loss.anneal_clip_epsilon: + loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha) + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device) + + # Forward pass PPO loss + loss = loss_module(batch) + losses[j, i] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + logger.log_scalar(key, value.item(), collected_frames) + logger.log_scalar("lr", alpha * cfg.optim.lr, collected_frames) + logger.log_scalar( + "clip_epsilon", alpha * cfg.loss.clip_epsilon, collected_frames + ) + + # Test logging + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if (collected_frames - frames_in_batch) // test_interval < ( + collected_frames // test_interval + ): + actor.eval() + test_rewards = [] + for _ in range(cfg.logger.num_test_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards = np.append(test_rewards, reward.cpu().numpy()) + del td_test + logger.log_scalar("reward_test", test_rewards.mean(), collected_frames) + actor.train() + + collector.update_policy_weights_() + + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/impala/utils.py b/examples/impala/utils.py new file mode 100644 index 00000000000..e9112ead762 --- /dev/null +++ b/examples/impala/utils.py @@ -0,0 +1,237 @@ +import random + +import gym +import torch.nn +import torch.optim +from tensordict.nn import TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.data.tensor_specs import DiscreteBox +from torchrl.envs import ( + CatFrames, + default_info_dict_reader, + DoubleToFloat, + EnvCreator, + ExplorationType, + GrayScale, + ParallelEnv, + Resize, + RewardClipping, + RewardSum, + StepCounter, + ToTensorImage, + TransformedEnv, + VecNorm, +) +from torchrl.envs.libs.gym import GymWrapper +from torchrl.modules import ( + ActorValueOperator, + ConvNet, + MLP, + OneHotCategorical, + ProbabilisticActor, + TanhNormal, + ValueOperator, +) + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +class NoopResetEnv(gym.Wrapper): + def __init__(self, env, noop_max=30): + """Sample initial states by taking random number of no-ops on reset.""" + gym.Wrapper.__init__(self, env) + self.noop_max = noop_max + self.override_num_noops = None + self.noop_action = 0 # No-op is assumed to be action 0. + assert env.unwrapped.get_action_meanings()[0] == "NOOP" + + def reset(self, **kwargs): + """Do no-op action for a number of steps in [1, noop_max].""" + self.env.reset(**kwargs) + if self.override_num_noops is not None: + noops = self.override_num_noops + else: + noops = random.randint(1, self.noop_max + 1) + assert noops > 0 + obs = None + for _ in range(noops): + obs, _, done, *other = self.env.step(self.noop_action) + if done: + obs = self.env.reset(**kwargs) + return obs + + +class EpisodicLifeEnv(gym.Wrapper): + def __init__(self, env): + """Make end-of-life == end-of-episode, but only reset on true game over. + Done by DeepMind for the DQN and co. since it helps value estimation. + """ + gym.Wrapper.__init__(self, env) + self.lives = 0 + + def step(self, action): + obs, rew, done, info = self.env.step(action) + lives = self.env.unwrapped.ale.lives() + info["end_of_life"] = False + if (lives < self.lives) or done: + info["end_of_life"] = True + self.lives = lives + return obs, rew, done, info + + def reset(self, **kwargs): + reset_data = self.env.reset(**kwargs) + self.lives = self.env.unwrapped.ale.lives() + return reset_data + + +def make_base_env( + env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False +): + env = gym.make(env_name) + if not is_test: + env = NoopResetEnv(env, noop_max=30) + env = EpisodicLifeEnv(env) + env = GymWrapper( + env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device + ) + reader = default_info_dict_reader(["end_of_life"]) + env.set_info_dict_reader(reader) + return env + + +def make_parallel_env(env_name, device, is_test=False): + num_envs = 8 + env = ParallelEnv( + num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) + ) + env = TransformedEnv(env) + env.append_transform(ToTensorImage()) + env.append_transform(GrayScale()) + env.append_transform(Resize(84, 84)) + env.append_transform(CatFrames(N=4, dim=-3)) + env.append_transform(RewardSum()) + env.append_transform(StepCounter(max_steps=4500)) + if not is_test: + env.append_transform(RewardClipping(-1, 1)) + env.append_transform(DoubleToFloat()) + env.append_transform(VecNorm(in_keys=["pixels"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_modules_pixels(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["pixels"].shape + + # Define distribution class and kwargs + if isinstance(proof_environment.action_spec.space, DiscreteBox): + num_outputs = proof_environment.action_spec.space.n + distribution_class = OneHotCategorical + distribution_kwargs = {} + else: # is ContinuousBox + num_outputs = proof_environment.action_spec.shape + distribution_class = TanhNormal + distribution_kwargs = { + "min": proof_environment.action_spec.space.minimum, + "max": proof_environment.action_spec.space.maximum, + } + + # Define input keys + in_keys = ["pixels"] + + # Define a shared Module and TensorDictModule (CNN + MLP) + common_cnn = ConvNet( + activation_class=torch.nn.ReLU, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + common_cnn_output = common_cnn(torch.ones(input_shape)) + common_mlp = MLP( + in_features=common_cnn_output.shape[-1], + activation_class=torch.nn.ReLU, + activate_last_layer=True, + out_features=512, + num_cells=[], + ) + common_mlp_output = common_mlp(common_cnn_output) + + # Define shared net as TensorDictModule + common_module = TensorDictModule( + module=torch.nn.Sequential(common_cnn, common_mlp), + in_keys=in_keys, + out_keys=["common_features"], + ) + + # Define on head for the policy + policy_net = MLP( + in_features=common_mlp_output.shape[-1], + out_features=num_outputs, + activation_class=torch.nn.ReLU, + num_cells=[], + ) + policy_module = TensorDictModule( + module=policy_net, + in_keys=["common_features"], + out_keys=["logits"], + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + policy_module, + in_keys=["logits"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define another head for the value + value_net = MLP( + activation_class=torch.nn.ReLU, + in_features=common_mlp_output.shape[-1], + out_features=1, + num_cells=[], + ) + value_module = ValueOperator( + value_net, + in_keys=["common_features"], + ) + + return common_module, policy_module, value_module + + +def make_ppo_models(env_name): + + proof_environment = make_parallel_env(env_name, device="cpu") + common_module, policy_module, value_module = make_ppo_modules_pixels( + proof_environment + ) + + # Wrap modules in a single ActorCritic operator + actor_critic = ActorValueOperator( + common_operator=common_module, + policy_operator=policy_module, + value_operator=value_module, + ) + + with torch.no_grad(): + td = proof_environment.rollout(max_steps=100, break_when_any_done=False) + td = actor_critic(td) + del td + + actor = actor_critic.get_policy_operator() + critic = actor_critic.get_value_operator() + critic_head = actor_critic.get_value_head() + + del proof_environment + + return actor, critic, critic_head diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index 43f5246502f..2199dde2ba5 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -4,9 +4,22 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Tuple, Union +from typing import List, Optional, Union, Tuple import torch +from tensordict.nn import ( + dispatch, + is_functional, + set_skip_existing, + TensorDictModule, + TensorDictModuleBase, +) +from tensordict.tensordict import TensorDictBase +from tensordict.utils import NestedKey +from torch import nn, Tensor +from torchrl.objectives.utils import hold_out_net +from advantages import ValueEstimatorBase, _self_set_skip_existing, _self_set_grad_enabled, _call_value_nets +from functional import _transpose_time, SHAPE_ERR def _c_val( @@ -27,8 +40,8 @@ def _dv_val( ) -> Tuple[torch.Tensor, torch.Tensor]: rho = _c_val(log_pi, log_mu, rho_bar) next_vals = torch.cat([vals[:, 1:], torch.zeros_like(vals[:, :1])], 1) - dv = rho * (rewards + gamma * next_vals - vals) - return dv, rho + deltas = rho * (rewards + gamma * next_vals - vals) + return deltas, rho def _vtrace( @@ -40,6 +53,7 @@ def _vtrace( rho_bar: Union[float, torch.Tensor] = 1.0, c_bar: Union[float, torch.Tensor] = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: + T = vals.shape[1] if not isinstance(gamma, torch.Tensor): gamma = torch.full_like(vals, gamma) @@ -54,5 +68,277 @@ def _vtrace( vals[:, t] + dv[:, t] + gamma[:, t] * c[:, t] * (v_out[-1] - vals[:, t + 1]) ) v_out.append(_v_out) - v_out = torch.stack(list(reversed(v_out)), 1) + v_out = torch.stack(list(reversed(v_out)), 1) # values return v_out, rho + +@_transpose_time +def vtrace_correction( + gamma: float, + lmbda: float, + state_value: torch.Tensor, + next_state_value: torch.Tensor, + reward: torch.Tensor, + done: torch.Tensor, + time_dim: int = -2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """V-Trace off-policy correction method. + + Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" + https://arxiv.org/abs/1802.01561 for more context. + + Args: + gamma (scalar): exponential mean discount. + lmbda (scalar): trajectory discount. + log_pi (Tensor): log probability of taking actions in the environment. + log_mu (Tensor): log probability of taking actions in the environment. + state_value (Tensor): value function result with old_state input. + next_state_value (Tensor): value function result with new_state input. + reward (Tensor): reward of taking actions in the environment. + done (Tensor): boolean flag for end of episode. + rho_bar (Union[float, Tensor]): clipping parameter for importance weights. + c_bar (Union[float, Tensor]): clipping parameter for importance weights. + time_dim (int): dimension where the time is unrolled. Defaults to -2. + + All tensors (values, reward and done) must have shape + ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. + + """ + if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + raise RuntimeError(SHAPE_ERR) + dtype = next_state_value.dtype + device = state_value.device + + not_done = (~done).int() + *batch_size, time_steps, lastdim = not_done.shape + advantage = torch.empty(*batch_size, time_steps, lastdim, device=device, dtype=dtype) + + prev_advantage = 0 + gnotdone = gamma * not_done + delta = reward + (gnotdone * next_state_value) - state_value + discount = lmbda * gnotdone + for t in reversed(range(time_steps)): + prev_advantage = advantage[..., t, :] = delta[..., t, :] + (prev_advantage * discount[..., t, :]) + + value_target = advantage + state_value + + return advantage, value_target + + +class VTrace(ValueEstimatorBase): + """A class wrapper around V-Trace estimate functional. + + Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" + https://arxiv.org/abs/1802.01561 for more context. + + Args: + gamma (scalar): exponential mean discount. + lmbda (scalar): trajectory discount. + value_network (TensorDictModule): value operator used to retrieve the value estimates. + average_gae (bool): if ``True``, the resulting GAE values will be standardized. + Default is ``False``. + differentiable (bool, optional): if ``True``, gradients are propagated through + the computation of the value function. Default is ``False``. + + .. note:: + The proper way to make the function call non-differentiable is to + decorate it in a `torch.no_grad()` context manager/decorator or + pass detached parameters for functional modules. + + # vectorized (bool, optional): whether to use the vectorized version of the + # lambda return. Default is `True`. + + skip_existing (bool, optional): if ``True``, the value network will skip + modules which outputs are already present in the tensordict. + Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + is not affected. + Defaults to "state_value". + advantage_key (str or tuple of str, optional): [Deprecated] the key of + the advantage entry. Defaults to ``"advantage"``. + value_target_key (str or tuple of str, optional): [Deprecated] the key + of the advantage entry. Defaults to ``"value_target"``. + value_key (str or tuple of str, optional): [Deprecated] the value key to + read from the input tensordict. Defaults to ``"state_value"``. + + # shifted (bool, optional): if ``True``, the value and next value are + # estimated with a single call to the value network. This is faster + # but is only valid whenever (1) the ``"next"`` value is shifted by + # only one time step (which is not the case with multi-step value + # estimation, for instance) and (2) when the parameters used at time + # ``t`` and ``t+1`` are identical (which is not the case when target + # parameters are to be used). Defaults to ``False``. + + VTrace will return an :obj:`"advantage"` entry containing the advantage value. It will also + return a :obj:`"value_target"` entry with the V-Trace target value. + + # Finally, if :obj:`gradient_mode` is ``True``, + # an additional and differentiable :obj:`"value_error"` entry will be returned, + # which simple represents the difference between the return and the value network + # output (i.e. an additional distance loss should be applied to that signed value). + + # .. note:: + # As other advantage functions do, if the ``value_key`` is already present + # in the input tensordict, the VTrace module will ignore the calls to the value + # network (if any) and use the provided value instead. + + """ + + def __init__( + self, + *, + gamma: Union[float, torch.Tensor], + lmbda: float, + rho_bar: Union[float, torch.Tensor] = 1.0, + c_bar: Union[float, torch.Tensor] = 1.0, + value_network: TensorDictModule, + average_gae: bool = False, + differentiable: bool = False, + vectorized: bool = True, + skip_existing: Optional[bool] = None, + advantage_key: NestedKey = None, + value_target_key: NestedKey = None, + value_key: NestedKey = None, + shifted: bool = False, + ): + super().__init__( + shifted=shifted, + value_network=value_network, + differentiable=differentiable, + advantage_key=advantage_key, + value_target_key=value_target_key, + value_key=value_key, + skip_existing=skip_existing, + ) + try: + device = next(value_network.parameters()).device + except (AttributeError, StopIteration): + device = torch.device("cpu") + self.register_buffer("gamma", torch.tensor(gamma, device=device)) + self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) + self.register_buffer("rho_bar", torch.tensor(rho_bar, device=device)) + self.register_buffer("c_bar", torch.tensor(c_bar, device=device)) + self.average_gae = average_gae + self.vectorized = vectorized + + @_self_set_skip_existing + @_self_set_grad_enabled + @dispatch + def forward( + self, + tensordict: TensorDictBase, + *unused_args, + params: Optional[List[Tensor]] = None, + target_params: Optional[List[Tensor]] = None, + ) -> TensorDictBase: + """Computes the V-Trace correction given the data in tensordict. + + If a functional module is provided, a nested TensorDict containing the parameters + (and if relevant the target parameters) can be passed to the module. + + Args: + tensordict (TensorDictBase): A TensorDict containing the data + (an observation key, "action", "reward", "done" and "next" tensordict state + as returned by the environment) necessary to compute the value estimates and the GAE. + The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + params (TensorDictBase, optional): A nested TensorDict containing the params + to be passed to the functional value network module. + target_params (TensorDictBase, optional): A nested TensorDict containing the + target params to be passed to the functional value network module. + + Returns: + An updated TensorDict with an advantage and a value_error keys as defined in the constructor. + + Examples: + >>> from tensordict import TensorDict + >>> value_net = TensorDictModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = VTrace( + ... gamma=0.98, + ... lmbda=0.94, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward}, [1, 10]) + >>> _ = module(tensordict) + >>> assert "advantage" in tensordict.keys() + + The module supports non-tensordict (i.e. unpacked tensordict) inputs too: + + Examples: + >>> value_net = TensorDictModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = VTrace( + ... gamma=0.98, + ... lmbda=0.94, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) + + """ + if tensordict.batch_dims < 1: + raise RuntimeError( + "Expected input tensordict to have at least one dimensions, got " + f"tensordict.batch_size = {tensordict.batch_size}" + ) + reward = tensordict.get(("next", self.tensor_keys.reward)) + device = reward.device + gamma, lmbda = self.gamma.to(device), self.lmbda.to(device) + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) + if steps_to_next_obs is not None: + gamma = gamma ** steps_to_next_obs.view_as(reward) + + if self.value_network is not None: + if params is not None: + params = params.detach() + if target_params is None: + target_params = params.clone(False) + with hold_out_net(self.value_network): + # we may still need to pass gradient, but we don't want to assign grads to + # value net params + value, next_value = _call_value_nets( + value_net=self.value_network, + data=tensordict, + params=params, + next_params=target_params, + single_call=self.shifted, + value_key=self.tensor_keys.value, + detach_next=True, + ) + else: + value = tensordict.get(self.tensor_keys.value) + next_value = tensordict.get(("next", self.tensor_keys.value)) + + done = tensordict.get(("next", self.tensor_keys.done)) + if self.vectorized: + raise NotImplementedError + else: + adv, value_target = vtrace_correction( + gamma, + lmbda, + value, + next_value, + reward, + done, + time_dim=tensordict.ndim - 1, + ) + + if self.average_gae: + loc = adv.mean() + scale = adv.std().clamp_min(1e-4) + adv = adv - loc + adv = adv / scale + + tensordict.set(self.tensor_keys.advantage, adv) + tensordict.set(self.tensor_keys.value_target, value_target) + + return tensordict + From 2a693a737d929c1bc052a2a3fd799ba57b4f6178 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 6 Sep 2023 15:28:33 +0200 Subject: [PATCH 002/109] vtrace module --- torchrl/objectives/value/vtrace.py | 52 +++++++++++++++++++----------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index 2199dde2ba5..1e37a3a9146 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -32,14 +32,13 @@ def _c_val( def _dv_val( rewards: torch.Tensor, - vals: torch.Tensor, + next_vals: torch.Tensor, gamma: Union[float, torch.Tensor], rho_bar: Union[float, torch.Tensor], log_pi: torch.Tensor, log_mu: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: rho = _c_val(log_pi, log_mu, rho_bar) - next_vals = torch.cat([vals[:, 1:], torch.zeros_like(vals[:, :1])], 1) deltas = rho * (rewards + gamma * next_vals - vals) return deltas, rho @@ -74,11 +73,14 @@ def _vtrace( @_transpose_time def vtrace_correction( gamma: float, - lmbda: float, + log_pi: torch.Tensor, + log_mu: torch.Tensor, state_value: torch.Tensor, next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + rho_bar: Union[float, torch.Tensor] = 1.0, + c_bar: Union[float, torch.Tensor] = 1.0, time_dim: int = -2, ) -> Tuple[torch.Tensor, torch.Tensor]: """V-Trace off-policy correction method. @@ -88,7 +90,6 @@ def vtrace_correction( Args: gamma (scalar): exponential mean discount. - lmbda (scalar): trajectory discount. log_pi (Tensor): log probability of taking actions in the environment. log_mu (Tensor): log probability of taking actions in the environment. state_value (Tensor): value function result with old_state input. @@ -108,20 +109,23 @@ def vtrace_correction( dtype = next_state_value.dtype device = state_value.device + import ipdb; ipdb.set_trace() + delta, rho = _dv_val(reward, next_state_value, gamma, rho_bar, log_pi, log_mu) + c = _c_val(log_pi, log_mu, c_bar) + not_done = (~done).int() *batch_size, time_steps, lastdim = not_done.shape - advantage = torch.empty(*batch_size, time_steps, lastdim, device=device, dtype=dtype) - - prev_advantage = 0 + acc = 0 + v_out = torch.empty(*batch_size, time_steps, lastdim, device=device, dtype=dtype) gnotdone = gamma * not_done - delta = reward + (gnotdone * next_state_value) - state_value - discount = lmbda * gnotdone for t in reversed(range(time_steps)): - prev_advantage = advantage[..., t, :] = delta[..., t, :] + (prev_advantage * discount[..., t, :]) + import ipdb; ipdb.set_trace() # TODO: Review! + acc = delta[..., t, :] + (gnotdone[..., t, :] * acc * c) + v_out[..., t, :].copy_(acc + state_value[..., t, :]) - value_target = advantage + state_value + advantage = rho * (reward + gamma * v_out - state_value) # TODO: Review! - return advantage, value_target + return advantage, v_out class VTrace(ValueEstimatorBase): @@ -132,7 +136,6 @@ class VTrace(ValueEstimatorBase): Args: gamma (scalar): exponential mean discount. - lmbda (scalar): trajectory discount. value_network (TensorDictModule): value operator used to retrieve the value estimates. average_gae (bool): if ``True``, the resulting GAE values will be standardized. Default is ``False``. @@ -186,9 +189,9 @@ def __init__( self, *, gamma: Union[float, torch.Tensor], - lmbda: float, rho_bar: Union[float, torch.Tensor] = 1.0, c_bar: Union[float, torch.Tensor] = 1.0, + actor_network: TensorDictModule = None, value_network: TensorDictModule, average_gae: bool = False, differentiable: bool = False, @@ -213,11 +216,11 @@ def __init__( except (AttributeError, StopIteration): device = torch.device("cpu") self.register_buffer("gamma", torch.tensor(gamma, device=device)) - self.register_buffer("lmbda", torch.tensor(lmbda, device=device)) self.register_buffer("rho_bar", torch.tensor(rho_bar, device=device)) self.register_buffer("c_bar", torch.tensor(c_bar, device=device)) self.average_gae = average_gae self.vectorized = vectorized + self.actor_network = actor_network @_self_set_skip_existing @_self_set_grad_enabled @@ -255,7 +258,6 @@ def forward( ... ) >>> module = VTrace( ... gamma=0.98, - ... lmbda=0.94, ... value_network=value_net, ... differentiable=False, ... ) @@ -274,7 +276,6 @@ def forward( ... ) >>> module = VTrace( ... gamma=0.98, - ... lmbda=0.94, ... value_network=value_net, ... differentiable=False, ... ) @@ -291,11 +292,12 @@ def forward( ) reward = tensordict.get(("next", self.tensor_keys.reward)) device = reward.device - gamma, lmbda = self.gamma.to(device), self.lmbda.to(device) + gamma = self.gamma.to(device) steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) if steps_to_next_obs is not None: gamma = gamma ** steps_to_next_obs.view_as(reward) + # Make sure we have the value and next value if self.value_network is not None: if params is not None: params = params.detach() @@ -317,17 +319,29 @@ def forward( value = tensordict.get(self.tensor_keys.value) next_value = tensordict.get(("next", self.tensor_keys.value)) + # Make sure we have the new log prob and the old log prob + if self.actor_network is not None: + import ipdb; ipdb.set_trace() + raise NotImplementedError + else: + import ipdb; ipdb.set_trace() + log_pi = tensordict.get(self.tensor_keys.log_pi) # new / local log prob + log_mu = tensordict.get(self.tensor_keys.log_mu) # old / distributed log prob + done = tensordict.get(("next", self.tensor_keys.done)) if self.vectorized: raise NotImplementedError else: adv, value_target = vtrace_correction( gamma, - lmbda, + log_pi, + log_mu, value, next_value, reward, done, + rho_bar=self.rho_bar, + c_bar=self.c_bar, time_dim=tensordict.ndim - 1, ) From 7a8ee3813a5983bb39227445bdc4d146c4c3c943 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 6 Sep 2023 15:36:34 +0200 Subject: [PATCH 003/109] vtrace module --- examples/impala/config.yaml | 41 +++++++++++++++--------------- examples/impala/impala.py | 4 +-- torchrl/objectives/value/vtrace.py | 4 +-- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index 1bee59c3adb..c9820adc4d3 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -1,35 +1,34 @@ # Environment env: -env_name: PongNoFrameskip-v4 + env_name: PongNoFrameskip-v4 # collector collector: -frames_per_batch: 4096 -total_frames: 40_000_000 + frames_per_batch: 4096 + total_frames: 40_000_000 # logger logger: -backend: wandb -exp_name: Atari_Schulman17 -test_interval: 40_000_000 -num_test_episodes: 3 + backend: csv + exp_name: Atari_Schulman17 + test_interval: 40_000_000 + num_test_episodes: 3 # Optim optim: -lr: 2.5e-4 -eps: 1.0e-6 -weight_decay: 0.0 -max_grad_norm: 0.5 -anneal_lr: True + lr: 2.5e-4 + eps: 1.0e-6 + weight_decay: 0.0 + max_grad_norm: 0.5 + anneal_lr: True # loss loss: -gamma: 0.99 -mini_batch_size: 1024 -ppo_epochs: 3 -gae_lambda: 0.95 -clip_epsilon: 0.1 -anneal_clip_epsilon: True -critic_coef: 1.0 -entropy_coef: 0.01 -loss_critic_type: l2 \ No newline at end of file + gamma: 0.99 + mini_batch_size: 1024 + ppo_epochs: 3 + clip_epsilon: 0.1 + anneal_clip_epsilon: True + critic_coef: 1.0 + entropy_coef: 0.01 + loss_critic_type: l2 \ No newline at end of file diff --git a/examples/impala/impala.py b/examples/impala/impala.py index edc666dfe9d..a54af5d493f 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -5,7 +5,7 @@ import hydra -@hydra.main(config_path=".", config_name="config_atari", version_base="1.1") +@hydra.main(config_path=".", config_name="config", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time @@ -73,8 +73,8 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create loss and adv modules vtrace_module = VTrace( gamma=cfg.loss.gamma, - lmbda=cfg.loss.gae_lambda, value_network=critic, + actor_network=actor, average_gae=False, ) loss_module = ClipPPOLoss( diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index 1e37a3a9146..b14360eac25 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -18,8 +18,8 @@ from tensordict.utils import NestedKey from torch import nn, Tensor from torchrl.objectives.utils import hold_out_net -from advantages import ValueEstimatorBase, _self_set_skip_existing, _self_set_grad_enabled, _call_value_nets -from functional import _transpose_time, SHAPE_ERR +from torchrl.objectives.value.advantages import ValueEstimatorBase, _self_set_skip_existing, _self_set_grad_enabled, _call_value_nets +from torchrl.objectives.value.functional import _transpose_time, SHAPE_ERR def _c_val( From a6434fd9a3ca9d395f8cf1418368ae9cf2b37b7a Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 6 Sep 2023 15:37:12 +0200 Subject: [PATCH 004/109] vtrace module --- examples/impala/impala.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/impala/impala.py b/examples/impala/impala.py index a54af5d493f..31b2cee4613 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821 vtrace_module = VTrace( gamma=cfg.loss.gamma, value_network=critic, - actor_network=actor, + # actor_network=actor, average_gae=False, ) loss_module = ClipPPOLoss( From 3e7645002f5d74ebdfb89293b40a638738ade00f Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 6 Sep 2023 16:11:06 +0200 Subject: [PATCH 005/109] vtrace module --- torchrl/objectives/value/vtrace.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index b14360eac25..5648f57cbf6 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -27,11 +27,12 @@ def _c_val( log_mu: torch.Tensor, c: Union[float, torch.Tensor] = 1, ) -> torch.Tensor: - return (log_pi - log_mu).clamp_max(math.log(c)).exp() + return (log_pi - log_mu).clamp_max(math.log(c)).exp().unsqueeze(-1) def _dv_val( rewards: torch.Tensor, + vals: torch.Tensor, next_vals: torch.Tensor, gamma: Union[float, torch.Tensor], rho_bar: Union[float, torch.Tensor], @@ -70,7 +71,7 @@ def _vtrace( v_out = torch.stack(list(reversed(v_out)), 1) # values return v_out, rho -@_transpose_time +# @_transpose_time def vtrace_correction( gamma: float, log_pi: torch.Tensor, @@ -104,23 +105,25 @@ def vtrace_correction( ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. """ + if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): raise RuntimeError(SHAPE_ERR) + dtype = next_state_value.dtype device = state_value.device - import ipdb; ipdb.set_trace() - delta, rho = _dv_val(reward, next_state_value, gamma, rho_bar, log_pi, log_mu) - c = _c_val(log_pi, log_mu, c_bar) + delta, rho = _dv_val(reward, state_value, next_state_value, gamma, rho_bar, log_pi, log_mu) + cs = _c_val(log_pi, log_mu, c_bar) not_done = (~done).int() *batch_size, time_steps, lastdim = not_done.shape acc = 0 v_out = torch.empty(*batch_size, time_steps, lastdim, device=device, dtype=dtype) gnotdone = gamma * not_done + import ipdb; ipdb.set_trace() # TODO: Review! for t in reversed(range(time_steps)): import ipdb; ipdb.set_trace() # TODO: Review! - acc = delta[..., t, :] + (gnotdone[..., t, :] * acc * c) + acc = delta[..., t, :] + (gnotdone[..., t, :] * acc * cs[t]) v_out[..., t, :].copy_(acc + state_value[..., t, :]) advantage = rho * (reward + gamma * v_out - state_value) # TODO: Review! @@ -195,7 +198,7 @@ def __init__( value_network: TensorDictModule, average_gae: bool = False, differentiable: bool = False, - vectorized: bool = True, + vectorized: bool = False, # TODO: Review! skip_existing: Optional[bool] = None, advantage_key: NestedKey = None, value_target_key: NestedKey = None, @@ -324,9 +327,10 @@ def forward( import ipdb; ipdb.set_trace() raise NotImplementedError else: - import ipdb; ipdb.set_trace() - log_pi = tensordict.get(self.tensor_keys.log_pi) # new / local log prob - log_mu = tensordict.get(self.tensor_keys.log_mu) # old / distributed log prob + # log_pi = tensordict.get(self.tensor_keys.log_pi) # new / local log prob + log_pi = tensordict.get("sample_log_prob") + # log_mu = tensordict.get(self.tensor_keys.log_mu) # old / distributed log prob + log_mu = tensordict.get("sample_log_prob") done = tensordict.get(("next", self.tensor_keys.done)) if self.vectorized: From 7613cb991366235a2285fe0613d4298c9cc9e5f5 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 6 Sep 2023 16:42:26 +0200 Subject: [PATCH 006/109] vtrace module --- examples/impala/impala.py | 3 +-- torchrl/objectives/value/vtrace.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 31b2cee4613..25615e5b270 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821 vtrace_module = VTrace( gamma=cfg.loss.gamma, value_network=critic, - # actor_network=actor, + actor_network=actor, average_gae=False, ) loss_module = ClipPPOLoss( @@ -135,7 +135,6 @@ def main(cfg: "DictConfig"): # noqa: F821 # Compute VTrace with torch.no_grad(): - import ipdb; ipdb.set_trace() data = vtrace_module(data) data_reshape = data.reshape(-1) diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index 5648f57cbf6..04fbb8e7b30 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -120,13 +120,13 @@ def vtrace_correction( acc = 0 v_out = torch.empty(*batch_size, time_steps, lastdim, device=device, dtype=dtype) gnotdone = gamma * not_done - import ipdb; ipdb.set_trace() # TODO: Review! + # TODO: Review! for t in reversed(range(time_steps)): - import ipdb; ipdb.set_trace() # TODO: Review! - acc = delta[..., t, :] + (gnotdone[..., t, :] * acc * cs[t]) + # TODO: Review! + acc = delta[..., t, :] + (gnotdone[..., t, :] * acc * cs[..., t, :]) v_out[..., t, :].copy_(acc + state_value[..., t, :]) - advantage = rho * (reward + gamma * v_out - state_value) # TODO: Review! + advantage = rho * (reward + gamma * v_out - state_value) # TODO: Review! return advantage, v_out @@ -322,15 +322,15 @@ def forward( value = tensordict.get(self.tensor_keys.value) next_value = tensordict.get(("next", self.tensor_keys.value)) + # TODO: raise ValueError if log_mu is not present + log_mu = tensordict.get("sample_log_prob") + # Make sure we have the new log prob and the old log prob if self.actor_network is not None: - import ipdb; ipdb.set_trace() - raise NotImplementedError + # TODO: review + log_pi = self.actor_network(tensordict.select(self.actor_network.in_keys)).get("sample_log_prob") # old / distributed log prob else: - # log_pi = tensordict.get(self.tensor_keys.log_pi) # new / local log prob - log_pi = tensordict.get("sample_log_prob") - # log_mu = tensordict.get(self.tensor_keys.log_mu) # old / distributed log prob - log_mu = tensordict.get("sample_log_prob") + log_pi = tensordict.get("sample_log_prob") # new / local log prob # TODO: Review! done = tensordict.get(("next", self.tensor_keys.done)) if self.vectorized: From 8c70b0a5931a873f454a35c12a3750902370c759 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 6 Sep 2023 17:50:59 +0200 Subject: [PATCH 007/109] vtrace module --- examples/impala/impala.py | 2 +- torchrl/objectives/value/vtrace.py | 119 +++++++++++++++-------------- 2 files changed, 62 insertions(+), 59 deletions(-) diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 25615e5b270..de4150a459b 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -75,7 +75,7 @@ def main(cfg: "DictConfig"): # noqa: F821 gamma=cfg.loss.gamma, value_network=critic, actor_network=actor, - average_gae=False, + average_adv=False, ) loss_module = ClipPPOLoss( actor=actor, diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index 04fbb8e7b30..de5bf1f78fb 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -18,8 +18,12 @@ from tensordict.utils import NestedKey from torch import nn, Tensor from torchrl.objectives.utils import hold_out_net -from torchrl.objectives.value.advantages import ValueEstimatorBase, _self_set_skip_existing, _self_set_grad_enabled, _call_value_nets -from torchrl.objectives.value.functional import _transpose_time, SHAPE_ERR +from torchrl.objectives.value.advantages import ( + ValueEstimatorBase, + _self_set_skip_existing, + _self_set_grad_enabled, + _call_value_nets) +from torchrl.objectives.value.functional import _transpose_time, SHAPE_ERR, td0_return_estimate def _c_val( @@ -27,8 +31,7 @@ def _c_val( log_mu: torch.Tensor, c: Union[float, torch.Tensor] = 1, ) -> torch.Tensor: - return (log_pi - log_mu).clamp_max(math.log(c)).exp().unsqueeze(-1) - + return (log_pi - log_mu).clamp_max(math.log(c)).exp().unsqueeze(-1) # TODO: Review! def _dv_val( rewards: torch.Tensor, @@ -80,8 +83,8 @@ def vtrace_correction( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, - rho_bar: Union[float, torch.Tensor] = 1.0, - c_bar: Union[float, torch.Tensor] = 1.0, + rho_thresh: Union[float, torch.Tensor] = 1.0, + c_thresh: Union[float, torch.Tensor] = 1.0, time_dim: int = -2, ) -> Tuple[torch.Tensor, torch.Tensor]: """V-Trace off-policy correction method. @@ -112,8 +115,10 @@ def vtrace_correction( dtype = next_state_value.dtype device = state_value.device - delta, rho = _dv_val(reward, state_value, next_state_value, gamma, rho_bar, log_pi, log_mu) - cs = _c_val(log_pi, log_mu, c_bar) + delta, clipped_rho = _dv_val(reward, state_value, next_state_value, gamma, rho_thresh, log_pi, log_mu) + torch.clamp(torch.exp(log_rhos), max=clip_c_thres) + + clipped_cs = _c_val(log_pi, log_mu, c_thresh) not_done = (~done).int() *batch_size, time_steps, lastdim = not_done.shape @@ -124,6 +129,8 @@ def vtrace_correction( for t in reversed(range(time_steps)): # TODO: Review! acc = delta[..., t, :] + (gnotdone[..., t, :] * acc * cs[..., t, :]) + + v_out[..., t, :].copy_(acc + state_value[..., t, :]) advantage = rho * (reward + gamma * v_out - state_value) # TODO: Review! @@ -140,7 +147,8 @@ class VTrace(ValueEstimatorBase): Args: gamma (scalar): exponential mean discount. value_network (TensorDictModule): value operator used to retrieve the value estimates. - average_gae (bool): if ``True``, the resulting GAE values will be standardized. + actor_network (TensorDictModule, optional): actor operator used to retrieve the log prob. + average_adv (bool): if ``True``, the resulting advantage values will be standardized. Default is ``False``. differentiable (bool, optional): if ``True``, gradients are propagated through the computation of the value function. Default is ``False``. @@ -149,10 +157,6 @@ class VTrace(ValueEstimatorBase): The proper way to make the function call non-differentiable is to decorate it in a `torch.no_grad()` context manager/decorator or pass detached parameters for functional modules. - - # vectorized (bool, optional): whether to use the vectorized version of the - # lambda return. Default is `True`. - skip_existing (bool, optional): if ``True``, the value network will skip modules which outputs are already present in the tensordict. Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` @@ -164,27 +168,21 @@ class VTrace(ValueEstimatorBase): of the advantage entry. Defaults to ``"value_target"``. value_key (str or tuple of str, optional): [Deprecated] the value key to read from the input tensordict. Defaults to ``"state_value"``. - - # shifted (bool, optional): if ``True``, the value and next value are - # estimated with a single call to the value network. This is faster - # but is only valid whenever (1) the ``"next"`` value is shifted by - # only one time step (which is not the case with multi-step value - # estimation, for instance) and (2) when the parameters used at time - # ``t`` and ``t+1`` are identical (which is not the case when target - # parameters are to be used). Defaults to ``False``. + shifted (bool, optional): if ``True``, the value and next value are + estimated with a single call to the value network. This is faster + but is only valid whenever (1) the ``"next"`` value is shifted by + only one time step (which is not the case with multi-step value + estimation, for instance) and (2) when the parameters used at time + ``t`` and ``t+1`` are identical (which is not the case when target + parameters are to be used). Defaults to ``False``. VTrace will return an :obj:`"advantage"` entry containing the advantage value. It will also return a :obj:`"value_target"` entry with the V-Trace target value. - # Finally, if :obj:`gradient_mode` is ``True``, - # an additional and differentiable :obj:`"value_error"` entry will be returned, - # which simple represents the difference between the return and the value network - # output (i.e. an additional distance loss should be applied to that signed value). - - # .. note:: - # As other advantage functions do, if the ``value_key`` is already present - # in the input tensordict, the VTrace module will ignore the calls to the value - # network (if any) and use the provided value instead. + .. note:: + As other advantage functions do, if the ``value_key`` is already present + in the input tensordict, the VTrace module will ignore the calls to the value + network (if any) and use the provided value instead. """ @@ -196,10 +194,10 @@ def __init__( c_bar: Union[float, torch.Tensor] = 1.0, actor_network: TensorDictModule = None, value_network: TensorDictModule, - average_gae: bool = False, + average_adv: bool = False, differentiable: bool = False, - vectorized: bool = False, # TODO: Review! skip_existing: Optional[bool] = None, + log_prob_key: NestedKey = "sample_log_prob", # TODO: should be added to _AcceptedKeys? advantage_key: NestedKey = None, value_target_key: NestedKey = None, value_key: NestedKey = None, @@ -221,9 +219,13 @@ def __init__( self.register_buffer("gamma", torch.tensor(gamma, device=device)) self.register_buffer("rho_bar", torch.tensor(rho_bar, device=device)) self.register_buffer("c_bar", torch.tensor(c_bar, device=device)) - self.average_gae = average_gae - self.vectorized = vectorized + self.average_adv = average_adv self.actor_network = actor_network + self._log_prob_key = log_prob_key + + @property + def log_prob_key(self): + return self._log_prob_key @_self_set_skip_existing @_self_set_grad_enabled @@ -322,34 +324,35 @@ def forward( value = tensordict.get(self.tensor_keys.value) next_value = tensordict.get(("next", self.tensor_keys.value)) - # TODO: raise ValueError if log_mu is not present - log_mu = tensordict.get("sample_log_prob") + # Make sure we have the log prob computed at collection time + if self.log_prob_key not in tensordict.keys(): + raise ValueError(f"Expected {self.log_prob_key} to be in tensordict") + log_mu = tensordict.get(self.log_prob_key) - # Make sure we have the new log prob and the old log prob - if self.actor_network is not None: - # TODO: review - log_pi = self.actor_network(tensordict.select(self.actor_network.in_keys)).get("sample_log_prob") # old / distributed log prob - else: - log_pi = tensordict.get("sample_log_prob") # new / local log prob # TODO: Review! + # Compute the current log prob + with hold_out_net(self.actor_network): + log_pi = self.actor_network( + tensordict.select(self.actor_network.in_keys) + ).get(self.log_prob_key) + # Compute the V-Trace correction done = tensordict.get(("next", self.tensor_keys.done)) - if self.vectorized: - raise NotImplementedError - else: - adv, value_target = vtrace_correction( - gamma, - log_pi, - log_mu, - value, - next_value, - reward, - done, - rho_bar=self.rho_bar, - c_bar=self.c_bar, - time_dim=tensordict.ndim - 1, - ) + adv, value_target = vtrace_correction( + gamma, + log_pi, + log_mu, + value, + next_value, + reward, + done, + rho_bar=self.rho_bar, + c_bar=self.c_bar, + time_dim=tensordict.ndim - 1, + ) + + # TODO: where are returns computed? - if self.average_gae: + if self.average_adv: loc = adv.mean() scale = adv.std().clamp_min(1e-4) adv = adv - loc From e8dd5be5530dd05b05a815fa8c4ddf7e641c5524 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 6 Sep 2023 18:18:42 +0200 Subject: [PATCH 008/109] vtrace module --- torchrl/objectives/value/vtrace.py | 32 ++++++++++++++++-------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index de5bf1f78fb..fd1fee35af2 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -38,13 +38,13 @@ def _dv_val( vals: torch.Tensor, next_vals: torch.Tensor, gamma: Union[float, torch.Tensor], - rho_bar: Union[float, torch.Tensor], + rho_thresh: Union[float, torch.Tensor], log_pi: torch.Tensor, log_mu: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - rho = _c_val(log_pi, log_mu, rho_bar) - deltas = rho * (rewards + gamma * next_vals - vals) - return deltas, rho + clipped_rho = _c_val(log_pi, log_mu, rho_thresh) + deltas = clipped_rho * (rewards + gamma * next_vals - vals) + return deltas, clipped_rho def _vtrace( @@ -100,8 +100,8 @@ def vtrace_correction( next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. done (Tensor): boolean flag for end of episode. - rho_bar (Union[float, Tensor]): clipping parameter for importance weights. - c_bar (Union[float, Tensor]): clipping parameter for importance weights. + rho_thresh (Union[float, Tensor]): clipping parameter for importance weights. + c_thresh (Union[float, Tensor]): clipping parameter for importance weights. time_dim (int): dimension where the time is unrolled. Defaults to -2. All tensors (values, reward and done) must have shape @@ -116,24 +116,28 @@ def vtrace_correction( device = state_value.device delta, clipped_rho = _dv_val(reward, state_value, next_state_value, gamma, rho_thresh, log_pi, log_mu) - torch.clamp(torch.exp(log_rhos), max=clip_c_thres) + clipped_c = _c_val(log_pi, log_mu, c_thresh) - clipped_cs = _c_val(log_pi, log_mu, c_thresh) + ############################################################ + # FIX THIS PART! not_done = (~done).int() *batch_size, time_steps, lastdim = not_done.shape acc = 0 v_out = torch.empty(*batch_size, time_steps, lastdim, device=device, dtype=dtype) - gnotdone = gamma * not_done - # TODO: Review! + + discounts = gamma * not_done # TODO: Review! for t in reversed(range(time_steps)): # TODO: Review! - acc = delta[..., t, :] + (gnotdone[..., t, :] * acc * cs[..., t, :]) + acc = delta[..., t, :] + (discounts[..., t, :] * acc * clipped_c[..., t, :]) + v_out.append(acc) + v_out[..., t, :].copy_(acc + state_value[..., t, :]) - v_out[..., t, :].copy_(acc + state_value[..., t, :]) + # FIX THIS PART! + ############################################################ - advantage = rho * (reward + gamma * v_out - state_value) # TODO: Review! + advantage = clipped_rho * (reward + gamma * v_out - state_value) return advantage, v_out @@ -350,8 +354,6 @@ def forward( time_dim=tensordict.ndim - 1, ) - # TODO: where are returns computed? - if self.average_adv: loc = adv.mean() scale = adv.std().clamp_min(1e-4) From 3ace31f6204402ebd4ecfc12d86a486881b0be21 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 7 Sep 2023 09:59:40 +0200 Subject: [PATCH 009/109] vtrace module --- examples/impala/config.yaml | 3 - examples/impala/impala.py | 95 ++++++++++++++---------------- torchrl/objectives/value/vtrace.py | 74 +++++++---------------- 3 files changed, 66 insertions(+), 106 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index c9820adc4d3..00780ee7213 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -26,9 +26,6 @@ optim: loss: gamma: 0.99 mini_batch_size: 1024 - ppo_epochs: 3 - clip_epsilon: 0.1 - anneal_clip_epsilon: True critic_coef: 1.0 entropy_coef: 0.01 loss_critic_type: l2 \ No newline at end of file diff --git a/examples/impala/impala.py b/examples/impala/impala.py index de4150a459b..3d79ecfb5d1 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -20,7 +20,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import ClipPPOLoss + from torchrl.objectives import A2CLoss from torchrl.record.loggers import generate_exp_name, get_logger from torchrl.objectives.value.vtrace import VTrace from utils import make_parallel_env, make_ppo_models @@ -75,16 +75,14 @@ def main(cfg: "DictConfig"): # noqa: F821 gamma=cfg.loss.gamma, value_network=critic, actor_network=actor, - average_adv=False, + average_adv=True, ) - loss_module = ClipPPOLoss( + loss_module = A2CLoss( actor=actor, critic=critic, - clip_epsilon=cfg.loss.clip_epsilon, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, - normalize_advantage=True, ) # Create optimizer @@ -109,9 +107,7 @@ def main(cfg: "DictConfig"): # noqa: F821 start_time = time.time() pbar = tqdm.tqdm(total=total_frames) num_mini_batches = frames_per_batch // mini_batch_size - total_network_updates = ( - (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches - ) + total_network_updates = (total_frames // frames_per_batch) * num_mini_batches for data in collector: @@ -130,49 +126,46 @@ def main(cfg: "DictConfig"): # noqa: F821 data["done"].copy_(data["end_of_life"]) data["next", "done"].copy_(data["next", "end_of_life"]) - losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches]) - for j in range(cfg.loss.ppo_epochs): - - # Compute VTrace - with torch.no_grad(): - data = vtrace_module(data) - data_reshape = data.reshape(-1) - - # Update the data buffer - data_buffer.extend(data_reshape) - - for i, batch in enumerate(data_buffer): - - # Linearly decrease the learning rate and clip epsilon - alpha = 1 - (num_network_updates / total_network_updates) - if cfg.optim.anneal_lr: - for g in optim.param_groups: - g["lr"] = cfg.optim.lr * alpha - if cfg.loss.anneal_clip_epsilon: - loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha) - num_network_updates += 1 - - # Get a data batch - batch = batch.to(device) - - # Forward pass PPO loss - loss = loss_module(batch) - losses[j, i] = loss.select( - "loss_critic", "loss_entropy", "loss_objective" - ).detach() - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) - - # Backward pass - loss_sum.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm - ) - - # Update the networks - optim.step() - optim.zero_grad() + losses = TensorDict({}, batch_size=[num_mini_batches]) + + # Compute VTrace + with torch.no_grad(): + data = vtrace_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for i, batch in enumerate(data_buffer): + + # Linearly decrease the learning rate and clip epsilon + alpha = 1 - (num_network_updates / total_network_updates) + if cfg.optim.anneal_lr: + for g in optim.param_groups: + g["lr"] = cfg.optim.lr * alpha + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device) + + # Forward pass PPO loss + loss = loss_module(batch) + losses[i] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses.items(): diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index fd1fee35af2..1ca45d2635f 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -47,34 +47,7 @@ def _dv_val( return deltas, clipped_rho -def _vtrace( - rewards: torch.Tensor, - vals: torch.Tensor, - log_pi: torch.Tensor, - log_mu: torch.Tensor, - gamma: Union[torch.Tensor, float], - rho_bar: Union[float, torch.Tensor] = 1.0, - c_bar: Union[float, torch.Tensor] = 1.0, -) -> Tuple[torch.Tensor, torch.Tensor]: - - T = vals.shape[1] - if not isinstance(gamma, torch.Tensor): - gamma = torch.full_like(vals, gamma) - - dv, rho = _dv_val(rewards, vals, gamma, rho_bar, log_pi, log_mu) - c = _c_val(log_pi, log_mu, c_bar) - - v_out = [] - v_out.append(vals[:, -1] + dv[:, -1]) - for t in range(T - 2, -1, -1): - _v_out = ( - vals[:, t] + dv[:, t] + gamma[:, t] * c[:, t] * (v_out[-1] - vals[:, t + 1]) - ) - v_out.append(_v_out) - v_out = torch.stack(list(reversed(v_out)), 1) # values - return v_out, rho - -# @_transpose_time +# @_transpose_time # TODO: is this needed? def vtrace_correction( gamma: float, log_pi: torch.Tensor, @@ -87,7 +60,7 @@ def vtrace_correction( c_thresh: Union[float, torch.Tensor] = 1.0, time_dim: int = -2, ) -> Tuple[torch.Tensor, torch.Tensor]: - """V-Trace off-policy correction method. + """Computes V-Trace off-policy actor critic targets. Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" https://arxiv.org/abs/1802.01561 for more context. @@ -115,31 +88,28 @@ def vtrace_correction( dtype = next_state_value.dtype device = state_value.device - delta, clipped_rho = _dv_val(reward, state_value, next_state_value, gamma, rho_thresh, log_pi, log_mu) + deltas, clipped_rho = _dv_val(reward, state_value, next_state_value, gamma, rho_thresh, log_pi, log_mu) clipped_c = _c_val(log_pi, log_mu, c_thresh) ############################################################ - # FIX THIS PART! + # MAKE THIS PART WORK; THEN WE CAN TRY TO MAKE IT FASTER not_done = (~done).int() *batch_size, time_steps, lastdim = not_done.shape - acc = 0 - v_out = torch.empty(*batch_size, time_steps, lastdim, device=device, dtype=dtype) + discounts = gamma * not_done + vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])] + for i in reversed(range(time_steps)): + discount_t, c_t, delta_t = discounts[..., i, :], clipped_c[..., i, :], deltas[..., i, :] + vs_minus_v_xs.append(delta_t + discount_t * c_t * vs_minus_v_xs[-1]) + vs_minus_v_xs = torch.stack(vs_minus_v_xs[1:], dim=time_dim) + vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[time_dim]) + vs = vs_minus_v_xs + state_value + vs_t_plus_1 = torch.cat([vs[..., 1:, :], next_state_value[..., -1:, :]], dim=time_dim) + advantages = clipped_rho * (reward + gamma * vs_t_plus_1 - state_value) - discounts = gamma * not_done # TODO: Review! - for t in reversed(range(time_steps)): - # TODO: Review! - acc = delta[..., t, :] + (discounts[..., t, :] * acc * clipped_c[..., t, :]) - v_out.append(acc) - - v_out[..., t, :].copy_(acc + state_value[..., t, :]) - - # FIX THIS PART! ############################################################ - advantage = clipped_rho * (reward + gamma * v_out - state_value) - - return advantage, v_out + return advantages, vs class VTrace(ValueEstimatorBase): @@ -194,8 +164,8 @@ def __init__( self, *, gamma: Union[float, torch.Tensor], - rho_bar: Union[float, torch.Tensor] = 1.0, - c_bar: Union[float, torch.Tensor] = 1.0, + rho_thresh: Union[float, torch.Tensor] = 1.0, + c_thresh: Union[float, torch.Tensor] = 1.0, actor_network: TensorDictModule = None, value_network: TensorDictModule, average_adv: bool = False, @@ -221,8 +191,8 @@ def __init__( except (AttributeError, StopIteration): device = torch.device("cpu") self.register_buffer("gamma", torch.tensor(gamma, device=device)) - self.register_buffer("rho_bar", torch.tensor(rho_bar, device=device)) - self.register_buffer("c_bar", torch.tensor(c_bar, device=device)) + self.register_buffer("rho_thresh", torch.tensor(rho_thresh, device=device)) + self.register_buffer("c_thresh", torch.tensor(c_thresh, device=device)) self.average_adv = average_adv self.actor_network = actor_network self._log_prob_key = log_prob_key @@ -333,7 +303,7 @@ def forward( raise ValueError(f"Expected {self.log_prob_key} to be in tensordict") log_mu = tensordict.get(self.log_prob_key) - # Compute the current log prob + # Compute log prob with current policy with hold_out_net(self.actor_network): log_pi = self.actor_network( tensordict.select(self.actor_network.in_keys) @@ -349,8 +319,8 @@ def forward( next_value, reward, done, - rho_bar=self.rho_bar, - c_bar=self.c_bar, + rho_thresh=self.rho_thresh, + c_thresh=self.c_thresh, time_dim=tensordict.ndim - 1, ) From 3efe6019ffdd3340fb19fb91b9c71f00ed3dcbe9 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 7 Sep 2023 13:12:03 +0200 Subject: [PATCH 010/109] impala --- examples/impala/config.yaml | 12 ++++---- examples/impala/impala.py | 49 +++++++++++++++++------------- examples/impala/utils.py | 2 +- torchrl/objectives/value/vtrace.py | 6 ++-- 4 files changed, 38 insertions(+), 31 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index 00780ee7213..62a30fce19f 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -4,28 +4,28 @@ env: # collector collector: - frames_per_batch: 4096 + frames_per_batch: 128 total_frames: 40_000_000 # logger logger: backend: csv - exp_name: Atari_Schulman17 + exp_name: Atari_Espeholt18 test_interval: 40_000_000 num_test_episodes: 3 # Optim optim: - lr: 2.5e-4 + lr: 0.0006 eps: 1.0e-6 weight_decay: 0.0 - max_grad_norm: 0.5 + max_grad_norm: 40.0 anneal_lr: True # loss loss: gamma: 0.99 - mini_batch_size: 1024 + mini_batch_size: 128 critic_coef: 1.0 - entropy_coef: 0.01 + entropy_coef: 0.025 loss_critic_type: l2 \ No newline at end of file diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 3d79ecfb5d1..ca3711ebe41 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -1,6 +1,6 @@ """ -This script reproduces the Proximal Policy Optimization (PPO) Algorithm -results from Schulman et al. 2017 for the on Atari Environments. +This script reproduces the IMPALA Algorithm +results from Espeholt et al. 2018 for the on Atari Environments. """ import hydra @@ -15,12 +15,12 @@ def main(cfg: "DictConfig"): # noqa: F821 import tqdm from tensordict import TensorDict - from torchrl.collectors import SyncDataCollector + from torchrl.collectors import MultiaSyncDataCollector from torchrl.collectors.distributed import RPCDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import A2CLoss + from torchrl.objectives import A2CLoss, ClipPPOLoss from torchrl.record.loggers import generate_exp_name, get_logger from torchrl.objectives.value.vtrace import VTrace from utils import make_parallel_env, make_ppo_models @@ -52,8 +52,8 @@ def main(cfg: "DictConfig"): # noqa: F821 # max_frames_per_traj=-1, # sync=False, # ) - collector = SyncDataCollector( - create_env_fn=make_parallel_env(cfg.env.env_name, device), + collector = MultiaSyncDataCollector( + create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * 4, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -86,16 +86,22 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizer - optim = torch.optim.Adam( - loss_module.parameters(), + optim_actor = torch.optim.Adam( + actor.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + ) + optim_critic = torch.optim.Adam( + critic.parameters(), lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, eps=cfg.optim.eps, ) # Create logger - exp_name = generate_exp_name("PPO", f"{cfg.logger.exp_name}_{cfg.env.env_name}") - logger = get_logger(cfg.logger.backend, logger_name="ppo", experiment_name=exp_name) + exp_name = generate_exp_name("IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}") + logger = get_logger(cfg.logger.backend, logger_name="impala", experiment_name=exp_name) # Create test environment test_env = make_parallel_env(cfg.env.env_name, device, is_test=True) @@ -141,39 +147,40 @@ def main(cfg: "DictConfig"): # noqa: F821 # Linearly decrease the learning rate and clip epsilon alpha = 1 - (num_network_updates / total_network_updates) if cfg.optim.anneal_lr: - for g in optim.param_groups: + for g in optim_actor.param_groups: + g["lr"] = cfg.optim.lr * alpha + for g in optim_critic.param_groups: g["lr"] = cfg.optim.lr * alpha num_network_updates += 1 # Get a data batch batch = batch.to(device) - # Forward pass PPO loss + # Forward pass A2C loss loss = loss_module(batch) losses[i] = loss.select( "loss_critic", "loss_entropy", "loss_objective" ).detach() - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) + loss_actor = loss["loss_objective"] + loss["loss_entropy"] + loss_critic = loss["loss_critic"] # Backward pass - loss_sum.backward() + loss_actor.backward() + loss_critic.backward() torch.nn.utils.clip_grad_norm_( list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm ) # Update the networks - optim.step() - optim.zero_grad() + optim_actor.step() + optim_critic.step() + optim_actor.zero_grad() + optim_critic.zero_grad() losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses.items(): logger.log_scalar(key, value.item(), collected_frames) logger.log_scalar("lr", alpha * cfg.optim.lr, collected_frames) - logger.log_scalar( - "clip_epsilon", alpha * cfg.loss.clip_epsilon, collected_frames - ) # Test logging with torch.no_grad(), set_exploration_type(ExplorationType.MODE): diff --git a/examples/impala/utils.py b/examples/impala/utils.py index e9112ead762..04d69d2abf1 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -102,7 +102,7 @@ def make_base_env( def make_parallel_env(env_name, device, is_test=False): - num_envs = 8 + num_envs = 1 env = ParallelEnv( num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) ) diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index 1ca45d2635f..5316bd4f471 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -23,7 +23,7 @@ _self_set_skip_existing, _self_set_grad_enabled, _call_value_nets) -from torchrl.objectives.value.functional import _transpose_time, SHAPE_ERR, td0_return_estimate +from torchrl.objectives.value.functional import _transpose_time, SHAPE_ERR def _c_val( @@ -31,7 +31,7 @@ def _c_val( log_mu: torch.Tensor, c: Union[float, torch.Tensor] = 1, ) -> torch.Tensor: - return (log_pi - log_mu).clamp_max(math.log(c)).exp().unsqueeze(-1) # TODO: Review! + return (log_pi - log_mu).clamp_max(math.log(c)).exp().unsqueeze(-1) # TODO: is unsqueeze needed? def _dv_val( rewards: torch.Tensor, @@ -89,7 +89,7 @@ def vtrace_correction( device = state_value.device deltas, clipped_rho = _dv_val(reward, state_value, next_state_value, gamma, rho_thresh, log_pi, log_mu) - clipped_c = _c_val(log_pi, log_mu, c_thresh) + clipped_c = torch.min(torch.tensor(c_thresh).to(device), clipped_rho) ############################################################ # MAKE THIS PART WORK; THEN WE CAN TRY TO MAKE IT FASTER From 245892739f73f86e8dc711b41af857e2301b5296 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 7 Sep 2023 13:45:15 +0200 Subject: [PATCH 011/109] impala example --- examples/impala/config.yaml | 6 +++--- examples/impala/impala.py | 14 ++++++++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index 62a30fce19f..466ecc86646 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -4,7 +4,7 @@ env: # collector collector: - frames_per_batch: 128 + frames_per_batch: 256 total_frames: 40_000_000 # logger @@ -19,13 +19,13 @@ optim: lr: 0.0006 eps: 1.0e-6 weight_decay: 0.0 - max_grad_norm: 40.0 + max_grad_norm: 1.0 anneal_lr: True # loss loss: gamma: 0.99 - mini_batch_size: 128 + mini_batch_size: 256 critic_coef: 1.0 entropy_coef: 0.025 loss_critic_type: l2 \ No newline at end of file diff --git a/examples/impala/impala.py b/examples/impala/impala.py index ca3711ebe41..c6563519b34 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -53,7 +53,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # sync=False, # ) collector = MultiaSyncDataCollector( - create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * 4, + create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * 12, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -112,11 +112,17 @@ def main(cfg: "DictConfig"): # noqa: F821 num_network_updates = 0 start_time = time.time() pbar = tqdm.tqdm(total=total_frames) - num_mini_batches = frames_per_batch // mini_batch_size - total_network_updates = (total_frames // frames_per_batch) * num_mini_batches + total_network_updates = (total_frames // frames_per_batch) + acc_batch = [] for data in collector: + if len(acc_batch) < (cfg.collector.frames_per_batch // cfg.loss.mini_batch_size): + acc_batch.append(data) + continue + + data = torch.cat(acc_batch, dim=-1) + frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip pbar.update(data.numel()) @@ -132,7 +138,7 @@ def main(cfg: "DictConfig"): # noqa: F821 data["done"].copy_(data["end_of_life"]) data["next", "done"].copy_(data["next", "end_of_life"]) - losses = TensorDict({}, batch_size=[num_mini_batches]) + losses = TensorDict({}, batch_size=[1]) # Compute VTrace with torch.no_grad(): From c6a60cc30f577d8f139c206483884abca56bacee Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 7 Sep 2023 15:52:43 +0200 Subject: [PATCH 012/109] impala example --- examples/impala/config.yaml | 4 ++-- examples/impala/impala.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index 466ecc86646..b061153b5df 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -4,7 +4,7 @@ env: # collector collector: - frames_per_batch: 256 + frames_per_batch: 256 # 50 total_frames: 40_000_000 # logger @@ -25,7 +25,7 @@ optim: # loss loss: gamma: 0.99 - mini_batch_size: 256 + mini_batch_size: 256 # 500 critic_coef: 1.0 entropy_coef: 0.025 loss_critic_type: l2 \ No newline at end of file diff --git a/examples/impala/impala.py b/examples/impala/impala.py index c6563519b34..3847904b986 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -53,7 +53,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # sync=False, # ) collector = MultiaSyncDataCollector( - create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * 12, + create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * 8, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -112,16 +112,16 @@ def main(cfg: "DictConfig"): # noqa: F821 num_network_updates = 0 start_time = time.time() pbar = tqdm.tqdm(total=total_frames) - total_network_updates = (total_frames // frames_per_batch) + num_mini_batches = frames_per_batch // mini_batch_size if mini_batch_size < frames_per_batch else 1 + total_network_updates = (total_frames // max(frames_per_batch, mini_batch_size)) * num_mini_batches acc_batch = [] for data in collector: - if len(acc_batch) < (cfg.collector.frames_per_batch // cfg.loss.mini_batch_size): + if len(acc_batch) < (cfg.loss.mini_batch_size // cfg.collector.frames_per_batch): acc_batch.append(data) continue - - data = torch.cat(acc_batch, dim=-1) + data = torch.cat(acc_batch, dim=0) frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip @@ -138,7 +138,7 @@ def main(cfg: "DictConfig"): # noqa: F821 data["done"].copy_(data["end_of_life"]) data["next", "done"].copy_(data["next", "end_of_life"]) - losses = TensorDict({}, batch_size=[1]) + losses = TensorDict({}, batch_size=[num_mini_batches]) # Compute VTrace with torch.no_grad(): @@ -217,4 +217,4 @@ def main(cfg: "DictConfig"): # noqa: F821 if __name__ == "__main__": - main() + main() \ No newline at end of file From f1577663cef099efaa0fefb4791276b2cdd20428 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 7 Sep 2023 15:53:58 +0200 Subject: [PATCH 013/109] impala example --- examples/impala/config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index b061153b5df..6fcf7c94169 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -4,7 +4,7 @@ env: # collector collector: - frames_per_batch: 256 # 50 + frames_per_batch: 128 total_frames: 40_000_000 # logger @@ -19,13 +19,13 @@ optim: lr: 0.0006 eps: 1.0e-6 weight_decay: 0.0 - max_grad_norm: 1.0 + max_grad_norm: 10.0 anneal_lr: True # loss loss: gamma: 0.99 - mini_batch_size: 256 # 500 + mini_batch_size: 128 critic_coef: 1.0 entropy_coef: 0.025 loss_critic_type: l2 \ No newline at end of file From f0cf4f542177f71dd23e1c8108aa0e9e105cf411 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 7 Sep 2023 16:19:18 +0200 Subject: [PATCH 014/109] impala example --- examples/impala/config.yaml | 6 +++--- examples/impala/impala.py | 12 +++--------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index 6fcf7c94169..c3f73fca006 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -9,7 +9,7 @@ collector: # logger logger: - backend: csv + backend: wandb exp_name: Atari_Espeholt18 test_interval: 40_000_000 num_test_episodes: 3 @@ -19,7 +19,7 @@ optim: lr: 0.0006 eps: 1.0e-6 weight_decay: 0.0 - max_grad_norm: 10.0 + max_grad_norm: 5.0 anneal_lr: True # loss @@ -28,4 +28,4 @@ loss: mini_batch_size: 128 critic_coef: 1.0 entropy_coef: 0.025 - loss_critic_type: l2 \ No newline at end of file + loss_critic_type: l2 diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 3847904b986..92373a811f6 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -112,17 +112,11 @@ def main(cfg: "DictConfig"): # noqa: F821 num_network_updates = 0 start_time = time.time() pbar = tqdm.tqdm(total=total_frames) - num_mini_batches = frames_per_batch // mini_batch_size if mini_batch_size < frames_per_batch else 1 - total_network_updates = (total_frames // max(frames_per_batch, mini_batch_size)) * num_mini_batches - acc_batch = [] + num_mini_batches = frames_per_batch // mini_batch_size + total_network_updates = (total_frames // frames_per_batch) * num_mini_batches for data in collector: - if len(acc_batch) < (cfg.loss.mini_batch_size // cfg.collector.frames_per_batch): - acc_batch.append(data) - continue - data = torch.cat(acc_batch, dim=0) - frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip pbar.update(data.numel()) @@ -217,4 +211,4 @@ def main(cfg: "DictConfig"): # noqa: F821 if __name__ == "__main__": - main() \ No newline at end of file + main() From ae6fb62704353281a0ddc865c55ae0d096494733 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 18 Sep 2023 10:52:08 +0200 Subject: [PATCH 015/109] docs clarifications --- torchrl/objectives/value/vtrace.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index 5316bd4f471..c3d842ea15d 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -67,8 +67,8 @@ def vtrace_correction( Args: gamma (scalar): exponential mean discount. - log_pi (Tensor): log probability of taking actions in the environment. - log_mu (Tensor): log probability of taking actions in the environment. + log_pi (Tensor): collection actor log probability of taking actions in the environment. + log_mu (Tensor): current actor log probability of taking actions in the environment. state_value (Tensor): value function result with old_state input. next_state_value (Tensor): value function result with new_state input. reward (Tensor): reward of taking actions in the environment. @@ -197,6 +197,9 @@ def __init__( self.actor_network = actor_network self._log_prob_key = log_prob_key + if not isinstance(gamma, torch.Tensor) and gamma.shape != (): + raise NotImplementedError("Per-value gamma is not supported yet") + @property def log_prob_key(self): return self._log_prob_key From 081c2da553c848333859df286fde0ea216bda30c Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 19 Sep 2023 09:31:31 +0200 Subject: [PATCH 016/109] docs --- examples/impala/config.yaml | 15 ++++++++------- examples/impala/impala.py | 5 +++++ examples/impala/utils.py | 5 +++++ torchrl/objectives/value/vtrace.py | 7 ++++--- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index c3f73fca006..9d67995938f 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -4,7 +4,7 @@ env: # collector collector: - frames_per_batch: 128 + frames_per_batch: 80 total_frames: 40_000_000 # logger @@ -16,16 +16,17 @@ logger: # Optim optim: - lr: 0.0006 - eps: 1.0e-6 + lr: 0.0001 + eps: 1.0e-8 weight_decay: 0.0 - max_grad_norm: 5.0 + max_grad_norm: 40.0 anneal_lr: True # loss loss: gamma: 0.99 - mini_batch_size: 128 - critic_coef: 1.0 - entropy_coef: 0.025 + mini_batch_size: 80 + critic_coef: 0.25 + entropy_coef: 0.01 loss_critic_type: l2 + diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 92373a811f6..1fcc40fcc1a 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + """ This script reproduces the IMPALA Algorithm results from Espeholt et al. 2018 for the on Atari Environments. diff --git a/examples/impala/utils.py b/examples/impala/utils.py index 04d69d2abf1..e0a60ff6c18 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -1,3 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import random import gym diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index c3d842ea15d..ebbfd109904 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -69,8 +69,8 @@ def vtrace_correction( gamma (scalar): exponential mean discount. log_pi (Tensor): collection actor log probability of taking actions in the environment. log_mu (Tensor): current actor log probability of taking actions in the environment. - state_value (Tensor): value function result with old_state input. - next_state_value (Tensor): value function result with new_state input. + state_value (Tensor): value function result with state input. + next_state_value (Tensor): value function result with next_state input. reward (Tensor): reward of taking actions in the environment. done (Tensor): boolean flag for end of episode. rho_thresh (Union[float, Tensor]): clipping parameter for importance weights. @@ -89,7 +89,8 @@ def vtrace_correction( device = state_value.device deltas, clipped_rho = _dv_val(reward, state_value, next_state_value, gamma, rho_thresh, log_pi, log_mu) - clipped_c = torch.min(torch.tensor(c_thresh).to(device), clipped_rho) + c_thresh = torch.tensor(c_thresh, device=device) + clipped_c = torch.min(c_thresh, clipped_rho) ############################################################ # MAKE THIS PART WORK; THEN WE CAN TRY TO MAKE IT FASTER From 89a5a9eb9d3034ddec6f663ba233d3cc923b03ea Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 19 Sep 2023 09:42:49 +0200 Subject: [PATCH 017/109] fixes --- examples/impala/impala.py | 21 ++++++++------- examples/impala/utils.py | 42 +++++++----------------------- torchrl/objectives/value/vtrace.py | 4 ++- 3 files changed, 24 insertions(+), 43 deletions(-) diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 1fcc40fcc1a..7baeb907bb7 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -105,8 +105,10 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create logger - exp_name = generate_exp_name("IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}") - logger = get_logger(cfg.logger.backend, logger_name="impala", experiment_name=exp_name) + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name("IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}") + logger = get_logger(cfg.logger.backend, logger_name="impala", experiment_name=exp_name) # Create test environment test_env = make_parallel_env(cfg.env.env_name, device, is_test=True) @@ -122,16 +124,15 @@ def main(cfg: "DictConfig"): # noqa: F821 for data in collector: + log_info = None frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip pbar.update(data.numel()) - # Train loging + # Get train reward episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: - logger.log_scalar( - "reward_train", episode_rewards.mean().item(), collected_frames - ) + log_info.update({"reward_train": episode_rewards.mean().item()}) # Apply episodic end of life data["done"].copy_(data["end_of_life"]) @@ -152,10 +153,10 @@ def main(cfg: "DictConfig"): # noqa: F821 # Linearly decrease the learning rate and clip epsilon alpha = 1 - (num_network_updates / total_network_updates) if cfg.optim.anneal_lr: - for g in optim_actor.param_groups: - g["lr"] = cfg.optim.lr * alpha - for g in optim_critic.param_groups: - g["lr"] = cfg.optim.lr * alpha + for group in optim_actor.param_groups: + group["lr"] = cfg.optim.lr * alpha + for group in optim_critic.param_groups: + group["lr"] = cfg.optim.lr * alpha num_network_updates += 1 # Get a data batch diff --git a/examples/impala/utils.py b/examples/impala/utils.py index e0a60ff6c18..9f48782d2f7 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -5,7 +5,7 @@ import random -import gym +import gymnasium as gym import torch.nn import torch.optim from tensordict.nn import TensorDictModule @@ -26,6 +26,7 @@ ToTensorImage, TransformedEnv, VecNorm, + NoopResetEnv, ) from torchrl.envs.libs.gym import GymWrapper from torchrl.modules import ( @@ -43,47 +44,22 @@ # -------------------------------------------------------------------- -class NoopResetEnv(gym.Wrapper): - def __init__(self, env, noop_max=30): - """Sample initial states by taking random number of no-ops on reset.""" - gym.Wrapper.__init__(self, env) - self.noop_max = noop_max - self.override_num_noops = None - self.noop_action = 0 # No-op is assumed to be action 0. - assert env.unwrapped.get_action_meanings()[0] == "NOOP" - - def reset(self, **kwargs): - """Do no-op action for a number of steps in [1, noop_max].""" - self.env.reset(**kwargs) - if self.override_num_noops is not None: - noops = self.override_num_noops - else: - noops = random.randint(1, self.noop_max + 1) - assert noops > 0 - obs = None - for _ in range(noops): - obs, _, done, *other = self.env.step(self.noop_action) - if done: - obs = self.env.reset(**kwargs) - return obs - - class EpisodicLifeEnv(gym.Wrapper): def __init__(self, env): """Make end-of-life == end-of-episode, but only reset on true game over. - Done by DeepMind for the DQN and co. since it helps value estimation. + Done by DeepMind for the DQN and co. It helps value estimation. """ gym.Wrapper.__init__(self, env) self.lives = 0 def step(self, action): - obs, rew, done, info = self.env.step(action) + obs, rew, done, truncate, info = self.env.step(action) lives = self.env.unwrapped.ale.lives() info["end_of_life"] = False if (lives < self.lives) or done: info["end_of_life"] = True self.lives = lives - return obs, rew, done, info + return obs, rew, done, truncate, info def reset(self, **kwargs): reset_data = self.env.reset(**kwargs) @@ -96,13 +72,15 @@ def make_base_env( ): env = gym.make(env_name) if not is_test: - env = NoopResetEnv(env, noop_max=30) env = EpisodicLifeEnv(env) env = GymWrapper( env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device ) - reader = default_info_dict_reader(["end_of_life"]) - env.set_info_dict_reader(reader) + env = TransformedEnv(env) + env.append_transform(NoopResetEnv(noops=30, random=True)) + if not is_test: + reader = default_info_dict_reader(["end_of_life"]) + env.set_info_dict_reader(reader) return env diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index ebbfd109904..1ca561ed24a 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -47,7 +47,7 @@ def _dv_val( return deltas, clipped_rho -# @_transpose_time # TODO: is this needed? +@_transpose_time def vtrace_correction( gamma: float, log_pi: torch.Tensor, @@ -198,6 +198,7 @@ def __init__( self.actor_network = actor_network self._log_prob_key = log_prob_key + import ipdb; ipdb.set_trace() if not isinstance(gamma, torch.Tensor) and gamma.shape != (): raise NotImplementedError("Per-value gamma is not supported yet") @@ -315,6 +316,7 @@ def forward( # Compute the V-Trace correction done = tensordict.get(("next", self.tensor_keys.done)) + import ipdb; ipdb.set_trace() adv, value_target = vtrace_correction( gamma, log_pi, From 888fbcbb9d3d2d7e5eef3e4765561d3f2b388d85 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 19 Sep 2023 10:28:57 +0200 Subject: [PATCH 018/109] fixes --- examples/impala/impala.py | 28 +++++++++++++++++++++++++--- torchrl/objectives/value/vtrace.py | 21 ++++++++++++++------- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 7baeb907bb7..96fc11ad0c6 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -122,9 +122,11 @@ def main(cfg: "DictConfig"): # noqa: F821 num_mini_batches = frames_per_batch // mini_batch_size total_network_updates = (total_frames // frames_per_batch) * num_mini_batches + sampling_start = time.time() for data in collector: log_info = None + sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip pbar.update(data.numel()) @@ -139,6 +141,7 @@ def main(cfg: "DictConfig"): # noqa: F821 data["next", "done"].copy_(data["next", "end_of_life"]) losses = TensorDict({}, batch_size=[num_mini_batches]) + training_start = time.time() # Compute VTrace with torch.no_grad(): @@ -183,10 +186,17 @@ def main(cfg: "DictConfig"): # noqa: F821 optim_actor.zero_grad() optim_critic.zero_grad() + training_time = time.time() - training_start losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses.items(): - logger.log_scalar(key, value.item(), collected_frames) - logger.log_scalar("lr", alpha * cfg.optim.lr, collected_frames) + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * cfg.optim.lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + } + ) # Test logging with torch.no_grad(), set_exploration_type(ExplorationType.MODE): @@ -194,6 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collected_frames // test_interval ): actor.eval() + eval_start = time.time() test_rewards = [] for _ in range(cfg.logger.num_test_episodes): td_test = test_env.rollout( @@ -206,10 +217,21 @@ def main(cfg: "DictConfig"): # noqa: F821 reward = td_test["next", "episode_reward"][td_test["next", "done"]] test_rewards = np.append(test_rewards, reward.cpu().numpy()) del td_test - logger.log_scalar("reward_test", test_rewards.mean(), collected_frames) + eval_time = time.time() - eval_start + log_info.update( + { + "test/reward": test_rewards.mean(), + "test/eval_time": eval_time, + } + ) actor.train() + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + collector.update_policy_weights_() + sampling_start = time.time() end_time = time.time() execution_time = end_time - start_time diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index 1ca561ed24a..510f5d8fda5 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -191,16 +191,23 @@ def __init__( device = next(value_network.parameters()).device except (AttributeError, StopIteration): device = torch.device("cpu") - self.register_buffer("gamma", torch.tensor(gamma, device=device)) - self.register_buffer("rho_thresh", torch.tensor(rho_thresh, device=device)) - self.register_buffer("c_thresh", torch.tensor(c_thresh, device=device)) + + if not isinstance(gamma, torch.Tensor): + gamma = torch.tensor(gamma, device=device) + if not isinstance(rho_thresh, torch.Tensor): + rho_thresh = torch.tensor(rho_thresh, device=device) + if not isinstance(c_thresh, torch.Tensor): + c_thresh = torch.tensor(c_thresh, device=device) + + self.register_buffer("gamma", gamma) + self.register_buffer("rho_thresh", rho_thresh) + self.register_buffer("c_thresh", c_thresh) self.average_adv = average_adv self.actor_network = actor_network self._log_prob_key = log_prob_key - import ipdb; ipdb.set_trace() - if not isinstance(gamma, torch.Tensor) and gamma.shape != (): - raise NotImplementedError("Per-value gamma is not supported yet") + if isinstance(gamma, torch.Tensor) and gamma.shape != (): + raise NotImplementedError("Per-value gamma is not supported yet. Gamma must be a scalar.") @property def log_prob_key(self): @@ -332,7 +339,7 @@ def forward( if self.average_adv: loc = adv.mean() - scale = adv.std().clamp_min(1e-4) + scale = adv.std().clamp_min(1e-8) adv = adv - loc adv = adv / scale From f3f9832f90e7012997af8b753600135591483634 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 19 Sep 2023 10:31:02 +0200 Subject: [PATCH 019/109] config --- examples/impala/config.yaml | 9 +++++---- examples/impala/impala.py | 27 +++++++++------------------ examples/impala/utils.py | 25 +++++++++++++++++++++++-- torchrl/objectives/value/vtrace.py | 8 ++++---- 4 files changed, 41 insertions(+), 28 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index 9d67995938f..042a572375e 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -4,8 +4,9 @@ env: # collector collector: - frames_per_batch: 80 + frames_per_batch: 128 # 80 total_frames: 40_000_000 + num_workers: 12 # logger logger: @@ -16,8 +17,8 @@ logger: # Optim optim: - lr: 0.0001 - eps: 1.0e-8 + lr: 0.0006 # 0.0001 + eps: 1.0e-5 weight_decay: 0.0 max_grad_norm: 40.0 anneal_lr: True @@ -25,7 +26,7 @@ optim: # loss loss: gamma: 0.99 - mini_batch_size: 80 + mini_batch_size: 128 # 80 critic_coef: 0.25 entropy_coef: 0.01 loss_critic_type: l2 diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 96fc11ad0c6..aa781fa9dad 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -25,10 +25,10 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import A2CLoss, ClipPPOLoss + from torchrl.objectives import A2CLoss from torchrl.record.loggers import generate_exp_name, get_logger from torchrl.objectives.value.vtrace import VTrace - from utils import make_parallel_env, make_ppo_models + from utils import make_parallel_env, make_ppo_models, eval_model device = "cpu" if not torch.cuda.is_available() else "cuda" @@ -39,7 +39,7 @@ def main(cfg: "DictConfig"): # noqa: F821 mini_batch_size = cfg.loss.mini_batch_size // frame_skip test_interval = cfg.logger.test_interval // frame_skip - # Create models (check utils_atari.py) + # Create models (check utils.py) actor, critic, critic_head = make_ppo_models(cfg.env.env_name) actor, critic, critic_head = ( actor.to(device), @@ -58,7 +58,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # sync=False, # ) collector = MultiaSyncDataCollector( - create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * 8, + create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * cfg.collector.num_workers, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -125,7 +125,7 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() for data in collector: - log_info = None + log_info = {} sampling_time = time.time() - sampling_start frames_in_batch = data.numel() collected_frames += frames_in_batch * frame_skip @@ -134,7 +134,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get train reward episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: - log_info.update({"reward_train": episode_rewards.mean().item()}) + log_info.update({"train/reward": episode_rewards.mean().item()}) # Apply episodic end of life data["done"].copy_(data["end_of_life"]) @@ -205,18 +205,9 @@ def main(cfg: "DictConfig"): # noqa: F821 ): actor.eval() eval_start = time.time() - test_rewards = [] - for _ in range(cfg.logger.num_test_episodes): - td_test = test_env.rollout( - policy=actor, - auto_reset=True, - auto_cast_to_device=True, - break_when_any_done=True, - max_steps=10_000_000, - ) - reward = td_test["next", "episode_reward"][td_test["next", "done"]] - test_rewards = np.append(test_rewards, reward.cpu().numpy()) - del td_test + test_rewards = eval_model( + actor, test_env, num_episodes=cfg.logger.num_test_episodes + ) eval_time = time.time() - eval_start log_info.update( { diff --git a/examples/impala/utils.py b/examples/impala/utils.py index 9f48782d2f7..e431dd82255 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -3,9 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import random import gymnasium as gym +import numpy as np import torch.nn import torch.optim from tensordict.nn import TensorDictModule @@ -99,7 +99,7 @@ def make_parallel_env(env_name, device, is_test=False): if not is_test: env.append_transform(RewardClipping(-1, 1)) env.append_transform(DoubleToFloat()) - env.append_transform(VecNorm(in_keys=["pixels"])) + # env.append_transform(VecNorm(in_keys=["pixels"])) return env @@ -218,3 +218,24 @@ def make_ppo_models(env_name): del proof_environment return actor, critic, critic_head + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards = np.append(test_rewards, reward.cpu().numpy()) + del td_test + return test_rewards.mean() diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index 510f5d8fda5..00e03930169 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -31,7 +31,7 @@ def _c_val( log_mu: torch.Tensor, c: Union[float, torch.Tensor] = 1, ) -> torch.Tensor: - return (log_pi - log_mu).clamp_max(math.log(c)).exp().unsqueeze(-1) # TODO: is unsqueeze needed? + return (log_pi - log_mu).clamp_max(math.log(c)).exp() # TODO: is unsqueeze needed? def _dv_val( rewards: torch.Tensor, @@ -313,17 +313,17 @@ def forward( # Make sure we have the log prob computed at collection time if self.log_prob_key not in tensordict.keys(): raise ValueError(f"Expected {self.log_prob_key} to be in tensordict") - log_mu = tensordict.get(self.log_prob_key) + log_mu = tensordict.get(self.log_prob_key).reshape_as(value) # Compute log prob with current policy with hold_out_net(self.actor_network): log_pi = self.actor_network( tensordict.select(self.actor_network.in_keys) - ).get(self.log_prob_key) + ).get(self.log_prob_key).reshape_as(value) # Compute the V-Trace correction done = tensordict.get(("next", self.tensor_keys.done)) - import ipdb; ipdb.set_trace() + adv, value_target = vtrace_correction( gamma, log_pi, From 5b7d6428f7d81cec70888eb64b91e05ec752d434 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 19 Sep 2023 16:33:55 +0200 Subject: [PATCH 020/109] fixes --- examples/impala/config.yaml | 12 ++++--- examples/impala/impala.py | 55 +++++++++++++----------------- torchrl/objectives/value/vtrace.py | 24 ++++++------- 3 files changed, 40 insertions(+), 51 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index 042a572375e..f5a5f216203 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -4,9 +4,9 @@ env: # collector collector: - frames_per_batch: 128 # 80 + frames_per_batch: 80 total_frames: 40_000_000 - num_workers: 12 + num_workers: 2 # logger logger: @@ -17,8 +17,10 @@ logger: # Optim optim: - lr: 0.0006 # 0.0001 - eps: 1.0e-5 + lr: 0.0006 + eps: 0.01 + alpha: 0.99 + momentum: 0.0 weight_decay: 0.0 max_grad_norm: 40.0 anneal_lr: True @@ -26,7 +28,7 @@ optim: # loss loss: gamma: 0.99 - mini_batch_size: 128 # 80 + batch_size: 32 critic_coef: 0.25 entropy_coef: 0.01 loss_critic_type: l2 diff --git a/examples/impala/impala.py b/examples/impala/impala.py index aa781fa9dad..8672e96e75d 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -36,7 +36,6 @@ def main(cfg: "DictConfig"): # noqa: F821 frame_skip = 4 total_frames = cfg.collector.total_frames // frame_skip frames_per_batch = cfg.collector.frames_per_batch // frame_skip - mini_batch_size = cfg.loss.mini_batch_size // frame_skip test_interval = cfg.logger.test_interval // frame_skip # Create models (check utils.py) @@ -70,9 +69,9 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(frames_per_batch), + storage=LazyMemmapStorage(cfg.loss.batch_size), sampler=sampler, - batch_size=mini_batch_size, + batch_size=cfg.loss.batch_size, ) # Create loss and adv modules @@ -91,17 +90,13 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizer - optim_actor = torch.optim.Adam( - actor.parameters(), - lr=cfg.optim.lr, - weight_decay=cfg.optim.weight_decay, - eps=cfg.optim.eps, - ) - optim_critic = torch.optim.Adam( - critic.parameters(), + optim = torch.optim.RMSprop( + loss_module.parameters(), lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, + momentum=cfg.optim.momentum, eps=cfg.optim.eps, + alpha=cfg.optim.alpha, ) # Create logger @@ -119,11 +114,11 @@ def main(cfg: "DictConfig"): # noqa: F821 num_network_updates = 0 start_time = time.time() pbar = tqdm.tqdm(total=total_frames) - num_mini_batches = frames_per_batch // mini_batch_size - total_network_updates = (total_frames // frames_per_batch) * num_mini_batches + total_network_updates = (total_frames // frames_per_batch) sampling_start = time.time() - for data in collector: + + for i, data in enumerate(collector): log_info = {} sampling_time = time.time() - sampling_start @@ -140,25 +135,26 @@ def main(cfg: "DictConfig"): # noqa: F821 data["done"].copy_(data["end_of_life"]) data["next", "done"].copy_(data["next", "end_of_life"]) - losses = TensorDict({}, batch_size=[num_mini_batches]) training_start = time.time() # Compute VTrace with torch.no_grad(): - data = vtrace_module(data) - data_reshape = data.reshape(-1) + data = vtrace_module(data) # TODO: parallelize this # Update the data buffer - data_buffer.extend(data_reshape) + data_buffer.extend(data) + + if i % cfg.loss.batch_size != 0 or i == 0: + continue - for i, batch in enumerate(data_buffer): + for batch in enumerate(data_buffer): + + batch = batch.reshape(-1) # Linearly decrease the learning rate and clip epsilon alpha = 1 - (num_network_updates / total_network_updates) if cfg.optim.anneal_lr: - for group in optim_actor.param_groups: - group["lr"] = cfg.optim.lr * alpha - for group in optim_critic.param_groups: + for group in optim.param_groups: group["lr"] = cfg.optim.lr * alpha num_network_updates += 1 @@ -167,27 +163,22 @@ def main(cfg: "DictConfig"): # noqa: F821 # Forward pass A2C loss loss = loss_module(batch) - losses[i] = loss.select( + losses = loss.select( "loss_critic", "loss_entropy", "loss_objective" ).detach() - loss_actor = loss["loss_objective"] + loss["loss_entropy"] - loss_critic = loss["loss_critic"] + loss_sum = loss["loss_objective"] + loss["loss_entropy"] + loss["loss_critic"] # Backward pass - loss_actor.backward() - loss_critic.backward() + optim.zero_grad() + loss_sum.backward() torch.nn.utils.clip_grad_norm_( list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm ) # Update the networks - optim_actor.step() - optim_critic.step() - optim_actor.zero_grad() - optim_critic.zero_grad() + optim.step() training_time = time.time() - training_start - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py index 00e03930169..8e00d5f3d53 100644 --- a/torchrl/objectives/value/vtrace.py +++ b/torchrl/objectives/value/vtrace.py @@ -37,13 +37,13 @@ def _dv_val( rewards: torch.Tensor, vals: torch.Tensor, next_vals: torch.Tensor, - gamma: Union[float, torch.Tensor], + discount: Union[float, torch.Tensor], rho_thresh: Union[float, torch.Tensor], log_pi: torch.Tensor, log_mu: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: clipped_rho = _c_val(log_pi, log_mu, rho_thresh) - deltas = clipped_rho * (rewards + gamma * next_vals - vals) + deltas = clipped_rho * (rewards + discount * next_vals - vals) return deltas, clipped_rho @@ -88,16 +88,14 @@ def vtrace_correction( dtype = next_state_value.dtype device = state_value.device - deltas, clipped_rho = _dv_val(reward, state_value, next_state_value, gamma, rho_thresh, log_pi, log_mu) - c_thresh = torch.tensor(c_thresh, device=device) - clipped_c = torch.min(c_thresh, clipped_rho) - - ############################################################ - # MAKE THIS PART WORK; THEN WE CAN TRY TO MAKE IT FASTER - not_done = (~done).int() *batch_size, time_steps, lastdim = not_done.shape discounts = gamma * not_done + + deltas, clipped_rho = _dv_val(reward, state_value, next_state_value, discounts, rho_thresh, log_pi, log_mu) + c_thresh = c_thresh.to(device) + clipped_c = torch.min(c_thresh, clipped_rho) + vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])] for i in reversed(range(time_steps)): discount_t, c_t, delta_t = discounts[..., i, :], clipped_c[..., i, :], deltas[..., i, :] @@ -106,9 +104,7 @@ def vtrace_correction( vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[time_dim]) vs = vs_minus_v_xs + state_value vs_t_plus_1 = torch.cat([vs[..., 1:, :], next_state_value[..., -1:, :]], dim=time_dim) - advantages = clipped_rho * (reward + gamma * vs_t_plus_1 - state_value) - - ############################################################ + advantages = clipped_rho * (reward + discounts * vs_t_plus_1 - state_value) return advantages, vs @@ -172,7 +168,7 @@ def __init__( average_adv: bool = False, differentiable: bool = False, skip_existing: Optional[bool] = None, - log_prob_key: NestedKey = "sample_log_prob", # TODO: should be added to _AcceptedKeys? + log_prob_key: NestedKey = "sample_log_prob", # Consider adding it to _AcceptedKeys? advantage_key: NestedKey = None, value_target_key: NestedKey = None, value_key: NestedKey = None, @@ -339,7 +335,7 @@ def forward( if self.average_adv: loc = adv.mean() - scale = adv.std().clamp_min(1e-8) + scale = adv.std().clamp_min(1e-6) adv = adv - loc adv = adv / scale From 7dca35e317607cb3473ac7b9695d2f0844d3a128 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 19 Sep 2023 16:51:31 +0200 Subject: [PATCH 021/109] fixes --- examples/impala/config.yaml | 2 +- examples/impala/impala.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index f5a5f216203..1359fc16cfa 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -6,7 +6,7 @@ env: collector: frames_per_batch: 80 total_frames: 40_000_000 - num_workers: 2 + num_workers: 12 # logger logger: diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 8672e96e75d..059045f98e1 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -147,7 +147,7 @@ def main(cfg: "DictConfig"): # noqa: F821 if i % cfg.loss.batch_size != 0 or i == 0: continue - for batch in enumerate(data_buffer): + for batch in data_buffer: batch = batch.reshape(-1) From e8c35efff25d708440e4374d36328f87f31b8239 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 19 Sep 2023 17:38:08 +0200 Subject: [PATCH 022/109] fixes --- examples/impala/impala.py | 14 +++++++++++++- examples/impala/utils.py | 1 + 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 059045f98e1..6b2578e7f96 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -20,7 +20,7 @@ def main(cfg: "DictConfig"): # noqa: F821 import tqdm from tensordict import TensorDict - from torchrl.collectors import MultiaSyncDataCollector + from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector from torchrl.collectors.distributed import RPCDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement @@ -65,6 +65,15 @@ def main(cfg: "DictConfig"): # noqa: F821 storing_device=device, max_frames_per_traj=-1, ) + # collector = SyncDataCollector( + # create_env_fn=make_parallel_env(cfg.env.env_name, device), + # policy=actor, + # frames_per_batch=frames_per_batch, + # total_frames=total_frames, + # device=device, + # storing_device=device, + # max_frames_per_traj=-1, + # ) # Create data buffer sampler = SamplerWithoutReplacement() @@ -145,6 +154,9 @@ def main(cfg: "DictConfig"): # noqa: F821 data_buffer.extend(data) if i % cfg.loss.batch_size != 0 or i == 0: + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) continue for batch in data_buffer: diff --git a/examples/impala/utils.py b/examples/impala/utils.py index e431dd82255..4104753f547 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -108,6 +108,7 @@ def make_parallel_env(env_name, device, is_test=False): # -------------------------------------------------------------------- + def make_ppo_modules_pixels(proof_environment): # Define input shape From a4af09fa691ace8efc1fc7bc20a82783a405dfb9 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 20 Sep 2023 11:44:24 +0200 Subject: [PATCH 023/109] fixes --- examples/impala/impala.py | 26 +++++++++++++++++--------- examples/impala/utils.py | 1 + torchrl/objectives/a2c.py | 9 ++++++--- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 6b2578e7f96..1553c7966f0 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -64,6 +64,7 @@ def main(cfg: "DictConfig"): # noqa: F821 device=device, storing_device=device, max_frames_per_traj=-1, + update_at_each_batch=True, ) # collector = SyncDataCollector( # create_env_fn=make_parallel_env(cfg.env.env_name, device), @@ -88,7 +89,7 @@ def main(cfg: "DictConfig"): # noqa: F821 gamma=cfg.loss.gamma, value_network=critic, actor_network=actor, - average_adv=True, + average_adv=False, ) loss_module = A2CLoss( actor=actor, @@ -123,7 +124,7 @@ def main(cfg: "DictConfig"): # noqa: F821 num_network_updates = 0 start_time = time.time() pbar = tqdm.tqdm(total=total_frames) - total_network_updates = (total_frames // frames_per_batch) + total_network_updates = (total_frames // (frames_per_batch * cfg.loss.batch_size)) sampling_start = time.time() @@ -138,7 +139,14 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get train reward episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: - log_info.update({"train/reward": episode_rewards.mean().item()}) + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) # Apply episodic end of life data["done"].copy_(data["end_of_life"]) @@ -148,19 +156,22 @@ def main(cfg: "DictConfig"): # noqa: F821 # Compute VTrace with torch.no_grad(): - data = vtrace_module(data) # TODO: parallelize this + # TODO: parallelize this by running it on batch, now returns some vmap error + data = vtrace_module(data) # Update the data buffer data_buffer.extend(data) + # Accumulate data if i % cfg.loss.batch_size != 0 or i == 0: if logger: for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) continue - for batch in data_buffer: + for batch in data_buffer: # Only one batch in the buffer from accumulated data + batch = batch.to(device) batch = batch.reshape(-1) # Linearly decrease the learning rate and clip epsilon @@ -170,9 +181,6 @@ def main(cfg: "DictConfig"): # noqa: F821 group["lr"] = cfg.optim.lr * alpha num_network_updates += 1 - # Get a data batch - batch = batch.to(device) - # Forward pass A2C loss loss = loss_module(batch) losses = loss.select( @@ -224,9 +232,9 @@ def main(cfg: "DictConfig"): # noqa: F821 for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) - collector.update_policy_weights_() sampling_start = time.time() + collector.shutdown() end_time = time.time() execution_time = end_time - start_time print(f"Training took {execution_time:.2f} seconds to finish") diff --git a/examples/impala/utils.py b/examples/impala/utils.py index 4104753f547..cdad28b5a16 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -37,6 +37,7 @@ ProbabilisticActor, TanhNormal, ValueOperator, + LSTMModule, ) # ==================================================================== diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index ea2c715d927..23837d94acc 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -354,14 +354,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: advantage = tensordict.get(self.tensor_keys.advantage) log_probs, dist = self._log_probs(tensordict) loss = -(log_probs * advantage) - td_out = TensorDict({"loss_objective": loss.mean()}, []) + # td_out = TensorDict({"loss_objective": loss.mean()}, []) + td_out = TensorDict({"loss_objective": loss.sum()}, []) if self.entropy_bonus: entropy = self.get_entropy_bonus(dist) td_out.set("entropy", entropy.mean().detach()) # for logging - td_out.set("loss_entropy", -self.entropy_coef * entropy.mean()) + # td_out.set("loss_entropy", -self.entropy_coef * entropy.mean()) + td_out.set("loss_entropy", -self.entropy_coef * entropy.sum()) if self.critic_coef: loss_critic = self.loss_critic(tensordict).mean() - td_out.set("loss_critic", loss_critic.mean()) + # td_out.set("loss_critic", loss_critic.mean()) + td_out.set("loss_critic", loss_critic.sum()) return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): From 8bf478755a86f02f9ead9eac3cf5e7452862eefd Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 21 Sep 2023 12:53:26 +0200 Subject: [PATCH 024/109] move vtrace to adv script --- examples/impala/config.yaml | 11 +- examples/impala/impala.py | 77 +++--- examples/impala/utils.py | 23 +- examples/impala2/config.yaml | 36 +++ examples/impala2/impala.py | 232 +++++++++++++++++ examples/impala2/utils.py | 239 +++++++++++++++++ examples/impala3/config.yaml | 36 +++ examples/impala3/impala.py | 239 +++++++++++++++++ examples/impala3/utils.py | 239 +++++++++++++++++ torchrl/objectives/a2c.py | 9 +- torchrl/objectives/value/__init__.py | 1 + torchrl/objectives/value/advantages.py | 244 +++++++++++++++++ torchrl/objectives/value/functional.py | 67 +++++ torchrl/objectives/value/vtrace.py | 346 ------------------------- 14 files changed, 1390 insertions(+), 409 deletions(-) create mode 100644 examples/impala2/config.yaml create mode 100644 examples/impala2/impala.py create mode 100644 examples/impala2/utils.py create mode 100644 examples/impala3/config.yaml create mode 100644 examples/impala3/impala.py create mode 100644 examples/impala3/utils.py delete mode 100644 torchrl/objectives/value/vtrace.py diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index 1359fc16cfa..448f42c2f3e 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -6,23 +6,23 @@ env: collector: frames_per_batch: 80 total_frames: 40_000_000 - num_workers: 12 + num_workers: 8 # logger logger: - backend: wandb + backend: null exp_name: Atari_Espeholt18 test_interval: 40_000_000 num_test_episodes: 3 # Optim optim: - lr: 0.0006 + lr: 0.00048 eps: 0.01 alpha: 0.99 momentum: 0.0 weight_decay: 0.0 - max_grad_norm: 40.0 + max_grad_norm: 1.0 anneal_lr: True # loss @@ -30,6 +30,5 @@ loss: gamma: 0.99 batch_size: 32 critic_coef: 0.25 - entropy_coef: 0.01 + entropy_coef: 0.05 loss_critic_type: l2 - diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 1553c7966f0..bc0862df89e 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -15,20 +15,18 @@ def main(cfg: "DictConfig"): # noqa: F821 import time - import numpy as np import torch.optim import tqdm - from tensordict import TensorDict from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector from torchrl.collectors.distributed import RPCDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import A2CLoss + from torchrl.objectives import A2CLoss, ClipPPOLoss + from torchrl.objectives.value import GAE, VTrace from torchrl.record.loggers import generate_exp_name, get_logger - from torchrl.objectives.value.vtrace import VTrace - from utils import make_parallel_env, make_ppo_models, eval_model + from utils import eval_model, make_parallel_env, make_ppo_models device = "cpu" if not torch.cuda.is_available() else "cuda" @@ -39,16 +37,15 @@ def main(cfg: "DictConfig"): # noqa: F821 test_interval = cfg.logger.test_interval // frame_skip # Create models (check utils.py) - actor, critic, critic_head = make_ppo_models(cfg.env.env_name) - actor, critic, critic_head = ( + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = ( actor.to(device), critic.to(device), - critic_head.to(device), ) # Create collector # collector = RPCDataCollector( - # create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * 2, + # create_env_fn=[make_parallel_env(cfg.env.env_name, 1, device)] * 2, # policy=actor, # frames_per_batch=frames_per_batch, # total_frames=total_frames, @@ -57,7 +54,8 @@ def main(cfg: "DictConfig"): # noqa: F821 # sync=False, # ) collector = MultiaSyncDataCollector( - create_env_fn=[make_parallel_env(cfg.env.env_name, device)] * cfg.collector.num_workers, + create_env_fn=[make_parallel_env(cfg.env.env_name, 1, device)] + * cfg.collector.num_workers, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -67,7 +65,7 @@ def main(cfg: "DictConfig"): # noqa: F821 update_at_each_batch=True, ) # collector = SyncDataCollector( - # create_env_fn=make_parallel_env(cfg.env.env_name, device), + # create_env_fn=make_parallel_env(cfg.env.env_name, cfg.loss.batch_size, device), # policy=actor, # frames_per_batch=frames_per_batch, # total_frames=total_frames, @@ -79,13 +77,13 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg.loss.batch_size), + storage=LazyMemmapStorage(cfg.collector.frames_per_batch * cfg.loss.batch_size), sampler=sampler, - batch_size=cfg.loss.batch_size, + batch_size=cfg.collector.frames_per_batch * cfg.loss.batch_size, ) # Create loss and adv modules - vtrace_module = VTrace( + adv_module = VTrace( gamma=cfg.loss.gamma, value_network=critic, actor_network=actor, @@ -112,11 +110,15 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create logger logger = None if cfg.logger.backend: - exp_name = generate_exp_name("IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}") - logger = get_logger(cfg.logger.backend, logger_name="impala", experiment_name=exp_name) + exp_name = generate_exp_name( + "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" + ) + logger = get_logger( + cfg.logger.backend, logger_name="impala", experiment_name=exp_name + ) # Create test environment - test_env = make_parallel_env(cfg.env.env_name, device, is_test=True) + test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) test_env.eval() # Main loop @@ -124,10 +126,9 @@ def main(cfg: "DictConfig"): # noqa: F821 num_network_updates = 0 start_time = time.time() pbar = tqdm.tqdm(total=total_frames) - total_network_updates = (total_frames // (frames_per_batch * cfg.loss.batch_size)) + total_network_updates = total_frames // (frames_per_batch * cfg.loss.batch_size) sampling_start = time.time() - for i, data in enumerate(collector): log_info = {} @@ -144,7 +145,7 @@ def main(cfg: "DictConfig"): # noqa: F821 { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() - / len(episode_length), + / len(episode_length), } ) @@ -152,44 +153,47 @@ def main(cfg: "DictConfig"): # noqa: F821 data["done"].copy_(data["end_of_life"]) data["next", "done"].copy_(data["next", "end_of_life"]) - training_start = time.time() + # Accumulate data + if (i + 1) % cfg.loss.batch_size != 0: - # Compute VTrace - with torch.no_grad(): - # TODO: parallelize this by running it on batch, now returns some vmap error - data = vtrace_module(data) + # Compute VTrace + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) - # Update the data buffer - data_buffer.extend(data) + # Update the data buffer + data_buffer.extend(data_reshape) - # Accumulate data - if i % cfg.loss.batch_size != 0 or i == 0: if logger: for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) continue - for batch in data_buffer: # Only one batch in the buffer from accumulated data + training_start = time.time() - batch = batch.to(device) - batch = batch.reshape(-1) + for k, batch in enumerate(data_buffer): # Linearly decrease the learning rate and clip epsilon - alpha = 1 - (num_network_updates / total_network_updates) + alpha = 1.0 if cfg.optim.anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) for group in optim.param_groups: group["lr"] = cfg.optim.lr * alpha num_network_updates += 1 + # Get a data batch + batch = batch.to(device) + # Forward pass A2C loss loss = loss_module(batch) losses = loss.select( "loss_critic", "loss_entropy", "loss_objective" ).detach() - loss_sum = loss["loss_objective"] + loss["loss_entropy"] + loss["loss_critic"] + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) # Backward pass - optim.zero_grad() loss_sum.backward() torch.nn.utils.clip_grad_norm_( list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm @@ -197,6 +201,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Update the networks optim.step() + optim.zero_grad() training_time = time.time() - training_start for key, value in losses.items(): @@ -212,7 +217,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Test logging with torch.no_grad(), set_exploration_type(ExplorationType.MODE): if (collected_frames - frames_in_batch) // test_interval < ( - collected_frames // test_interval + collected_frames // test_interval ): actor.eval() eval_start = time.time() diff --git a/examples/impala/utils.py b/examples/impala/utils.py index cdad28b5a16..f25fd60613c 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -18,6 +18,7 @@ EnvCreator, ExplorationType, GrayScale, + NoopResetEnv, ParallelEnv, Resize, RewardClipping, @@ -26,18 +27,17 @@ ToTensorImage, TransformedEnv, VecNorm, - NoopResetEnv, ) from torchrl.envs.libs.gym import GymWrapper from torchrl.modules import ( ActorValueOperator, ConvNet, + LSTMModule, MLP, OneHotCategorical, ProbabilisticActor, TanhNormal, ValueOperator, - LSTMModule, ) # ==================================================================== @@ -69,7 +69,7 @@ def reset(self, **kwargs): def make_base_env( - env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False + env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False ): env = gym.make(env_name) if not is_test: @@ -85,8 +85,8 @@ def make_base_env( return env -def make_parallel_env(env_name, device, is_test=False): - num_envs = 1 +def make_parallel_env(env_name, num_envs, device, is_test=False): + env = ParallelEnv( num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) ) @@ -100,7 +100,7 @@ def make_parallel_env(env_name, device, is_test=False): if not is_test: env.append_transform(RewardClipping(-1, 1)) env.append_transform(DoubleToFloat()) - # env.append_transform(VecNorm(in_keys=["pixels"])) + env.append_transform(VecNorm(in_keys=["pixels"], eps=0.999)) return env @@ -109,7 +109,6 @@ def make_parallel_env(env_name, device, is_test=False): # -------------------------------------------------------------------- - def make_ppo_modules_pixels(proof_environment): # Define input shape @@ -196,7 +195,7 @@ def make_ppo_modules_pixels(proof_environment): def make_ppo_models(env_name): - proof_environment = make_parallel_env(env_name, device="cpu") + proof_environment = make_parallel_env(env_name, 1, device="cpu") common_module, policy_module, value_module = make_ppo_modules_pixels( proof_environment ) @@ -208,18 +207,12 @@ def make_ppo_models(env_name): value_operator=value_module, ) - with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) - del td - actor = actor_critic.get_policy_operator() critic = actor_critic.get_value_operator() - critic_head = actor_critic.get_value_head() del proof_environment - return actor, critic, critic_head + return actor, critic # ==================================================================== diff --git a/examples/impala2/config.yaml b/examples/impala2/config.yaml new file mode 100644 index 00000000000..6957fd9bddd --- /dev/null +++ b/examples/impala2/config.yaml @@ -0,0 +1,36 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + num_envs: 8 + +# collector +collector: + frames_per_batch: 4096 + total_frames: 40_000_000 + +# logger +logger: + backend: wandb + exp_name: Atari_Schulman17 + test_interval: 40_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 2.5e-4 + eps: 1.0e-6 + weight_decay: 0.0 + max_grad_norm: 0.5 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + mini_batch_size: 1024 + ppo_epochs: 3 + gae_lambda: 0.95 + clip_epsilon: 0.1 + anneal_clip_epsilon: True + critic_coef: 1.0 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/impala2/impala.py b/examples/impala2/impala.py new file mode 100644 index 00000000000..8bfa3080a9f --- /dev/null +++ b/examples/impala2/impala.py @@ -0,0 +1,232 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the IMPALA Algorithm +results from Espeholt et al. 2018 for the on Atari Environments. +""" +import hydra + + +@hydra.main(config_path=".", config_name="config", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import ClipPPOLoss + from torchrl.objectives.value.vtrace import VTrace + from torchrl.record.loggers import generate_exp_name, get_logger + from utils import eval_model, make_parallel_env, make_ppo_models + + device = "cpu" if not torch.cuda.device_count() else "cuda" + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + mini_batch_size = cfg.loss.mini_batch_size // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Create models (check utils_atari.py) + actor, critic, critic_head = make_ppo_models(cfg.env.env_name) + actor, critic, critic_head = ( + actor.to(device), + critic.to(device), + critic_head.to(device), + ) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch), + sampler=sampler, + batch_size=mini_batch_size, + ) + + # Create loss and adv modules + adv_module = VTrace( + gamma=cfg.loss.gamma, + value_network=critic, + actor_network=actor, + average_adv=False, + ) + loss_module = ClipPPOLoss( + actor=actor, + critic=critic, + clip_epsilon=cfg.loss.clip_epsilon, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + normalize_advantage=True, + ) + + # Create optimizer + optim = torch.optim.Adam( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + ) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name( + "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" + ) + logger = get_logger( + cfg.logger.backend, logger_name="impala", experiment_name=exp_name + ) + + # Create test environment + test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=total_frames) + num_mini_batches = frames_per_batch // mini_batch_size + total_network_updates = ( + (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches + ) + + sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + # Apply episodic end of life + data["done"].copy_(data["end_of_life"]) + data["next", "done"].copy_(data["next", "end_of_life"]) + + losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches]) + training_start = time.time() + for j in range(cfg.loss.ppo_epochs): + + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for k, batch in enumerate(data_buffer): + + # Linearly decrease the learning rate and clip epsilon + alpha = 1 - (num_network_updates / total_network_updates) + if cfg.optim.anneal_lr: + for group in optim.param_groups: + group["lr"] = cfg.optim.lr * alpha + if cfg.loss.anneal_clip_epsilon: + loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha) + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device) + + # Forward pass PPO loss + loss = loss_module(batch) + losses[j, k] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * cfg.optim.lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + "train/clip_epsilon": alpha * cfg.loss.clip_epsilon, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: + actor.eval() + eval_start = time.time() + test_rewards = eval_model( + actor, test_env, num_episodes=cfg.logger.num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_rewards.mean(), + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/impala2/utils.py b/examples/impala2/utils.py new file mode 100644 index 00000000000..305318e0551 --- /dev/null +++ b/examples/impala2/utils.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import gymnasium as gym +import numpy as np +import torch.nn +import torch.optim +from tensordict.nn import TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.data.tensor_specs import DiscreteBox +from torchrl.envs import ( + CatFrames, + default_info_dict_reader, + DoubleToFloat, + EnvCreator, + ExplorationType, + GrayScale, + NoopResetEnv, + ParallelEnv, + Resize, + RewardClipping, + RewardSum, + StepCounter, + ToTensorImage, + TransformedEnv, + VecNorm, +) +from torchrl.envs.libs.gym import GymWrapper +from torchrl.modules import ( + ActorValueOperator, + ConvNet, + MLP, + OneHotCategorical, + ProbabilisticActor, + TanhNormal, + ValueOperator, +) + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +class EpisodicLifeEnv(gym.Wrapper): + def __init__(self, env): + """Make end-of-life == end-of-episode, but only reset on true game over. + Done by DeepMind for the DQN and co. It helps value estimation. + """ + gym.Wrapper.__init__(self, env) + self.lives = 0 + + def step(self, action): + obs, rew, done, truncated, info = self.env.step(action) + lives = self.env.unwrapped.ale.lives() + info["end_of_life"] = False + if (lives < self.lives) or done: + info["end_of_life"] = True + self.lives = lives + return obs, rew, done, truncated, info + + def reset(self, **kwargs): + reset_data = self.env.reset(**kwargs) + self.lives = self.env.unwrapped.ale.lives() + return reset_data + + +def make_base_env( + env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False +): + env = gym.make(env_name) + if not is_test: + env = EpisodicLifeEnv(env) + env = GymWrapper( + env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device + ) + env = TransformedEnv(env) + env.append_transform(NoopResetEnv(noops=30, random=True)) + if not is_test: + reader = default_info_dict_reader(["end_of_life"]) + env.set_info_dict_reader(reader) + return env + + +def make_parallel_env(env_name, num_envs, device, is_test=False): + env = ParallelEnv( + num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) + ) + env = TransformedEnv(env) + env.append_transform(ToTensorImage(from_int=False)) + env.append_transform(GrayScale()) + env.append_transform(Resize(84, 84)) + env.append_transform(CatFrames(N=4, dim=-3)) + env.append_transform(RewardSum()) + env.append_transform(StepCounter(max_steps=4500)) + if not is_test: + env.append_transform(RewardClipping(-1, 1)) + env.append_transform(DoubleToFloat()) + # env.append_transform(VecNorm(in_keys=["pixels"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_modules_pixels(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["pixels"].shape + + # Define distribution class and kwargs + if isinstance(proof_environment.action_spec.space, DiscreteBox): + num_outputs = proof_environment.action_spec.space.n + distribution_class = OneHotCategorical + distribution_kwargs = {} + else: # is ContinuousBox + num_outputs = proof_environment.action_spec.shape + distribution_class = TanhNormal + distribution_kwargs = { + "min": proof_environment.action_spec.space.minimum, + "max": proof_environment.action_spec.space.maximum, + } + + # Define input keys + in_keys = ["pixels"] + + # Define a shared Module and TensorDictModule (CNN + MLP) + common_cnn = ConvNet( + activation_class=torch.nn.ReLU, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + common_cnn_output = common_cnn(torch.ones(input_shape)) + common_mlp = MLP( + in_features=common_cnn_output.shape[-1], + activation_class=torch.nn.ReLU, + activate_last_layer=True, + out_features=512, + num_cells=[], + ) + common_mlp_output = common_mlp(common_cnn_output) + + # Define shared net as TensorDictModule + common_module = TensorDictModule( + module=torch.nn.Sequential(common_cnn, common_mlp), + in_keys=in_keys, + out_keys=["common_features"], + ) + + # Define on head for the policy + policy_net = MLP( + in_features=common_mlp_output.shape[-1], + out_features=num_outputs, + activation_class=torch.nn.ReLU, + num_cells=[], + ) + policy_module = TensorDictModule( + module=policy_net, + in_keys=["common_features"], + out_keys=["logits"], + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + policy_module, + in_keys=["logits"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define another head for the value + value_net = MLP( + activation_class=torch.nn.ReLU, + in_features=common_mlp_output.shape[-1], + out_features=1, + num_cells=[], + ) + value_module = ValueOperator( + value_net, + in_keys=["common_features"], + ) + + return common_module, policy_module, value_module + + +def make_ppo_models(env_name): + + proof_environment = make_parallel_env(env_name, 1, device="cpu") + common_module, policy_module, value_module = make_ppo_modules_pixels( + proof_environment + ) + + # Wrap modules in a single ActorCritic operator + actor_critic = ActorValueOperator( + common_operator=common_module, + policy_operator=policy_module, + value_operator=value_module, + ) + + with torch.no_grad(): + td = proof_environment.rollout(max_steps=100, break_when_any_done=False) + td = actor_critic(td) + del td + + actor = actor_critic.get_policy_operator() + critic = actor_critic.get_value_operator() + critic_head = actor_critic.get_value_head() + + del proof_environment + + return actor, critic, critic_head + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards = np.append(test_rewards, reward.cpu().numpy()) + del td_test + return test_rewards.mean() diff --git a/examples/impala3/config.yaml b/examples/impala3/config.yaml new file mode 100644 index 00000000000..dacfdcb1bf3 --- /dev/null +++ b/examples/impala3/config.yaml @@ -0,0 +1,36 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + num_envs: 8 + +# collector +collector: + frames_per_batch: 2560 + total_frames: 40_000_000 + +# logger +logger: + backend: wandb + exp_name: Atari_Schulman17 + test_interval: 40_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 0.0006 + eps: 1.0e-6 + weight_decay: 0.0 + max_grad_norm: 0.5 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + mini_batch_size: 2560 + ppo_epochs: 1 + gae_lambda: 0.95 + clip_epsilon: 0.1 + anneal_clip_epsilon: True + critic_coef: 0.5 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/impala3/impala.py b/examples/impala3/impala.py new file mode 100644 index 00000000000..15cb7808332 --- /dev/null +++ b/examples/impala3/impala.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the IMPALA Algorithm +results from Espeholt et al. 2018 for the on Atari Environments. +""" +import hydra + + +@hydra.main(config_path=".", config_name="config", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import A2CLoss, ClipPPOLoss + from torchrl.objectives.value.vtrace import VTrace + from torchrl.record.loggers import generate_exp_name, get_logger + from utils import eval_model, make_parallel_env, make_ppo_models + + device = "cpu" if not torch.cuda.device_count() else "cuda" + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + mini_batch_size = cfg.loss.mini_batch_size // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Create models (check utils_atari.py) + actor, critic, critic_head = make_ppo_models(cfg.env.env_name) + actor, critic, critic_head = ( + actor.to(device), + critic.to(device), + critic_head.to(device), + ) + + # Create collector + collector = SyncDataCollector( + create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + storing_device=device, + max_frames_per_traj=-1, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch), + sampler=sampler, + batch_size=mini_batch_size, + ) + + # Create loss and adv modules + adv_module = VTrace( + gamma=cfg.loss.gamma, + value_network=critic, + actor_network=actor, + average_adv=True, + ) + # loss_module = ClipPPOLoss( + # actor=actor, + # critic=critic, + # clip_epsilon=cfg.loss.clip_epsilon, + # loss_critic_type=cfg.loss.loss_critic_type, + # entropy_coef=cfg.loss.entropy_coef, + # critic_coef=cfg.loss.critic_coef, + # normalize_advantage=True, + # ) + loss_module = A2CLoss( + actor=actor, + critic=critic, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + ) + + # Create optimizer + optim = torch.optim.Adam( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + ) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name( + "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" + ) + logger = get_logger( + cfg.logger.backend, logger_name="impala", experiment_name=exp_name + ) + + # Create test environment + test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=total_frames) + num_mini_batches = frames_per_batch // mini_batch_size + total_network_updates = ( + (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches + ) + + sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + # Apply episodic end of life + data["done"].copy_(data["end_of_life"]) + data["next", "done"].copy_(data["next", "end_of_life"]) + + losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches]) + training_start = time.time() + for j in range(cfg.loss.ppo_epochs): + + # Compute GAE + with torch.no_grad(): + data = adv_module(data) + data_reshape = data.reshape(-1) + + # Update the data buffer + data_buffer.extend(data_reshape) + + for k, batch in enumerate(data_buffer): + + # Linearly decrease the learning rate and clip epsilon + alpha = 1 - (num_network_updates / total_network_updates) + if cfg.optim.anneal_lr: + for group in optim.param_groups: + group["lr"] = cfg.optim.lr * alpha + # if cfg.loss.anneal_clip_epsilon: + # loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha) + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device) + + # Forward pass PPO loss + loss = loss_module(batch) + losses[j, k] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * cfg.optim.lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + "train/clip_epsilon": alpha * cfg.loss.clip_epsilon, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: + actor.eval() + eval_start = time.time() + test_rewards = eval_model( + actor, test_env, num_episodes=cfg.logger.num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_rewards.mean(), + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/impala3/utils.py b/examples/impala3/utils.py new file mode 100644 index 00000000000..ccf6978a340 --- /dev/null +++ b/examples/impala3/utils.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import gymnasium as gym +import numpy as np +import torch.nn +import torch.optim +from tensordict.nn import TensorDictModule +from torchrl.data import CompositeSpec +from torchrl.data.tensor_specs import DiscreteBox +from torchrl.envs import ( + CatFrames, + default_info_dict_reader, + DoubleToFloat, + EnvCreator, + ExplorationType, + GrayScale, + NoopResetEnv, + ParallelEnv, + Resize, + RewardClipping, + RewardSum, + StepCounter, + ToTensorImage, + TransformedEnv, + VecNorm, +) +from torchrl.envs.libs.gym import GymWrapper +from torchrl.modules import ( + ActorValueOperator, + ConvNet, + MLP, + OneHotCategorical, + ProbabilisticActor, + TanhNormal, + ValueOperator, +) + +# ==================================================================== +# Environment utils +# -------------------------------------------------------------------- + + +class EpisodicLifeEnv(gym.Wrapper): + def __init__(self, env): + """Make end-of-life == end-of-episode, but only reset on true game over. + Done by DeepMind for the DQN and co. It helps value estimation. + """ + gym.Wrapper.__init__(self, env) + self.lives = 0 + + def step(self, action): + obs, rew, done, truncated, info = self.env.step(action) + lives = self.env.unwrapped.ale.lives() + info["end_of_life"] = False + if (lives < self.lives) or done: + info["end_of_life"] = True + self.lives = lives + return obs, rew, done, truncated, info + + def reset(self, **kwargs): + reset_data = self.env.reset(**kwargs) + self.lives = self.env.unwrapped.ale.lives() + return reset_data + + +def make_base_env( + env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False +): + env = gym.make(env_name) + if not is_test: + env = EpisodicLifeEnv(env) + env = GymWrapper( + env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device + ) + env = TransformedEnv(env) + env.append_transform(NoopResetEnv(noops=30, random=True)) + if not is_test: + reader = default_info_dict_reader(["end_of_life"]) + env.set_info_dict_reader(reader) + return env + + +def make_parallel_env(env_name, num_envs, device, is_test=False): + env = ParallelEnv( + num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) + ) + env = TransformedEnv(env) + env.append_transform(ToTensorImage()) + env.append_transform(GrayScale()) + env.append_transform(Resize(84, 84)) + env.append_transform(CatFrames(N=4, dim=-3)) + env.append_transform(RewardSum()) + env.append_transform(StepCounter(max_steps=4500)) + if not is_test: + env.append_transform(RewardClipping(-1, 1)) + env.append_transform(DoubleToFloat()) + env.append_transform(VecNorm(in_keys=["pixels"])) + return env + + +# ==================================================================== +# Model utils +# -------------------------------------------------------------------- + + +def make_ppo_modules_pixels(proof_environment): + + # Define input shape + input_shape = proof_environment.observation_spec["pixels"].shape + + # Define distribution class and kwargs + if isinstance(proof_environment.action_spec.space, DiscreteBox): + num_outputs = proof_environment.action_spec.space.n + distribution_class = OneHotCategorical + distribution_kwargs = {} + else: # is ContinuousBox + num_outputs = proof_environment.action_spec.shape + distribution_class = TanhNormal + distribution_kwargs = { + "min": proof_environment.action_spec.space.minimum, + "max": proof_environment.action_spec.space.maximum, + } + + # Define input keys + in_keys = ["pixels"] + + # Define a shared Module and TensorDictModule (CNN + MLP) + common_cnn = ConvNet( + activation_class=torch.nn.ReLU, + num_cells=[32, 64, 64], + kernel_sizes=[8, 4, 3], + strides=[4, 2, 1], + ) + common_cnn_output = common_cnn(torch.ones(input_shape)) + common_mlp = MLP( + in_features=common_cnn_output.shape[-1], + activation_class=torch.nn.ReLU, + activate_last_layer=True, + out_features=512, + num_cells=[], + ) + common_mlp_output = common_mlp(common_cnn_output) + + # Define shared net as TensorDictModule + common_module = TensorDictModule( + module=torch.nn.Sequential(common_cnn, common_mlp), + in_keys=in_keys, + out_keys=["common_features"], + ) + + # Define on head for the policy + policy_net = MLP( + in_features=common_mlp_output.shape[-1], + out_features=num_outputs, + activation_class=torch.nn.ReLU, + num_cells=[], + ) + policy_module = TensorDictModule( + module=policy_net, + in_keys=["common_features"], + out_keys=["logits"], + ) + + # Add probabilistic sampling of the actions + policy_module = ProbabilisticActor( + policy_module, + in_keys=["logits"], + spec=CompositeSpec(action=proof_environment.action_spec), + distribution_class=distribution_class, + distribution_kwargs=distribution_kwargs, + return_log_prob=True, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define another head for the value + value_net = MLP( + activation_class=torch.nn.ReLU, + in_features=common_mlp_output.shape[-1], + out_features=1, + num_cells=[], + ) + value_module = ValueOperator( + value_net, + in_keys=["common_features"], + ) + + return common_module, policy_module, value_module + + +def make_ppo_models(env_name): + + proof_environment = make_parallel_env(env_name, 1, device="cpu") + common_module, policy_module, value_module = make_ppo_modules_pixels( + proof_environment + ) + + # Wrap modules in a single ActorCritic operator + actor_critic = ActorValueOperator( + common_operator=common_module, + policy_operator=policy_module, + value_operator=value_module, + ) + + with torch.no_grad(): + td = proof_environment.rollout(max_steps=100, break_when_any_done=False) + td = actor_critic(td) + del td + + actor = actor_critic.get_policy_operator() + critic = actor_critic.get_value_operator() + critic_head = actor_critic.get_value_head() + + del proof_environment + + return actor, critic, critic_head + + +# ==================================================================== +# Evaluation utils +# -------------------------------------------------------------------- + + +def eval_model(actor, test_env, num_episodes=3): + test_rewards = [] + for _ in range(num_episodes): + td_test = test_env.rollout( + policy=actor, + auto_reset=True, + auto_cast_to_device=True, + break_when_any_done=True, + max_steps=10_000_000, + ) + reward = td_test["next", "episode_reward"][td_test["next", "done"]] + test_rewards = np.append(test_rewards, reward.cpu().numpy()) + del td_test + return test_rewards.mean() diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 23837d94acc..ea2c715d927 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -354,17 +354,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: advantage = tensordict.get(self.tensor_keys.advantage) log_probs, dist = self._log_probs(tensordict) loss = -(log_probs * advantage) - # td_out = TensorDict({"loss_objective": loss.mean()}, []) - td_out = TensorDict({"loss_objective": loss.sum()}, []) + td_out = TensorDict({"loss_objective": loss.mean()}, []) if self.entropy_bonus: entropy = self.get_entropy_bonus(dist) td_out.set("entropy", entropy.mean().detach()) # for logging - # td_out.set("loss_entropy", -self.entropy_coef * entropy.mean()) - td_out.set("loss_entropy", -self.entropy_coef * entropy.sum()) + td_out.set("loss_entropy", -self.entropy_coef * entropy.mean()) if self.critic_coef: loss_critic = self.loss_critic(tensordict).mean() - # td_out.set("loss_critic", loss_critic.mean()) - td_out.set("loss_critic", loss_critic.sum()) + td_out.set("loss_critic", loss_critic.mean()) return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): diff --git a/torchrl/objectives/value/__init__.py b/torchrl/objectives/value/__init__.py index 11ae2e6d9e2..51496986153 100644 --- a/torchrl/objectives/value/__init__.py +++ b/torchrl/objectives/value/__init__.py @@ -12,4 +12,5 @@ TDLambdaEstimate, TDLambdaEstimator, ValueEstimatorBase, + VTrace, ) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 31c8c291c5b..6ea06b01fe7 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + + import abc import functools import warnings @@ -32,8 +34,10 @@ vec_generalized_advantage_estimate, vec_td1_return_estimate, vec_td_lambda_return_estimate, + vtrace_advantage_estimate, ) + try: from torch import vmap except ImportError as err: @@ -1260,6 +1264,246 @@ def value_estimate( return value_target +class VTrace(ValueEstimatorBase): + """A class wrapper around V-Trace estimate functional. + + Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" + https://arxiv.org/abs/1802.01561 for more context. + + Args: + gamma (scalar): exponential mean discount. + value_network (TensorDictModule): value operator used to retrieve the value estimates. + actor_network (TensorDictModule, optional): actor operator used to retrieve the log prob. + average_adv (bool): if ``True``, the resulting advantage values will be standardized. + Default is ``False``. + differentiable (bool, optional): if ``True``, gradients are propagated through + the computation of the value function. Default is ``False``. + + .. note:: + The proper way to make the function call non-differentiable is to + decorate it in a `torch.no_grad()` context manager/decorator or + pass detached parameters for functional modules. + skip_existing (bool, optional): if ``True``, the value network will skip + modules which outputs are already present in the tensordict. + Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` + is not affected. + Defaults to "state_value". + advantage_key (str or tuple of str, optional): [Deprecated] the key of + the advantage entry. Defaults to ``"advantage"``. + value_target_key (str or tuple of str, optional): [Deprecated] the key + of the advantage entry. Defaults to ``"value_target"``. + value_key (str or tuple of str, optional): [Deprecated] the value key to + read from the input tensordict. Defaults to ``"state_value"``. + shifted (bool, optional): if ``True``, the value and next value are + estimated with a single call to the value network. This is faster + but is only valid whenever (1) the ``"next"`` value is shifted by + only one time step (which is not the case with multi-step value + estimation, for instance) and (2) when the parameters used at time + ``t`` and ``t+1`` are identical (which is not the case when target + parameters are to be used). Defaults to ``False``. + + VTrace will return an :obj:`"advantage"` entry containing the advantage value. It will also + return a :obj:`"value_target"` entry with the V-Trace target value. + + .. note:: + As other advantage functions do, if the ``value_key`` is already present + in the input tensordict, the VTrace module will ignore the calls to the value + network (if any) and use the provided value instead. + + """ + + def __init__( + self, + *, + gamma: Union[float, torch.Tensor], + rho_thresh: Union[float, torch.Tensor] = 1.0, + c_thresh: Union[float, torch.Tensor] = 1.0, + actor_network: TensorDictModule = None, + value_network: TensorDictModule, + average_adv: bool = False, + differentiable: bool = False, + skip_existing: Optional[bool] = None, + log_prob_key: NestedKey = "sample_log_prob", # Consider adding it to _AcceptedKeys? + advantage_key: NestedKey = None, + value_target_key: NestedKey = None, + value_key: NestedKey = None, + shifted: bool = False, + ): + super().__init__( + shifted=shifted, + value_network=value_network, + differentiable=differentiable, + advantage_key=advantage_key, + value_target_key=value_target_key, + value_key=value_key, + skip_existing=skip_existing, + ) + try: + device = next(value_network.parameters()).device + except (AttributeError, StopIteration): + device = torch.device("cpu") + + if not isinstance(gamma, torch.Tensor): + gamma = torch.tensor(gamma, device=device) + if not isinstance(rho_thresh, torch.Tensor): + rho_thresh = torch.tensor(rho_thresh, device=device) + if not isinstance(c_thresh, torch.Tensor): + c_thresh = torch.tensor(c_thresh, device=device) + + self.register_buffer("gamma", gamma) + self.register_buffer("rho_thresh", rho_thresh) + self.register_buffer("c_thresh", c_thresh) + self.average_adv = average_adv + self.actor_network = actor_network + self._log_prob_key = log_prob_key + + if isinstance(gamma, torch.Tensor) and gamma.shape != (): + raise NotImplementedError( + "Per-value gamma is not supported yet. Gamma must be a scalar." + ) + + @property + def log_prob_key(self): + return self._log_prob_key + + @_self_set_skip_existing + @_self_set_grad_enabled + @dispatch + def forward( + self, + tensordict: TensorDictBase, + *unused_args, + params: Optional[List[Tensor]] = None, + target_params: Optional[List[Tensor]] = None, + ) -> TensorDictBase: + """Computes the V-Trace correction given the data in tensordict. + + If a functional module is provided, a nested TensorDict containing the parameters + (and if relevant the target parameters) can be passed to the module. + + Args: + tensordict (TensorDictBase): A TensorDict containing the data + (an observation key, "action", "reward", "done" and "next" tensordict state + as returned by the environment) necessary to compute the value estimates and the GAE. + The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are + the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). + params (TensorDictBase, optional): A nested TensorDict containing the params + to be passed to the functional value network module. + target_params (TensorDictBase, optional): A nested TensorDict containing the + target params to be passed to the functional value network module. + + Returns: + An updated TensorDict with an advantage and a value_error keys as defined in the constructor. + + Examples: + >>> from tensordict import TensorDict + >>> value_net = TensorDictModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = VTrace( + ... gamma=0.98, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward}, [1, 10]) + >>> _ = module(tensordict) + >>> assert "advantage" in tensordict.keys() + + The module supports non-tensordict (i.e. unpacked tensordict) inputs too: + + Examples: + >>> value_net = TensorDictModule( + ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + ... ) + >>> module = VTrace( + ... gamma=0.98, + ... value_network=value_net, + ... differentiable=False, + ... ) + >>> obs, next_obs = torch.randn(2, 1, 10, 3) + >>> reward = torch.randn(1, 10, 1) + >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) + + """ + if tensordict.batch_dims < 1: + raise RuntimeError( + "Expected input tensordict to have at least one dimensions, got " + f"tensordict.batch_size = {tensordict.batch_size}" + ) + reward = tensordict.get(("next", self.tensor_keys.reward)) + device = reward.device + gamma = self.gamma.to(device) + steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) + if steps_to_next_obs is not None: + gamma = gamma ** steps_to_next_obs.view_as(reward) + + # Make sure we have the value and next value + if self.value_network is not None: + if params is not None: + params = params.detach() + if target_params is None: + target_params = params.clone(False) + with hold_out_net(self.value_network): + # we may still need to pass gradient, but we don't want to assign grads to + # value net params + value, next_value = _call_value_nets( + value_net=self.value_network, + data=tensordict, + params=params, + next_params=target_params, + single_call=self.shifted, + value_key=self.tensor_keys.value, + detach_next=True, + ) + else: + value = tensordict.get(self.tensor_keys.value) + next_value = tensordict.get(("next", self.tensor_keys.value)) + + # Make sure we have the log prob computed at collection time + if self.log_prob_key not in tensordict.keys(): + raise ValueError(f"Expected {self.log_prob_key} to be in tensordict") + log_mu = tensordict.get(self.log_prob_key).view_as(value) + + # Compute log prob with current policy + with hold_out_net(self.actor_network): + log_pi = ( + self.actor_network(tensordict.select(self.actor_network.in_keys)) + .get(self.log_prob_key) + .view_as(value) + ) + + # Compute the V-Trace correction + done = tensordict.get(("next", self.tensor_keys.done)) + + adv, value_target = vtrace_advantage_estimate( + gamma, + log_pi, + log_mu, + value, + next_value, + reward, + done, + rho_thresh=self.rho_thresh, + c_thresh=self.c_thresh, + time_dim=tensordict.ndim - 1, + ) + + if self.average_adv: + loc = adv.mean() + scale = adv.std().clamp_min(1e-5) + adv = adv - loc + adv = adv / scale + + tensordict.set(self.tensor_keys.advantage, adv) + tensordict.set(self.tensor_keys.value_target, value_target) + + return tensordict + + def _deprecate_class(cls, new_cls): @wraps(cls.__init__) def new_init(self, *args, **kwargs): diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index ccd0966bbf6..f9c7af3a2ea 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -24,6 +24,7 @@ "vec_td_lambda_return_estimate", "td_lambda_advantage_estimate", "vec_td_lambda_advantage_estimate", + "vtrace_advantage_estimate", ] from torchrl.objectives.value.utils import ( @@ -1052,6 +1053,72 @@ def vec_td_lambda_advantage_estimate( ) +@_transpose_time +def vtrace_advantage_estimate( + gamma: float, + log_pi: torch.Tensor, + log_mu: torch.Tensor, + state_value: torch.Tensor, + next_state_value: torch.Tensor, + reward: torch.Tensor, + done: torch.Tensor, + rho_thresh: Union[float, torch.Tensor] = 1.0, + c_thresh: Union[float, torch.Tensor] = 1.0, + time_dim: int = -2, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes V-Trace off-policy actor critic targets. + + Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" + https://arxiv.org/abs/1802.01561 for more context. + + Args: + gamma (scalar): exponential mean discount. + log_pi (Tensor): collection actor log probability of taking actions in the environment. + log_mu (Tensor): current actor log probability of taking actions in the environment. + state_value (Tensor): value function result with state input. + next_state_value (Tensor): value function result with next_state input. + reward (Tensor): reward of taking actions in the environment. + done (Tensor): boolean flag for end of episode. + rho_thresh (Union[float, Tensor]): clipping parameter for importance weights. + c_thresh (Union[float, Tensor]): clipping parameter for importance weights. + time_dim (int): dimension where the time is unrolled. Defaults to -2. + + All tensors (values, reward and done) must have shape + ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. + """ + if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): + raise RuntimeError(SHAPE_ERR) + + device = state_value.device + + not_done = (~done).int() + *batch_size, time_steps, lastdim = not_done.shape + discounts = gamma * not_done + + clipped_rho = (log_pi - log_mu).exp().clamp_max(rho_thresh) + deltas = clipped_rho * (reward + discounts * next_state_value - state_value) + c_thresh = c_thresh.to(device) + clipped_c = torch.clamp(c_thresh, max=clipped_rho) + + vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])] + for i in reversed(range(time_steps)): + discount_t, c_t, delta_t = ( + discounts[..., i, :], + clipped_c[..., i, :], + deltas[..., i, :], + ) + vs_minus_v_xs.append(delta_t + discount_t * c_t * vs_minus_v_xs[-1]) + vs_minus_v_xs = torch.stack(vs_minus_v_xs[1:], dim=time_dim) + vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[time_dim]) + vs = vs_minus_v_xs + state_value + vs_t_plus_1 = torch.cat( + [vs[..., 1:, :], next_state_value[..., -1:, :]], dim=time_dim + ) + advantages = clipped_rho * (reward + discounts * vs_t_plus_1 - state_value) + + return advantages, vs + + ######################################################################## # Reward to go # ------------ diff --git a/torchrl/objectives/value/vtrace.py b/torchrl/objectives/value/vtrace.py deleted file mode 100644 index 8e00d5f3d53..00000000000 --- a/torchrl/objectives/value/vtrace.py +++ /dev/null @@ -1,346 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import math -from typing import List, Optional, Union, Tuple - -import torch -from tensordict.nn import ( - dispatch, - is_functional, - set_skip_existing, - TensorDictModule, - TensorDictModuleBase, -) -from tensordict.tensordict import TensorDictBase -from tensordict.utils import NestedKey -from torch import nn, Tensor -from torchrl.objectives.utils import hold_out_net -from torchrl.objectives.value.advantages import ( - ValueEstimatorBase, - _self_set_skip_existing, - _self_set_grad_enabled, - _call_value_nets) -from torchrl.objectives.value.functional import _transpose_time, SHAPE_ERR - - -def _c_val( - log_pi: torch.Tensor, - log_mu: torch.Tensor, - c: Union[float, torch.Tensor] = 1, -) -> torch.Tensor: - return (log_pi - log_mu).clamp_max(math.log(c)).exp() # TODO: is unsqueeze needed? - -def _dv_val( - rewards: torch.Tensor, - vals: torch.Tensor, - next_vals: torch.Tensor, - discount: Union[float, torch.Tensor], - rho_thresh: Union[float, torch.Tensor], - log_pi: torch.Tensor, - log_mu: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - clipped_rho = _c_val(log_pi, log_mu, rho_thresh) - deltas = clipped_rho * (rewards + discount * next_vals - vals) - return deltas, clipped_rho - - -@_transpose_time -def vtrace_correction( - gamma: float, - log_pi: torch.Tensor, - log_mu: torch.Tensor, - state_value: torch.Tensor, - next_state_value: torch.Tensor, - reward: torch.Tensor, - done: torch.Tensor, - rho_thresh: Union[float, torch.Tensor] = 1.0, - c_thresh: Union[float, torch.Tensor] = 1.0, - time_dim: int = -2, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes V-Trace off-policy actor critic targets. - - Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" - https://arxiv.org/abs/1802.01561 for more context. - - Args: - gamma (scalar): exponential mean discount. - log_pi (Tensor): collection actor log probability of taking actions in the environment. - log_mu (Tensor): current actor log probability of taking actions in the environment. - state_value (Tensor): value function result with state input. - next_state_value (Tensor): value function result with next_state input. - reward (Tensor): reward of taking actions in the environment. - done (Tensor): boolean flag for end of episode. - rho_thresh (Union[float, Tensor]): clipping parameter for importance weights. - c_thresh (Union[float, Tensor]): clipping parameter for importance weights. - time_dim (int): dimension where the time is unrolled. Defaults to -2. - - All tensors (values, reward and done) must have shape - ``[*Batch x TimeSteps x *F]``, with ``*F`` feature dimensions. - - """ - - if not (next_state_value.shape == state_value.shape == reward.shape == done.shape): - raise RuntimeError(SHAPE_ERR) - - dtype = next_state_value.dtype - device = state_value.device - - not_done = (~done).int() - *batch_size, time_steps, lastdim = not_done.shape - discounts = gamma * not_done - - deltas, clipped_rho = _dv_val(reward, state_value, next_state_value, discounts, rho_thresh, log_pi, log_mu) - c_thresh = c_thresh.to(device) - clipped_c = torch.min(c_thresh, clipped_rho) - - vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])] - for i in reversed(range(time_steps)): - discount_t, c_t, delta_t = discounts[..., i, :], clipped_c[..., i, :], deltas[..., i, :] - vs_minus_v_xs.append(delta_t + discount_t * c_t * vs_minus_v_xs[-1]) - vs_minus_v_xs = torch.stack(vs_minus_v_xs[1:], dim=time_dim) - vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[time_dim]) - vs = vs_minus_v_xs + state_value - vs_t_plus_1 = torch.cat([vs[..., 1:, :], next_state_value[..., -1:, :]], dim=time_dim) - advantages = clipped_rho * (reward + discounts * vs_t_plus_1 - state_value) - - return advantages, vs - - -class VTrace(ValueEstimatorBase): - """A class wrapper around V-Trace estimate functional. - - Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" - https://arxiv.org/abs/1802.01561 for more context. - - Args: - gamma (scalar): exponential mean discount. - value_network (TensorDictModule): value operator used to retrieve the value estimates. - actor_network (TensorDictModule, optional): actor operator used to retrieve the log prob. - average_adv (bool): if ``True``, the resulting advantage values will be standardized. - Default is ``False``. - differentiable (bool, optional): if ``True``, gradients are propagated through - the computation of the value function. Default is ``False``. - - .. note:: - The proper way to make the function call non-differentiable is to - decorate it in a `torch.no_grad()` context manager/decorator or - pass detached parameters for functional modules. - skip_existing (bool, optional): if ``True``, the value network will skip - modules which outputs are already present in the tensordict. - Defaults to ``None``, ie. the value of :func:`tensordict.nn.skip_existing()` - is not affected. - Defaults to "state_value". - advantage_key (str or tuple of str, optional): [Deprecated] the key of - the advantage entry. Defaults to ``"advantage"``. - value_target_key (str or tuple of str, optional): [Deprecated] the key - of the advantage entry. Defaults to ``"value_target"``. - value_key (str or tuple of str, optional): [Deprecated] the value key to - read from the input tensordict. Defaults to ``"state_value"``. - shifted (bool, optional): if ``True``, the value and next value are - estimated with a single call to the value network. This is faster - but is only valid whenever (1) the ``"next"`` value is shifted by - only one time step (which is not the case with multi-step value - estimation, for instance) and (2) when the parameters used at time - ``t`` and ``t+1`` are identical (which is not the case when target - parameters are to be used). Defaults to ``False``. - - VTrace will return an :obj:`"advantage"` entry containing the advantage value. It will also - return a :obj:`"value_target"` entry with the V-Trace target value. - - .. note:: - As other advantage functions do, if the ``value_key`` is already present - in the input tensordict, the VTrace module will ignore the calls to the value - network (if any) and use the provided value instead. - - """ - - def __init__( - self, - *, - gamma: Union[float, torch.Tensor], - rho_thresh: Union[float, torch.Tensor] = 1.0, - c_thresh: Union[float, torch.Tensor] = 1.0, - actor_network: TensorDictModule = None, - value_network: TensorDictModule, - average_adv: bool = False, - differentiable: bool = False, - skip_existing: Optional[bool] = None, - log_prob_key: NestedKey = "sample_log_prob", # Consider adding it to _AcceptedKeys? - advantage_key: NestedKey = None, - value_target_key: NestedKey = None, - value_key: NestedKey = None, - shifted: bool = False, - ): - super().__init__( - shifted=shifted, - value_network=value_network, - differentiable=differentiable, - advantage_key=advantage_key, - value_target_key=value_target_key, - value_key=value_key, - skip_existing=skip_existing, - ) - try: - device = next(value_network.parameters()).device - except (AttributeError, StopIteration): - device = torch.device("cpu") - - if not isinstance(gamma, torch.Tensor): - gamma = torch.tensor(gamma, device=device) - if not isinstance(rho_thresh, torch.Tensor): - rho_thresh = torch.tensor(rho_thresh, device=device) - if not isinstance(c_thresh, torch.Tensor): - c_thresh = torch.tensor(c_thresh, device=device) - - self.register_buffer("gamma", gamma) - self.register_buffer("rho_thresh", rho_thresh) - self.register_buffer("c_thresh", c_thresh) - self.average_adv = average_adv - self.actor_network = actor_network - self._log_prob_key = log_prob_key - - if isinstance(gamma, torch.Tensor) and gamma.shape != (): - raise NotImplementedError("Per-value gamma is not supported yet. Gamma must be a scalar.") - - @property - def log_prob_key(self): - return self._log_prob_key - - @_self_set_skip_existing - @_self_set_grad_enabled - @dispatch - def forward( - self, - tensordict: TensorDictBase, - *unused_args, - params: Optional[List[Tensor]] = None, - target_params: Optional[List[Tensor]] = None, - ) -> TensorDictBase: - """Computes the V-Trace correction given the data in tensordict. - - If a functional module is provided, a nested TensorDict containing the parameters - (and if relevant the target parameters) can be passed to the module. - - Args: - tensordict (TensorDictBase): A TensorDict containing the data - (an observation key, "action", "reward", "done" and "next" tensordict state - as returned by the environment) necessary to compute the value estimates and the GAE. - The data passed to this module should be structured as :obj:`[*B, T, F]` where :obj:`B` are - the batch size, :obj:`T` the time dimension and :obj:`F` the feature dimension(s). - params (TensorDictBase, optional): A nested TensorDict containing the params - to be passed to the functional value network module. - target_params (TensorDictBase, optional): A nested TensorDict containing the - target params to be passed to the functional value network module. - - Returns: - An updated TensorDict with an advantage and a value_error keys as defined in the constructor. - - Examples: - >>> from tensordict import TensorDict - >>> value_net = TensorDictModule( - ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] - ... ) - >>> module = VTrace( - ... gamma=0.98, - ... value_network=value_net, - ... differentiable=False, - ... ) - >>> obs, next_obs = torch.randn(2, 1, 10, 3) - >>> reward = torch.randn(1, 10, 1) - >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward}, [1, 10]) - >>> _ = module(tensordict) - >>> assert "advantage" in tensordict.keys() - - The module supports non-tensordict (i.e. unpacked tensordict) inputs too: - - Examples: - >>> value_net = TensorDictModule( - ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] - ... ) - >>> module = VTrace( - ... gamma=0.98, - ... value_network=value_net, - ... differentiable=False, - ... ) - >>> obs, next_obs = torch.randn(2, 1, 10, 3) - >>> reward = torch.randn(1, 10, 1) - >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) - - """ - if tensordict.batch_dims < 1: - raise RuntimeError( - "Expected input tensordict to have at least one dimensions, got " - f"tensordict.batch_size = {tensordict.batch_size}" - ) - reward = tensordict.get(("next", self.tensor_keys.reward)) - device = reward.device - gamma = self.gamma.to(device) - steps_to_next_obs = tensordict.get(self.tensor_keys.steps_to_next_obs, None) - if steps_to_next_obs is not None: - gamma = gamma ** steps_to_next_obs.view_as(reward) - - # Make sure we have the value and next value - if self.value_network is not None: - if params is not None: - params = params.detach() - if target_params is None: - target_params = params.clone(False) - with hold_out_net(self.value_network): - # we may still need to pass gradient, but we don't want to assign grads to - # value net params - value, next_value = _call_value_nets( - value_net=self.value_network, - data=tensordict, - params=params, - next_params=target_params, - single_call=self.shifted, - value_key=self.tensor_keys.value, - detach_next=True, - ) - else: - value = tensordict.get(self.tensor_keys.value) - next_value = tensordict.get(("next", self.tensor_keys.value)) - - # Make sure we have the log prob computed at collection time - if self.log_prob_key not in tensordict.keys(): - raise ValueError(f"Expected {self.log_prob_key} to be in tensordict") - log_mu = tensordict.get(self.log_prob_key).reshape_as(value) - - # Compute log prob with current policy - with hold_out_net(self.actor_network): - log_pi = self.actor_network( - tensordict.select(self.actor_network.in_keys) - ).get(self.log_prob_key).reshape_as(value) - - # Compute the V-Trace correction - done = tensordict.get(("next", self.tensor_keys.done)) - - adv, value_target = vtrace_correction( - gamma, - log_pi, - log_mu, - value, - next_value, - reward, - done, - rho_thresh=self.rho_thresh, - c_thresh=self.c_thresh, - time_dim=tensordict.ndim - 1, - ) - - if self.average_adv: - loc = adv.mean() - scale = adv.std().clamp_min(1e-6) - adv = adv - loc - adv = adv / scale - - tensordict.set(self.tensor_keys.advantage, adv) - tensordict.set(self.tensor_keys.value_target, value_target) - - return tensordict - From 8648e104d797a1575fd381a9504b2b4ca30df499 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 21 Sep 2023 20:13:30 +0200 Subject: [PATCH 025/109] tests --- examples/impala2/config.yaml | 4 +- examples/impala2/impala.py | 7 +- examples/impala2/utils.py | 2 +- test/test_cost.py | 94 +++++++++++++++++++++----- torchrl/objectives/a2c.py | 4 +- torchrl/objectives/common.py | 4 ++ torchrl/objectives/ppo.py | 4 +- torchrl/objectives/utils.py | 1 + torchrl/objectives/value/advantages.py | 25 +++---- torchrl/objectives/value/functional.py | 4 +- 10 files changed, 107 insertions(+), 42 deletions(-) diff --git a/examples/impala2/config.yaml b/examples/impala2/config.yaml index 6957fd9bddd..7eb1cd612a8 100644 --- a/examples/impala2/config.yaml +++ b/examples/impala2/config.yaml @@ -5,7 +5,7 @@ env: # collector collector: - frames_per_batch: 4096 + frames_per_batch: 1024 total_frames: 40_000_000 # logger @@ -27,7 +27,7 @@ optim: loss: gamma: 0.99 mini_batch_size: 1024 - ppo_epochs: 3 + ppo_epochs: 1 gae_lambda: 0.95 clip_epsilon: 0.1 anneal_clip_epsilon: True diff --git a/examples/impala2/impala.py b/examples/impala2/impala.py index 8bfa3080a9f..acb79379744 100644 --- a/examples/impala2/impala.py +++ b/examples/impala2/impala.py @@ -19,7 +19,7 @@ def main(cfg: "DictConfig"): # noqa: F821 import tqdm from tensordict import TensorDict - from torchrl.collectors import SyncDataCollector + from torchrl.collectors import MultiaSyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type @@ -46,14 +46,15 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create collector - collector = SyncDataCollector( - create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), + collector = MultiaSyncDataCollector( + create_env_fn=[make_parallel_env(cfg.env.env_name, 8, device)] * 4, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, device=device, storing_device=device, max_frames_per_traj=-1, + update_at_each_batch=True, ) # Create data buffer diff --git a/examples/impala2/utils.py b/examples/impala2/utils.py index 305318e0551..2ce03e50c35 100644 --- a/examples/impala2/utils.py +++ b/examples/impala2/utils.py @@ -97,7 +97,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): if not is_test: env.append_transform(RewardClipping(-1, 1)) env.append_transform(DoubleToFloat()) - # env.append_transform(VecNorm(in_keys=["pixels"])) + env.append_transform(VecNorm(in_keys=["pixels"])) return env diff --git a/test/test_cost.py b/test/test_cost.py index 62f8a123792..be07b19924b 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -127,6 +127,7 @@ from torchrl.objectives.value.advantages import ( _call_value_nets, GAE, + VTrace, TD1Estimator, TDLambdaEstimator, ) @@ -432,7 +433,7 @@ def test_dqn(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = DQNLoss(actor, loss_function="l2", delay_value=delay_value) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -887,7 +888,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -1336,7 +1337,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): delay_actor=delay_actor, delay_value=delay_value, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -1907,7 +1908,7 @@ def test_td3( delay_actor=delay_actor, delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est is (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -2523,7 +2524,7 @@ def test_sac( **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -3197,7 +3198,7 @@ def test_discrete_sac( loss_function="l2", **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -3745,7 +3746,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4086,7 +4087,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4103,7 +4104,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_deprec.make_value_estimator(td_est) return @@ -4509,7 +4510,7 @@ def test_cql( **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4918,7 +4919,7 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): @@ -4931,6 +4932,10 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, value_network=value, actor_network=actor, differentiable=gradient_mode + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -4984,7 +4989,7 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): actor.zero_grad() @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace","td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_ppo_shared(self, loss_class, device, advantage): torch.manual_seed(self.seed) @@ -4997,6 +5002,12 @@ def test_ppo_shared(self, loss_class, device, advantage): lmbda=0.9, value_network=value, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -5058,6 +5069,7 @@ def test_ppo_shared(self, loss_class, device, advantage): "advantage", ( "gae", + "vtrace" "td", "td_lambda", ), @@ -5077,6 +5089,12 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses): lmbda=0.9, value_network=value, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -5128,7 +5146,7 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses): ) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): @@ -5142,6 +5160,10 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, value_network=value, actor_network=actor, differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -5204,6 +5226,7 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, + ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -5271,6 +5294,13 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): value_network=value, differentiable=gradient_mode, ) + if advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -5559,7 +5589,7 @@ def _create_seq_mock_data_a2c( return td @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_a2c(self, device, gradient_mode, advantage, td_est): @@ -5572,6 +5602,8 @@ def test_a2c(self, device, gradient_mode, advantage, td_est): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace(gamma=0.9, value_network=value, actor_network=actor, differentiable=gradient_mode) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -5685,7 +5717,7 @@ def test_a2c_separate_losses(self, separate_losses): not _has_functorch, reason=f"functorch not found, {FUNCTORCH_ERR}" ) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_a2c_diff(self, device, gradient_mode, advantage): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): @@ -5699,6 +5731,8 @@ def test_a2c_diff(self, device, gradient_mode, advantage): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace(gamma=0.9, value_network=value, actor_network=actor, differentiable=gradient_mode) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -5752,6 +5786,7 @@ def test_a2c_diff(self, device, gradient_mode, advantage): ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, + ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -5920,7 +5955,7 @@ class TestReinforce(LossModuleTestBase): @pytest.mark.parametrize("delay_value", [True, False]) @pytest.mark.parametrize("gradient_mode", [True, False]) - @pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None]) + @pytest.mark.parametrize("advantage", ["gae", "vtrace", "td", "td_lambda", None]) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est): n_obs = 3 @@ -5946,6 +5981,13 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est value_network=get_functional(value_net), differentiable=gradient_mode, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=get_functional(value_net), + actor_network=get_functional(actor_net), + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=gamma, @@ -6029,6 +6071,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, + ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -6623,7 +6666,7 @@ def test_dreamer_actor(self, device, imagination_horizon, discount_loss, td_est) imagination_horizon=imagination_horizon, discount_loss=discount_loss, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_module.make_value_estimator(td_est) return @@ -7333,7 +7376,7 @@ def test_iql( expectile=expectile, loss_function="l2", ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -9267,6 +9310,7 @@ class TestAdv: "adv,kwargs", [ [GAE, {"lmbda": 0.95}], + [VTrace, {}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], ], @@ -9279,6 +9323,20 @@ def test_dispatch( value_net = TensorDictModule( nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) + if adv == VTrace: + n_obs = 3 + n_act = 5 + net = NormalParamWrapper(nn.Linear(n_obs, 2 * n_act)) + module = TensorDictModule( + net, in_keys=["observation"], out_keys=["loc", "scale"] + ) + actor_net = ProbabilisticActor( + module, + distribution_class=TanhNormal, + return_log_prob=True, + in_keys=["loc", "scale"], + spec=UnboundedContinuousTensorSpec(n_act), + ) module = adv( gamma=0.98, value_network=value_net, diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index ea2c715d927..fa1225a3ff2 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -20,7 +20,7 @@ distance_loss, ValueEstimators, ) -from torchrl.objectives.value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator +from torchrl.objectives.value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator, VTrace class A2CLoss(LossModule): @@ -380,6 +380,8 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = GAE(value_network=self.critic, **hp) elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + elif value_type == ValueEstimators.VTrace: + self._value_estimator = VTrace(value_network=self.critic, **hp) else: raise NotImplementedError(f"Unknown value type {value_type}") diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index a5fa8c592d5..a2561359ee4 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -447,6 +447,10 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams raise NotImplementedError( f"Value type {value_type} it not implemented for loss {type(self)}." ) + elif value_type == ValueEstimators.VTrace: + raise NotImplementedError( + f"Value type {value_type} it not implemented for loss {type(self)}." + ) elif value_type == ValueEstimators.TDLambda: raise NotImplementedError( f"Value type {value_type} it not implemented for loss {type(self)}." diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 06ccea7ff30..ac9df698df2 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -22,7 +22,7 @@ ) from .common import LossModule -from .value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator +from .value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator, VTrace class PPOLoss(LossModule): @@ -462,6 +462,8 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = GAE(value_network=self.critic, **hp) elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + elif value_type == ValueEstimators.VTrace: + self._value_estimator = VTrace(value_network=self.critic, **hp) else: raise NotImplementedError(f"Unknown value type {value_type}") diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index bc678ed0154..1dd3cfc5f35 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -39,6 +39,7 @@ class ValueEstimators(Enum): TD1 = "TD(1) (infinity-step return)" TDLambda = "TD(lambda)" GAE = "Generalized advantage estimate" + VTrace = "V-trace" def default_value_kwargs(value_type: ValueEstimators): diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 6ea06b01fe7..0f607b40d24 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -192,6 +192,7 @@ class _AcceptedKeys: reward: NestedKey = "reward" done: NestedKey = "done" steps_to_next_obs: NestedKey = "steps_to_next_obs" + sample_log_prob: NestedKey = "sample_log_prob" default_keys = _AcceptedKeys() value_network: Union[TensorDictModule, Callable] @@ -333,7 +334,7 @@ def set_keys(self, **kwargs) -> None: raise ValueError("tensordict keys cannot be None") if key not in self._AcceptedKeys.__dict__: raise KeyError( - f"{key} it not an accepted tensordict key for advantages" + f"{key} it not an acceptedaccepted tensordict key for advantages" ) if ( key == "value" @@ -1273,7 +1274,9 @@ class VTrace(ValueEstimatorBase): Args: gamma (scalar): exponential mean discount. value_network (TensorDictModule): value operator used to retrieve the value estimates. - actor_network (TensorDictModule, optional): actor operator used to retrieve the log prob. + actor_network (TensorDictModule): actor operator used to retrieve the log prob. + rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights. + c_thresh (Union[float, Tensor]): c clipping parameter for importance weights. average_adv (bool): if ``True``, the resulting advantage values will be standardized. Default is ``False``. differentiable (bool, optional): if ``True``, gradients are propagated through @@ -1316,14 +1319,13 @@ def __init__( self, *, gamma: Union[float, torch.Tensor], + actor_network: TensorDictModule, + value_network: TensorDictModule, rho_thresh: Union[float, torch.Tensor] = 1.0, c_thresh: Union[float, torch.Tensor] = 1.0, - actor_network: TensorDictModule = None, - value_network: TensorDictModule, average_adv: bool = False, differentiable: bool = False, skip_existing: Optional[bool] = None, - log_prob_key: NestedKey = "sample_log_prob", # Consider adding it to _AcceptedKeys? advantage_key: NestedKey = None, value_target_key: NestedKey = None, value_key: NestedKey = None, @@ -1355,17 +1357,12 @@ def __init__( self.register_buffer("c_thresh", c_thresh) self.average_adv = average_adv self.actor_network = actor_network - self._log_prob_key = log_prob_key if isinstance(gamma, torch.Tensor) and gamma.shape != (): raise NotImplementedError( "Per-value gamma is not supported yet. Gamma must be a scalar." ) - @property - def log_prob_key(self): - return self._log_prob_key - @_self_set_skip_existing @_self_set_grad_enabled @dispatch @@ -1464,15 +1461,15 @@ def forward( next_value = tensordict.get(("next", self.tensor_keys.value)) # Make sure we have the log prob computed at collection time - if self.log_prob_key not in tensordict.keys(): - raise ValueError(f"Expected {self.log_prob_key} to be in tensordict") - log_mu = tensordict.get(self.log_prob_key).view_as(value) + if self.tensor_keys.sample_log_prob not in tensordict.keys(): + raise ValueError(f"Expected {self.tensor_keys.sample_log_prob} to be in tensordict") + log_mu = tensordict.get(self.tensor_keys.sample_log_prob).view_as(value) # Compute log prob with current policy with hold_out_net(self.actor_network): log_pi = ( self.actor_network(tensordict.select(self.actor_network.in_keys)) - .get(self.log_prob_key) + .get(self.tensor_keys.sample_log_prob) .view_as(value) ) diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index f9c7af3a2ea..1f281274e1a 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -1079,8 +1079,8 @@ def vtrace_advantage_estimate( next_state_value (Tensor): value function result with next_state input. reward (Tensor): reward of taking actions in the environment. done (Tensor): boolean flag for end of episode. - rho_thresh (Union[float, Tensor]): clipping parameter for importance weights. - c_thresh (Union[float, Tensor]): clipping parameter for importance weights. + rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights. + c_thresh (Union[float, Tensor]): c clipping parameter for importance weights. time_dim (int): dimension where the time is unrolled. Defaults to -2. All tensors (values, reward and done) must have shape From dfc1c82653047644576088461f9e5cd05602ac45 Mon Sep 17 00:00:00 2001 From: albert bou Date: Fri, 22 Sep 2023 14:22:07 +0200 Subject: [PATCH 026/109] tests --- examples/impala/config.yaml | 24 +-- examples/impala/impala.py | 151 ++++++++-------- examples/impala/utils.py | 21 ++- examples/impala2/config.yaml | 8 +- examples/impala2/impala.py | 24 ++- examples/impala2/utils.py | 4 +- examples/impala3/config.yaml | 36 ---- examples/impala3/impala.py | 239 ------------------------- examples/impala3/utils.py | 239 ------------------------- test/test_cost.py | 154 ++++++++++++++++ torchrl/objectives/value/functional.py | 4 + 11 files changed, 271 insertions(+), 633 deletions(-) delete mode 100644 examples/impala3/config.yaml delete mode 100644 examples/impala3/impala.py delete mode 100644 examples/impala3/utils.py diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index 448f42c2f3e..783fe37bfc8 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -1,26 +1,24 @@ # Environment env: env_name: PongNoFrameskip-v4 + num_envs: 32 # collector collector: - frames_per_batch: 80 + frames_per_batch: 2560 total_frames: 40_000_000 - num_workers: 8 # logger logger: - backend: null - exp_name: Atari_Espeholt18 - test_interval: 40_000_000 + backend: wandb + exp_name: Atari_Schulman17 + test_interval: 200_000_000 num_test_episodes: 3 # Optim optim: - lr: 0.00048 - eps: 0.01 - alpha: 0.99 - momentum: 0.0 + lr: 0.001 + eps: 1.0e-6 weight_decay: 0.0 max_grad_norm: 1.0 anneal_lr: True @@ -28,7 +26,11 @@ optim: # loss loss: gamma: 0.99 - batch_size: 32 + mini_batch_size: 2560 + ppo_epochs: 2 + gae_lambda: 0.95 + clip_epsilon: 0.1 + anneal_clip_epsilon: True critic_coef: 0.25 - entropy_coef: 0.05 + entropy_coef: 0.01 loss_critic_type: l2 diff --git a/examples/impala/impala.py b/examples/impala/impala.py index bc0862df89e..3b34c9885a9 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -18,44 +18,36 @@ def main(cfg: "DictConfig"): # noqa: F821 import torch.optim import tqdm - from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector - from torchrl.collectors.distributed import RPCDataCollector + from tensordict import TensorDict + from torchrl.collectors import MultiaSyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import A2CLoss, ClipPPOLoss - from torchrl.objectives.value import GAE, VTrace + from torchrl.objectives import ClipPPOLoss, A2CLoss + from torchrl.objectives.value.vtrace import VTrace from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_parallel_env, make_ppo_models - device = "cpu" if not torch.cuda.is_available() else "cuda" + device = "cpu" if not torch.cuda.device_count() else "cuda" # Correct for frame_skip frame_skip = 4 total_frames = cfg.collector.total_frames // frame_skip frames_per_batch = cfg.collector.frames_per_batch // frame_skip + mini_batch_size = cfg.loss.mini_batch_size // frame_skip test_interval = cfg.logger.test_interval // frame_skip - # Create models (check utils.py) - actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = ( + # Create models (check utils_atari.py) + actor, critic, critic_head = make_ppo_models(cfg.env.env_name) + actor, critic, critic_head = ( actor.to(device), critic.to(device), + critic_head.to(device), ) # Create collector - # collector = RPCDataCollector( - # create_env_fn=[make_parallel_env(cfg.env.env_name, 1, device)] * 2, - # policy=actor, - # frames_per_batch=frames_per_batch, - # total_frames=total_frames, - # storing_device="cpu", - # max_frames_per_traj=-1, - # sync=False, - # ) collector = MultiaSyncDataCollector( - create_env_fn=[make_parallel_env(cfg.env.env_name, 1, device)] - * cfg.collector.num_workers, + create_env_fn=[make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device)] * 1, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -64,22 +56,13 @@ def main(cfg: "DictConfig"): # noqa: F821 max_frames_per_traj=-1, update_at_each_batch=True, ) - # collector = SyncDataCollector( - # create_env_fn=make_parallel_env(cfg.env.env_name, cfg.loss.batch_size, device), - # policy=actor, - # frames_per_batch=frames_per_batch, - # total_frames=total_frames, - # device=device, - # storing_device=device, - # max_frames_per_traj=-1, - # ) # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(cfg.collector.frames_per_batch * cfg.loss.batch_size), + storage=LazyMemmapStorage(frames_per_batch), sampler=sampler, - batch_size=cfg.collector.frames_per_batch * cfg.loss.batch_size, + batch_size=mini_batch_size, ) # Create loss and adv modules @@ -89,6 +72,15 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_network=actor, average_adv=False, ) + # loss_module = ClipPPOLoss( + # actor=actor, + # critic=critic, + # clip_epsilon=cfg.loss.clip_epsilon, + # loss_critic_type=cfg.loss.loss_critic_type, + # entropy_coef=cfg.loss.entropy_coef, + # critic_coef=cfg.loss.critic_coef, + # normalize_advantage=True, + # ) loss_module = A2CLoss( actor=actor, critic=critic, @@ -98,13 +90,11 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizer - optim = torch.optim.RMSprop( + optim = torch.optim.Adam( loss_module.parameters(), lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, - momentum=cfg.optim.momentum, eps=cfg.optim.eps, - alpha=cfg.optim.alpha, ) # Create logger @@ -126,7 +116,10 @@ def main(cfg: "DictConfig"): # noqa: F821 num_network_updates = 0 start_time = time.time() pbar = tqdm.tqdm(total=total_frames) - total_network_updates = total_frames // (frames_per_batch * cfg.loss.batch_size) + num_mini_batches = frames_per_batch // mini_batch_size + total_network_updates = ( + (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches + ) sampling_start = time.time() for i, data in enumerate(collector): @@ -137,7 +130,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collected_frames += frames_in_batch * frame_skip pbar.update(data.numel()) - # Get train reward + # Get training rewards and episode lengths episode_rewards = data["next", "episode_reward"][data["next", "done"]] if len(episode_rewards) > 0: episode_length = data["next", "step_count"][data["next", "done"]] @@ -145,7 +138,7 @@ def main(cfg: "DictConfig"): # noqa: F821 { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() - / len(episode_length), + / len(episode_length), } ) @@ -153,10 +146,11 @@ def main(cfg: "DictConfig"): # noqa: F821 data["done"].copy_(data["end_of_life"]) data["next", "done"].copy_(data["next", "end_of_life"]) - # Accumulate data - if (i + 1) % cfg.loss.batch_size != 0: + losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches]) + training_start = time.time() + for j in range(cfg.loss.ppo_epochs): - # Compute VTrace + # Compute adv with torch.no_grad(): data = adv_module(data) data_reshape = data.reshape(-1) @@ -164,46 +158,42 @@ def main(cfg: "DictConfig"): # noqa: F821 # Update the data buffer data_buffer.extend(data_reshape) - if logger: - for key, value in log_info.items(): - logger.log_scalar(key, value, collected_frames) - continue - - training_start = time.time() - - for k, batch in enumerate(data_buffer): + for k, batch in enumerate(data_buffer): - # Linearly decrease the learning rate and clip epsilon - alpha = 1.0 - if cfg.optim.anneal_lr: + # Linearly decrease the learning rate and clip epsilon alpha = 1 - (num_network_updates / total_network_updates) - for group in optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - num_network_updates += 1 - - # Get a data batch - batch = batch.to(device) - - # Forward pass A2C loss - loss = loss_module(batch) - losses = loss.select( - "loss_critic", "loss_entropy", "loss_objective" - ).detach() - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) + if cfg.optim.anneal_lr: + for group in optim.param_groups: + group["lr"] = cfg.optim.lr * alpha + # if cfg.loss.anneal_clip_epsilon: + # loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha) + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device) + + # Forward pass PPO loss + loss = loss_module(batch) + losses[j, k] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) - # Backward pass - loss_sum.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm - ) + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm + ) - # Update the networks - optim.step() - optim.zero_grad() + # Update the networks + optim.step() + optim.zero_grad() + # Get training losses and times training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) for key, value in losses.items(): log_info.update({f"train/{key}": value.item()}) log_info.update( @@ -211,14 +201,15 @@ def main(cfg: "DictConfig"): # noqa: F821 "train/lr": alpha * cfg.optim.lr, "train/sampling_time": sampling_time, "train/training_time": training_time, + # "train/clip_epsilon": alpha * cfg.loss.clip_epsilon, } ) - # Test logging + # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.MODE): - if (collected_frames - frames_in_batch) // test_interval < ( - collected_frames // test_interval - ): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: actor.eval() eval_start = time.time() test_rewards = eval_model( @@ -227,8 +218,8 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_time = time.time() - eval_start log_info.update( { - "test/reward": test_rewards.mean(), - "test/eval_time": eval_time, + "eval/reward": test_rewards.mean(), + "eval/time": eval_time, } ) actor.train() @@ -237,9 +228,9 @@ def main(cfg: "DictConfig"): # noqa: F821 for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) + collector.update_policy_weights_() sampling_start = time.time() - collector.shutdown() end_time = time.time() execution_time = end_time - start_time print(f"Training took {execution_time:.2f} seconds to finish") diff --git a/examples/impala/utils.py b/examples/impala/utils.py index f25fd60613c..b43075013dc 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - import gymnasium as gym import numpy as np import torch.nn @@ -32,7 +31,6 @@ from torchrl.modules import ( ActorValueOperator, ConvNet, - LSTMModule, MLP, OneHotCategorical, ProbabilisticActor, @@ -54,13 +52,13 @@ def __init__(self, env): self.lives = 0 def step(self, action): - obs, rew, done, truncate, info = self.env.step(action) + obs, rew, done, truncated, info = self.env.step(action) lives = self.env.unwrapped.ale.lives() info["end_of_life"] = False if (lives < self.lives) or done: info["end_of_life"] = True self.lives = lives - return obs, rew, done, truncate, info + return obs, rew, done, truncated, info def reset(self, **kwargs): reset_data = self.env.reset(**kwargs) @@ -69,7 +67,7 @@ def reset(self, **kwargs): def make_base_env( - env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False + env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False ): env = gym.make(env_name) if not is_test: @@ -86,12 +84,11 @@ def make_base_env( def make_parallel_env(env_name, num_envs, device, is_test=False): - env = ParallelEnv( num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) ) env = TransformedEnv(env) - env.append_transform(ToTensorImage()) + env.append_transform(ToTensorImage(from_int=True)) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) env.append_transform(CatFrames(N=4, dim=-3)) @@ -100,7 +97,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): if not is_test: env.append_transform(RewardClipping(-1, 1)) env.append_transform(DoubleToFloat()) - env.append_transform(VecNorm(in_keys=["pixels"], eps=0.999)) + env.append_transform(VecNorm(in_keys=["pixels"], decay=0.99999, eps=1e-2)) return env @@ -207,12 +204,18 @@ def make_ppo_models(env_name): value_operator=value_module, ) + with torch.no_grad(): + td = proof_environment.rollout(max_steps=100, break_when_any_done=False) + td = actor_critic(td) + del td + actor = actor_critic.get_policy_operator() critic = actor_critic.get_value_operator() + critic_head = actor_critic.get_value_head() del proof_environment - return actor, critic + return actor, critic, critic_head # ==================================================================== diff --git a/examples/impala2/config.yaml b/examples/impala2/config.yaml index 7eb1cd612a8..9462d2e52ed 100644 --- a/examples/impala2/config.yaml +++ b/examples/impala2/config.yaml @@ -1,7 +1,7 @@ # Environment env: env_name: PongNoFrameskip-v4 - num_envs: 8 + num_envs: 2 # collector collector: @@ -12,7 +12,7 @@ collector: logger: backend: wandb exp_name: Atari_Schulman17 - test_interval: 40_000_000 + test_interval: 200_000_000 num_test_episodes: 3 # Optim @@ -27,10 +27,10 @@ optim: loss: gamma: 0.99 mini_batch_size: 1024 - ppo_epochs: 1 + ppo_epochs: 3 gae_lambda: 0.95 clip_epsilon: 0.1 anneal_clip_epsilon: True - critic_coef: 1.0 + critic_coef: 0.5 entropy_coef: 0.01 loss_critic_type: l2 diff --git a/examples/impala2/impala.py b/examples/impala2/impala.py index acb79379744..69923864c97 100644 --- a/examples/impala2/impala.py +++ b/examples/impala2/impala.py @@ -23,7 +23,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import ClipPPOLoss + from torchrl.objectives import ClipPPOLoss, A2CLoss from torchrl.objectives.value.vtrace import VTrace from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_parallel_env, make_ppo_models @@ -47,7 +47,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create collector collector = MultiaSyncDataCollector( - create_env_fn=[make_parallel_env(cfg.env.env_name, 8, device)] * 4, + create_env_fn=[make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device)] * 4, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -72,14 +72,12 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_network=actor, average_adv=False, ) - loss_module = ClipPPOLoss( + loss_module = A2CLoss( actor=actor, critic=critic, - clip_epsilon=cfg.loss.clip_epsilon, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, - normalize_advantage=True, ) # Create optimizer @@ -111,7 +109,7 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar = tqdm.tqdm(total=total_frames) num_mini_batches = frames_per_batch // mini_batch_size total_network_updates = ( - (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches + (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches ) sampling_start = time.time() @@ -131,7 +129,7 @@ def main(cfg: "DictConfig"): # noqa: F821 { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() - / len(episode_length), + / len(episode_length), } ) @@ -143,7 +141,7 @@ def main(cfg: "DictConfig"): # noqa: F821 training_start = time.time() for j in range(cfg.loss.ppo_epochs): - # Compute GAE + # Compute adv with torch.no_grad(): data = adv_module(data) data_reshape = data.reshape(-1) @@ -158,8 +156,8 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.optim.anneal_lr: for group in optim.param_groups: group["lr"] = cfg.optim.lr * alpha - if cfg.loss.anneal_clip_epsilon: - loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha) + # if cfg.loss.anneal_clip_epsilon: + # loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha) num_network_updates += 1 # Get a data batch @@ -171,7 +169,7 @@ def main(cfg: "DictConfig"): # noqa: F821 "loss_critic", "loss_entropy", "loss_objective" ).detach() loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] ) # Backward pass @@ -194,14 +192,14 @@ def main(cfg: "DictConfig"): # noqa: F821 "train/lr": alpha * cfg.optim.lr, "train/sampling_time": sampling_time, "train/training_time": training_time, - "train/clip_epsilon": alpha * cfg.loss.clip_epsilon, + # "train/clip_epsilon": alpha * cfg.loss.clip_epsilon, } ) # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.MODE): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( - i * frames_in_batch * frame_skip + i * frames_in_batch * frame_skip ) // test_interval: actor.eval() eval_start = time.time() diff --git a/examples/impala2/utils.py b/examples/impala2/utils.py index 2ce03e50c35..9d3af001416 100644 --- a/examples/impala2/utils.py +++ b/examples/impala2/utils.py @@ -88,7 +88,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) ) env = TransformedEnv(env) - env.append_transform(ToTensorImage(from_int=False)) + env.append_transform(ToTensorImage(from_int=True)) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) env.append_transform(CatFrames(N=4, dim=-3)) @@ -97,7 +97,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): if not is_test: env.append_transform(RewardClipping(-1, 1)) env.append_transform(DoubleToFloat()) - env.append_transform(VecNorm(in_keys=["pixels"])) + env.append_transform(VecNorm(in_keys=["pixels"], decay=0.9999, eps=1e-3)) return env diff --git a/examples/impala3/config.yaml b/examples/impala3/config.yaml deleted file mode 100644 index dacfdcb1bf3..00000000000 --- a/examples/impala3/config.yaml +++ /dev/null @@ -1,36 +0,0 @@ -# Environment -env: - env_name: PongNoFrameskip-v4 - num_envs: 8 - -# collector -collector: - frames_per_batch: 2560 - total_frames: 40_000_000 - -# logger -logger: - backend: wandb - exp_name: Atari_Schulman17 - test_interval: 40_000_000 - num_test_episodes: 3 - -# Optim -optim: - lr: 0.0006 - eps: 1.0e-6 - weight_decay: 0.0 - max_grad_norm: 0.5 - anneal_lr: True - -# loss -loss: - gamma: 0.99 - mini_batch_size: 2560 - ppo_epochs: 1 - gae_lambda: 0.95 - clip_epsilon: 0.1 - anneal_clip_epsilon: True - critic_coef: 0.5 - entropy_coef: 0.01 - loss_critic_type: l2 diff --git a/examples/impala3/impala.py b/examples/impala3/impala.py deleted file mode 100644 index 15cb7808332..00000000000 --- a/examples/impala3/impala.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -""" -This script reproduces the IMPALA Algorithm -results from Espeholt et al. 2018 for the on Atari Environments. -""" -import hydra - - -@hydra.main(config_path=".", config_name="config", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 - - import time - - import torch.optim - import tqdm - - from tensordict import TensorDict - from torchrl.collectors import SyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer - from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement - from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import A2CLoss, ClipPPOLoss - from torchrl.objectives.value.vtrace import VTrace - from torchrl.record.loggers import generate_exp_name, get_logger - from utils import eval_model, make_parallel_env, make_ppo_models - - device = "cpu" if not torch.cuda.device_count() else "cuda" - - # Correct for frame_skip - frame_skip = 4 - total_frames = cfg.collector.total_frames // frame_skip - frames_per_batch = cfg.collector.frames_per_batch // frame_skip - mini_batch_size = cfg.loss.mini_batch_size // frame_skip - test_interval = cfg.logger.test_interval // frame_skip - - # Create models (check utils_atari.py) - actor, critic, critic_head = make_ppo_models(cfg.env.env_name) - actor, critic, critic_head = ( - actor.to(device), - critic.to(device), - critic_head.to(device), - ) - - # Create collector - collector = SyncDataCollector( - create_env_fn=make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device), - policy=actor, - frames_per_batch=frames_per_batch, - total_frames=total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, - ) - - # Create data buffer - sampler = SamplerWithoutReplacement() - data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(frames_per_batch), - sampler=sampler, - batch_size=mini_batch_size, - ) - - # Create loss and adv modules - adv_module = VTrace( - gamma=cfg.loss.gamma, - value_network=critic, - actor_network=actor, - average_adv=True, - ) - # loss_module = ClipPPOLoss( - # actor=actor, - # critic=critic, - # clip_epsilon=cfg.loss.clip_epsilon, - # loss_critic_type=cfg.loss.loss_critic_type, - # entropy_coef=cfg.loss.entropy_coef, - # critic_coef=cfg.loss.critic_coef, - # normalize_advantage=True, - # ) - loss_module = A2CLoss( - actor=actor, - critic=critic, - loss_critic_type=cfg.loss.loss_critic_type, - entropy_coef=cfg.loss.entropy_coef, - critic_coef=cfg.loss.critic_coef, - ) - - # Create optimizer - optim = torch.optim.Adam( - loss_module.parameters(), - lr=cfg.optim.lr, - weight_decay=cfg.optim.weight_decay, - eps=cfg.optim.eps, - ) - - # Create logger - logger = None - if cfg.logger.backend: - exp_name = generate_exp_name( - "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" - ) - logger = get_logger( - cfg.logger.backend, logger_name="impala", experiment_name=exp_name - ) - - # Create test environment - test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) - test_env.eval() - - # Main loop - collected_frames = 0 - num_network_updates = 0 - start_time = time.time() - pbar = tqdm.tqdm(total=total_frames) - num_mini_batches = frames_per_batch // mini_batch_size - total_network_updates = ( - (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches - ) - - sampling_start = time.time() - for i, data in enumerate(collector): - - log_info = {} - sampling_time = time.time() - sampling_start - frames_in_batch = data.numel() - collected_frames += frames_in_batch * frame_skip - pbar.update(data.numel()) - - # Get training rewards and episode lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] - if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( - { - "train/reward": episode_rewards.mean().item(), - "train/episode_length": episode_length.sum().item() - / len(episode_length), - } - ) - - # Apply episodic end of life - data["done"].copy_(data["end_of_life"]) - data["next", "done"].copy_(data["next", "end_of_life"]) - - losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches]) - training_start = time.time() - for j in range(cfg.loss.ppo_epochs): - - # Compute GAE - with torch.no_grad(): - data = adv_module(data) - data_reshape = data.reshape(-1) - - # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Linearly decrease the learning rate and clip epsilon - alpha = 1 - (num_network_updates / total_network_updates) - if cfg.optim.anneal_lr: - for group in optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - # if cfg.loss.anneal_clip_epsilon: - # loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha) - num_network_updates += 1 - - # Get a data batch - batch = batch.to(device) - - # Forward pass PPO loss - loss = loss_module(batch) - losses[j, k] = loss.select( - "loss_critic", "loss_entropy", "loss_objective" - ).detach() - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) - - # Backward pass - loss_sum.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm - ) - - # Update the networks - optim.step() - optim.zero_grad() - - # Get training losses and times - training_time = time.time() - training_start - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) - for key, value in losses.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( - { - "train/lr": alpha * cfg.optim.lr, - "train/sampling_time": sampling_time, - "train/training_time": training_time, - "train/clip_epsilon": alpha * cfg.loss.clip_epsilon, - } - ) - - # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): - if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( - i * frames_in_batch * frame_skip - ) // test_interval: - actor.eval() - eval_start = time.time() - test_rewards = eval_model( - actor, test_env, num_episodes=cfg.logger.num_test_episodes - ) - eval_time = time.time() - eval_start - log_info.update( - { - "eval/reward": test_rewards.mean(), - "eval/time": eval_time, - } - ) - actor.train() - - if logger: - for key, value in log_info.items(): - logger.log_scalar(key, value, collected_frames) - - collector.update_policy_weights_() - sampling_start = time.time() - - end_time = time.time() - execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") - - -if __name__ == "__main__": - main() diff --git a/examples/impala3/utils.py b/examples/impala3/utils.py deleted file mode 100644 index ccf6978a340..00000000000 --- a/examples/impala3/utils.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import gymnasium as gym -import numpy as np -import torch.nn -import torch.optim -from tensordict.nn import TensorDictModule -from torchrl.data import CompositeSpec -from torchrl.data.tensor_specs import DiscreteBox -from torchrl.envs import ( - CatFrames, - default_info_dict_reader, - DoubleToFloat, - EnvCreator, - ExplorationType, - GrayScale, - NoopResetEnv, - ParallelEnv, - Resize, - RewardClipping, - RewardSum, - StepCounter, - ToTensorImage, - TransformedEnv, - VecNorm, -) -from torchrl.envs.libs.gym import GymWrapper -from torchrl.modules import ( - ActorValueOperator, - ConvNet, - MLP, - OneHotCategorical, - ProbabilisticActor, - TanhNormal, - ValueOperator, -) - -# ==================================================================== -# Environment utils -# -------------------------------------------------------------------- - - -class EpisodicLifeEnv(gym.Wrapper): - def __init__(self, env): - """Make end-of-life == end-of-episode, but only reset on true game over. - Done by DeepMind for the DQN and co. It helps value estimation. - """ - gym.Wrapper.__init__(self, env) - self.lives = 0 - - def step(self, action): - obs, rew, done, truncated, info = self.env.step(action) - lives = self.env.unwrapped.ale.lives() - info["end_of_life"] = False - if (lives < self.lives) or done: - info["end_of_life"] = True - self.lives = lives - return obs, rew, done, truncated, info - - def reset(self, **kwargs): - reset_data = self.env.reset(**kwargs) - self.lives = self.env.unwrapped.ale.lives() - return reset_data - - -def make_base_env( - env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False -): - env = gym.make(env_name) - if not is_test: - env = EpisodicLifeEnv(env) - env = GymWrapper( - env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device - ) - env = TransformedEnv(env) - env.append_transform(NoopResetEnv(noops=30, random=True)) - if not is_test: - reader = default_info_dict_reader(["end_of_life"]) - env.set_info_dict_reader(reader) - return env - - -def make_parallel_env(env_name, num_envs, device, is_test=False): - env = ParallelEnv( - num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) - ) - env = TransformedEnv(env) - env.append_transform(ToTensorImage()) - env.append_transform(GrayScale()) - env.append_transform(Resize(84, 84)) - env.append_transform(CatFrames(N=4, dim=-3)) - env.append_transform(RewardSum()) - env.append_transform(StepCounter(max_steps=4500)) - if not is_test: - env.append_transform(RewardClipping(-1, 1)) - env.append_transform(DoubleToFloat()) - env.append_transform(VecNorm(in_keys=["pixels"])) - return env - - -# ==================================================================== -# Model utils -# -------------------------------------------------------------------- - - -def make_ppo_modules_pixels(proof_environment): - - # Define input shape - input_shape = proof_environment.observation_spec["pixels"].shape - - # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, DiscreteBox): - num_outputs = proof_environment.action_spec.space.n - distribution_class = OneHotCategorical - distribution_kwargs = {} - else: # is ContinuousBox - num_outputs = proof_environment.action_spec.shape - distribution_class = TanhNormal - distribution_kwargs = { - "min": proof_environment.action_spec.space.minimum, - "max": proof_environment.action_spec.space.maximum, - } - - # Define input keys - in_keys = ["pixels"] - - # Define a shared Module and TensorDictModule (CNN + MLP) - common_cnn = ConvNet( - activation_class=torch.nn.ReLU, - num_cells=[32, 64, 64], - kernel_sizes=[8, 4, 3], - strides=[4, 2, 1], - ) - common_cnn_output = common_cnn(torch.ones(input_shape)) - common_mlp = MLP( - in_features=common_cnn_output.shape[-1], - activation_class=torch.nn.ReLU, - activate_last_layer=True, - out_features=512, - num_cells=[], - ) - common_mlp_output = common_mlp(common_cnn_output) - - # Define shared net as TensorDictModule - common_module = TensorDictModule( - module=torch.nn.Sequential(common_cnn, common_mlp), - in_keys=in_keys, - out_keys=["common_features"], - ) - - # Define on head for the policy - policy_net = MLP( - in_features=common_mlp_output.shape[-1], - out_features=num_outputs, - activation_class=torch.nn.ReLU, - num_cells=[], - ) - policy_module = TensorDictModule( - module=policy_net, - in_keys=["common_features"], - out_keys=["logits"], - ) - - # Add probabilistic sampling of the actions - policy_module = ProbabilisticActor( - policy_module, - in_keys=["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), - distribution_class=distribution_class, - distribution_kwargs=distribution_kwargs, - return_log_prob=True, - default_interaction_type=ExplorationType.RANDOM, - ) - - # Define another head for the value - value_net = MLP( - activation_class=torch.nn.ReLU, - in_features=common_mlp_output.shape[-1], - out_features=1, - num_cells=[], - ) - value_module = ValueOperator( - value_net, - in_keys=["common_features"], - ) - - return common_module, policy_module, value_module - - -def make_ppo_models(env_name): - - proof_environment = make_parallel_env(env_name, 1, device="cpu") - common_module, policy_module, value_module = make_ppo_modules_pixels( - proof_environment - ) - - # Wrap modules in a single ActorCritic operator - actor_critic = ActorValueOperator( - common_operator=common_module, - policy_operator=policy_module, - value_operator=value_module, - ) - - with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) - del td - - actor = actor_critic.get_policy_operator() - critic = actor_critic.get_value_operator() - critic_head = actor_critic.get_value_head() - - del proof_environment - - return actor, critic, critic_head - - -# ==================================================================== -# Evaluation utils -# -------------------------------------------------------------------- - - -def eval_model(actor, test_env, num_episodes=3): - test_rewards = [] - for _ in range(num_episodes): - td_test = test_env.rollout( - policy=actor, - auto_reset=True, - auto_cast_to_device=True, - break_when_any_done=True, - max_steps=10_000_000, - ) - reward = td_test["next", "episode_reward"][td_test["next", "done"]] - test_rewards = np.append(test_rewards, reward.cpu().numpy()) - del td_test - return test_rewards.mean() diff --git a/test/test_cost.py b/test/test_cost.py index be07b19924b..5eb93681183 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -134,6 +134,7 @@ from torchrl.objectives.value.functional import ( _transpose_time, generalized_advantage_estimate, + vtrace_advantage_estimate, td0_advantage_estimate, td1_advantage_estimate, td_lambda_advantage_estimate, @@ -8568,6 +8569,158 @@ def test_gae_multidim( torch.testing.assert_close(r1, r3, rtol=1e-4, atol=1e-4) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("gamma", [0.99, 0.5, 0.1]) + @pytest.mark.parametrize("rho_thresh", [1.0, 0.5]) + @pytest.mark.parametrize("c_thresh", [1.0, 0.5]) + @pytest.mark.parametrize("N", [(1,), (3,), (7, 3)]) + @pytest.mark.parametrize("T", [200, 5, 3]) + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) + @pytest.mark.parametrize("has_done", [True, False]) + def test_vtrace(self, device, gamma, N, T, dtype, has_done, rho_thresh, c_thresh): + torch.manual_seed(0) + + done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + if has_done: + done = done.bernoulli_(0.1) + reward = torch.randn(*N, T, 1, device=device, dtype=dtype) + state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) + next_state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) + log_pi = torch.log(torch.randn(*N, T, 1, device=device, dtype=dtype)) + log_mu = torch.log(torch.randn(*N, T, 1, device=device, dtype=dtype)) + + r1 = vtrace_advantage_estimate( + gamma, log_pi, log_mu, state_value, next_state_value, reward, done, rho_thresh, c_thresh + ) + + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("N", [(1,), (8,), (7, 3)]) + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) + @pytest.mark.parametrize("has_done", [True, False]) + @pytest.mark.parametrize( + "gamma_tensor", ["scalar", "tensor", "tensor_single_element"] + ) + @pytest.mark.parametrize( + "rho_thresh_tensor", ["scalar", "tensor", "tensor_single_element"] + ) + @pytest.mark.parametrize( + "c_thresh_tensor", ["scalar", "tensor", "tensor_single_element"] + ) + def test_vtrace_param_as_tensor( + self, device, N, dtype, has_done, gamma_tensor, rho_thresh_tensor, c_thresh_tensor + ): + torch.manual_seed(0) + + gamma = 0.95 + lmbda = 0.90 + T = 200 + + done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + if has_done: + done = done.bernoulli_(0.1) + reward = torch.randn(*N, T, 1, device=device, dtype=dtype) + state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) + next_state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) + log_pi = torch.log(torch.randn(*N, T, 1, device=device, dtype=dtype)) + log_mu = torch.log(torch.randn(*N, T, 1, device=device, dtype=dtype)) + + if gamma_tensor == "tensor": + gamma_vec = torch.full_like(reward, gamma) + elif gamma_tensor == "tensor_single_element": + gamma_vec = torch.as_tensor([gamma], device=device) + else: + gamma_vec = gamma + + if rho_thresh_tensor == "tensor": + rho_thresh_tensor_vec = torch.full_like(reward, rho_thresh_tensor) + elif rho_thresh_tensor == "tensor_single_element": + rho_thresh_tensor_vec = torch.as_tensor([rho_thresh_tensor], device=device) + else: + rho_thresh_tensor_vec = rho_thresh_tensor + + if c_thresh_tensor == "tensor": + c_thresh_tensor_vec = torch.full_like(reward, c_thresh_tensor) + elif c_thresh_tensor == "tensor_single_element": + c_thresh_tensor_vec = torch.as_tensor([c_thresh_tensor], device=device) + else: + c_thresh_tensor_vec = c_thresh_tensor + + r1 = vtrace_advantage_estimate( + gamma_vec, log_pi, log_mu, state_value, next_state_value, reward, done, rho_thresh_tensor_vec, c_thresh_tensor_vec + ) + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("gamma", [0.99, 0.5, 0.1]) + @pytest.mark.parametrize("N", [(3,), (7, 3)]) + @pytest.mark.parametrize("T", [100, 3]) + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) + @pytest.mark.parametrize("feature_dim", [[5], [2, 5]]) + @pytest.mark.parametrize("has_done", [True, False]) + def test_vtrace_multidim( + self, device, gamma, N, T, dtype, has_done, feature_dim + ): + D = feature_dim + time_dim = -1 - len(D) + + torch.manual_seed(0) + + done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool) + if has_done: + done = done.bernoulli_(0.1) + reward = torch.randn(*N, T, *D, device=device, dtype=dtype) + state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) + next_state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) + log_pi = torch.log(torch.randn(*N, T, 1, device=device, dtype=dtype)) + log_mu = torch.log(torch.randn(*N, T, 1, device=device, dtype=dtype)) + + r1 = vtrace_advantage_estimate( + gamma, + log_pi, + log_mu, + state_value, + next_state_value, + reward, + done, + time_dim=time_dim, + ) + + if len(D) == 2: + r2 = [ + vtrace_advantage_estimate( + gamma, + log_pi[..., i : i + 1, j], + log_mu[..., i : i + 1, j], + state_value[..., i : i + 1, j], + next_state_value[..., i : i + 1, j], + reward[..., i : i + 1, j], + done[..., i : i + 1, j], + time_dim=-2, + ) + for i in range(D[0]) + for j in range(D[1]) + ] + else: + r2 = [ + vtrace_advantage_estimate( + gamma, + log_pi[..., i : i + 1], + log_mu[..., i : i + 1], + state_value[..., i : i + 1], + next_state_value[..., i : i + 1], + reward[..., i : i + 1], + done[..., i : i + 1], + time_dim=-2, + ) + for i in range(D[0]) + ] + + list2 = list(zip(*r2)) + r2 = [torch.cat(list2[0], -1), torch.cat(list2[1], -1)] + if len(D) == 2: + r2 = [r2[0].unflatten(-1, D), r2[1].unflatten(-1, D)] + torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) + @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("gamma", [0.5, 0.99, 0.1]) @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) @@ -9337,6 +9490,7 @@ def test_dispatch( in_keys=["loc", "scale"], spec=UnboundedContinuousTensorSpec(n_act), ) + kwargs["actor_network"] = actor_net module = adv( gamma=0.98, value_network=value_net, diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 1f281274e1a..7c6314117a5 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -1053,6 +1053,10 @@ def vec_td_lambda_advantage_estimate( ) +######################################################################## +# V-Trace +# ----- + @_transpose_time def vtrace_advantage_estimate( gamma: float, From a568378fa83bac00785d42595daa841e56169614 Mon Sep 17 00:00:00 2001 From: albert bou Date: Fri, 22 Sep 2023 18:15:28 +0200 Subject: [PATCH 027/109] fix --- examples/impala/config.yaml | 13 ++++---- examples/impala/impala.py | 17 ++-------- examples/impala/utils.py | 5 +-- examples/impala2/config.yaml | 23 +++++++------ examples/impala2/impala.py | 45 ++++++++++++++------------ examples/impala2/utils.py | 7 ++-- torchrl/objectives/value/functional.py | 5 +-- 7 files changed, 55 insertions(+), 60 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index 783fe37bfc8..ad5bd4d08f2 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -6,7 +6,7 @@ env: # collector collector: frames_per_batch: 2560 - total_frames: 40_000_000 + total_frames: 200_000_000 # logger logger: @@ -17,9 +17,11 @@ logger: # Optim optim: - lr: 0.001 - eps: 1.0e-6 + lr: 0.0006 + eps: 0.01 weight_decay: 0.0 + momentum: 0.0 + alpha: 0.99 max_grad_norm: 1.0 anneal_lr: True @@ -27,10 +29,7 @@ optim: loss: gamma: 0.99 mini_batch_size: 2560 - ppo_epochs: 2 - gae_lambda: 0.95 - clip_epsilon: 0.1 - anneal_clip_epsilon: True + ppo_epochs: 1 critic_coef: 0.25 entropy_coef: 0.01 loss_critic_type: l2 diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 3b34c9885a9..9e0790bbd2c 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -24,7 +24,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import ClipPPOLoss, A2CLoss - from torchrl.objectives.value.vtrace import VTrace + from torchrl.objectives.value import VTrace from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_parallel_env, make_ppo_models @@ -72,15 +72,6 @@ def main(cfg: "DictConfig"): # noqa: F821 actor_network=actor, average_adv=False, ) - # loss_module = ClipPPOLoss( - # actor=actor, - # critic=critic, - # clip_epsilon=cfg.loss.clip_epsilon, - # loss_critic_type=cfg.loss.loss_critic_type, - # entropy_coef=cfg.loss.entropy_coef, - # critic_coef=cfg.loss.critic_coef, - # normalize_advantage=True, - # ) loss_module = A2CLoss( actor=actor, critic=critic, @@ -90,11 +81,12 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizer - optim = torch.optim.Adam( + optim = torch.optim.RMSprop( loss_module.parameters(), lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, eps=cfg.optim.eps, + alpha=cfg.optim.alpha, ) # Create logger @@ -165,8 +157,6 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.optim.anneal_lr: for group in optim.param_groups: group["lr"] = cfg.optim.lr * alpha - # if cfg.loss.anneal_clip_epsilon: - # loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha) num_network_updates += 1 # Get a data batch @@ -201,7 +191,6 @@ def main(cfg: "DictConfig"): # noqa: F821 "train/lr": alpha * cfg.optim.lr, "train/sampling_time": sampling_time, "train/training_time": training_time, - # "train/clip_epsilon": alpha * cfg.loss.clip_epsilon, } ) diff --git a/examples/impala/utils.py b/examples/impala/utils.py index b43075013dc..6d2c686aec6 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -26,6 +26,7 @@ ToTensorImage, TransformedEnv, VecNorm, + ObservationNorm ) from torchrl.envs.libs.gym import GymWrapper from torchrl.modules import ( @@ -88,7 +89,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) ) env = TransformedEnv(env) - env.append_transform(ToTensorImage(from_int=True)) + env.append_transform(ToTensorImage(from_int=False)) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) env.append_transform(CatFrames(N=4, dim=-3)) @@ -97,7 +98,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): if not is_test: env.append_transform(RewardClipping(-1, 1)) env.append_transform(DoubleToFloat()) - env.append_transform(VecNorm(in_keys=["pixels"], decay=0.99999, eps=1e-2)) + env.append_transform(ObservationNorm(in_keys=['pixels'], scale=1/255., loc=-0.5, standard_normal=False)) return env diff --git a/examples/impala2/config.yaml b/examples/impala2/config.yaml index 9462d2e52ed..a0e6e3b4651 100644 --- a/examples/impala2/config.yaml +++ b/examples/impala2/config.yaml @@ -1,12 +1,12 @@ # Environment env: env_name: PongNoFrameskip-v4 - num_envs: 2 + num_envs: 4 # collector collector: - frames_per_batch: 1024 - total_frames: 40_000_000 + frames_per_batch: 320 + total_frames: 200_000_000 # logger logger: @@ -17,20 +17,19 @@ logger: # Optim optim: - lr: 2.5e-4 - eps: 1.0e-6 + lr: 0.0006 + eps: 0.01 weight_decay: 0.0 - max_grad_norm: 0.5 + momentum: 0.0 + alpha: 0.99 + max_grad_norm: 1.0 anneal_lr: True # loss loss: gamma: 0.99 - mini_batch_size: 1024 - ppo_epochs: 3 - gae_lambda: 0.95 - clip_epsilon: 0.1 - anneal_clip_epsilon: True - critic_coef: 0.5 + batch_size: 8 + sgd_updates: 1 + critic_coef: 0.25 entropy_coef: 0.01 loss_critic_type: l2 diff --git a/examples/impala2/impala.py b/examples/impala2/impala.py index 69923864c97..298af078187 100644 --- a/examples/impala2/impala.py +++ b/examples/impala2/impala.py @@ -24,7 +24,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import ClipPPOLoss, A2CLoss - from torchrl.objectives.value.vtrace import VTrace + from torchrl.objectives.value import VTrace from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_parallel_env, make_ppo_models @@ -47,7 +47,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create collector collector = MultiaSyncDataCollector( - create_env_fn=[make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device)] * 4, + create_env_fn=[make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device)] * 8, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -60,9 +60,9 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(frames_per_batch), + storage=LazyMemmapStorage(frames_per_batch * cfg.loss.batch_size), sampler=sampler, - batch_size=mini_batch_size, + batch_size=mini_batch_size * cfg.loss.batch_size, ) # Create loss and adv modules @@ -81,11 +81,12 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizer - optim = torch.optim.Adam( + optim = torch.optim.RMSprop( loss_module.parameters(), lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay, eps=cfg.optim.eps, + alpha=cfg.optim.alpha, ) # Create logger @@ -107,11 +108,9 @@ def main(cfg: "DictConfig"): # noqa: F821 num_network_updates = 0 start_time = time.time() pbar = tqdm.tqdm(total=total_frames) - num_mini_batches = frames_per_batch // mini_batch_size - total_network_updates = ( - (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches - ) + total_network_updates = (total_frames // frames_per_batch * cfg.loss.batch_size) * cfg.loss.sgd_updates + accumulator = [] sampling_start = time.time() for i, data in enumerate(collector): @@ -137,17 +136,25 @@ def main(cfg: "DictConfig"): # noqa: F821 data["done"].copy_(data["end_of_life"]) data["next", "done"].copy_(data["next", "end_of_life"]) - losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches]) + if len(accumulator) < cfg.loss.batch_size: + accumulator.append(data) + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + continue + + losses = TensorDict({}, batch_size=[cfg.loss.sgd_updates, 1]) training_start = time.time() - for j in range(cfg.loss.ppo_epochs): + for j in range(cfg.loss.sgd_updates): + + for acc_data in accumulator: - # Compute adv - with torch.no_grad(): - data = adv_module(data) - data_reshape = data.reshape(-1) + with torch.no_grad(): + acc_data = adv_module(acc_data) + acc_data_reshape = acc_data.reshape(-1) - # Update the data buffer - data_buffer.extend(data_reshape) + # Update the data buffer + data_buffer.extend(acc_data_reshape) for k, batch in enumerate(data_buffer): @@ -156,8 +163,6 @@ def main(cfg: "DictConfig"): # noqa: F821 if cfg.optim.anneal_lr: for group in optim.param_groups: group["lr"] = cfg.optim.lr * alpha - # if cfg.loss.anneal_clip_epsilon: - # loss_module.clip_epsilon.copy_(cfg.loss.clip_epsilon * alpha) num_network_updates += 1 # Get a data batch @@ -192,7 +197,6 @@ def main(cfg: "DictConfig"): # noqa: F821 "train/lr": alpha * cfg.optim.lr, "train/sampling_time": sampling_time, "train/training_time": training_time, - # "train/clip_epsilon": alpha * cfg.loss.clip_epsilon, } ) @@ -221,6 +225,7 @@ def main(cfg: "DictConfig"): # noqa: F821 collector.update_policy_weights_() sampling_start = time.time() + accumulator = [] end_time = time.time() execution_time = end_time - start_time diff --git a/examples/impala2/utils.py b/examples/impala2/utils.py index 9d3af001416..6d2c686aec6 100644 --- a/examples/impala2/utils.py +++ b/examples/impala2/utils.py @@ -26,6 +26,7 @@ ToTensorImage, TransformedEnv, VecNorm, + ObservationNorm ) from torchrl.envs.libs.gym import GymWrapper from torchrl.modules import ( @@ -67,7 +68,7 @@ def reset(self, **kwargs): def make_base_env( - env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False + env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False ): env = gym.make(env_name) if not is_test: @@ -88,7 +89,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) ) env = TransformedEnv(env) - env.append_transform(ToTensorImage(from_int=True)) + env.append_transform(ToTensorImage(from_int=False)) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) env.append_transform(CatFrames(N=4, dim=-3)) @@ -97,7 +98,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): if not is_test: env.append_transform(RewardClipping(-1, 1)) env.append_transform(DoubleToFloat()) - env.append_transform(VecNorm(in_keys=["pixels"], decay=0.9999, eps=1e-3)) + env.append_transform(ObservationNorm(in_keys=['pixels'], scale=1/255., loc=-0.5, standard_normal=False)) return env diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 7c6314117a5..b5dde0cec79 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -1099,10 +1099,11 @@ def vtrace_advantage_estimate( *batch_size, time_steps, lastdim = not_done.shape discounts = gamma * not_done - clipped_rho = (log_pi - log_mu).exp().clamp_max(rho_thresh) + rho = (log_pi - log_mu).exp() + clipped_rho = rho.clamp_max(rho_thresh) deltas = clipped_rho * (reward + discounts * next_state_value - state_value) c_thresh = c_thresh.to(device) - clipped_c = torch.clamp(c_thresh, max=clipped_rho) + clipped_c = torch.clamp(rho, max=c_thresh) vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])] for i in reversed(range(time_steps)): From 596c6cc4357b90698378d655a76b49992a6f8aa6 Mon Sep 17 00:00:00 2001 From: albert bou Date: Fri, 22 Sep 2023 18:26:41 +0200 Subject: [PATCH 028/109] format --- examples/impala/impala.py | 13 +++++++------ examples/impala/utils.py | 11 +++++++---- examples/impala2/impala.py | 15 +++++++++------ examples/impala2/utils.py | 11 +++++++---- torchrl/objectives/a2c.py | 8 +++++++- torchrl/objectives/value/advantages.py | 4 +++- torchrl/objectives/value/functional.py | 3 ++- 7 files changed, 42 insertions(+), 23 deletions(-) diff --git a/examples/impala/impala.py b/examples/impala/impala.py index 9e0790bbd2c..4b00c6cdd58 100644 --- a/examples/impala/impala.py +++ b/examples/impala/impala.py @@ -23,7 +23,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import ClipPPOLoss, A2CLoss + from torchrl.objectives import A2CLoss from torchrl.objectives.value import VTrace from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_parallel_env, make_ppo_models @@ -47,7 +47,8 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create collector collector = MultiaSyncDataCollector( - create_env_fn=[make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device)] * 1, + create_env_fn=[make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device)] + * 1, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -110,7 +111,7 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar = tqdm.tqdm(total=total_frames) num_mini_batches = frames_per_batch // mini_batch_size total_network_updates = ( - (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches + (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches ) sampling_start = time.time() @@ -130,7 +131,7 @@ def main(cfg: "DictConfig"): # noqa: F821 { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() - / len(episode_length), + / len(episode_length), } ) @@ -168,7 +169,7 @@ def main(cfg: "DictConfig"): # noqa: F821 "loss_critic", "loss_entropy", "loss_objective" ).detach() loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] ) # Backward pass @@ -197,7 +198,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.MODE): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( - i * frames_in_batch * frame_skip + i * frames_in_batch * frame_skip ) // test_interval: actor.eval() eval_start = time.time() diff --git a/examples/impala/utils.py b/examples/impala/utils.py index 6d2c686aec6..a7d478e54b1 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -18,6 +18,7 @@ ExplorationType, GrayScale, NoopResetEnv, + ObservationNorm, ParallelEnv, Resize, RewardClipping, @@ -25,8 +26,6 @@ StepCounter, ToTensorImage, TransformedEnv, - VecNorm, - ObservationNorm ) from torchrl.envs.libs.gym import GymWrapper from torchrl.modules import ( @@ -68,7 +67,7 @@ def reset(self, **kwargs): def make_base_env( - env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False + env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False ): env = gym.make(env_name) if not is_test: @@ -98,7 +97,11 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): if not is_test: env.append_transform(RewardClipping(-1, 1)) env.append_transform(DoubleToFloat()) - env.append_transform(ObservationNorm(in_keys=['pixels'], scale=1/255., loc=-0.5, standard_normal=False)) + env.append_transform( + ObservationNorm( + in_keys=["pixels"], scale=1 / 255.0, loc=-0.5, standard_normal=False + ) + ) return env diff --git a/examples/impala2/impala.py b/examples/impala2/impala.py index 298af078187..da920d347bc 100644 --- a/examples/impala2/impala.py +++ b/examples/impala2/impala.py @@ -23,7 +23,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import ClipPPOLoss, A2CLoss + from torchrl.objectives import A2CLoss from torchrl.objectives.value import VTrace from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_parallel_env, make_ppo_models @@ -47,7 +47,8 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create collector collector = MultiaSyncDataCollector( - create_env_fn=[make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device)] * 8, + create_env_fn=[make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device)] + * 8, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -108,7 +109,9 @@ def main(cfg: "DictConfig"): # noqa: F821 num_network_updates = 0 start_time = time.time() pbar = tqdm.tqdm(total=total_frames) - total_network_updates = (total_frames // frames_per_batch * cfg.loss.batch_size) * cfg.loss.sgd_updates + total_network_updates = ( + total_frames // frames_per_batch * cfg.loss.batch_size + ) * cfg.loss.sgd_updates accumulator = [] sampling_start = time.time() @@ -128,7 +131,7 @@ def main(cfg: "DictConfig"): # noqa: F821 { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() - / len(episode_length), + / len(episode_length), } ) @@ -174,7 +177,7 @@ def main(cfg: "DictConfig"): # noqa: F821 "loss_critic", "loss_entropy", "loss_objective" ).detach() loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] ) # Backward pass @@ -203,7 +206,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.MODE): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( - i * frames_in_batch * frame_skip + i * frames_in_batch * frame_skip ) // test_interval: actor.eval() eval_start = time.time() diff --git a/examples/impala2/utils.py b/examples/impala2/utils.py index 6d2c686aec6..a7d478e54b1 100644 --- a/examples/impala2/utils.py +++ b/examples/impala2/utils.py @@ -18,6 +18,7 @@ ExplorationType, GrayScale, NoopResetEnv, + ObservationNorm, ParallelEnv, Resize, RewardClipping, @@ -25,8 +26,6 @@ StepCounter, ToTensorImage, TransformedEnv, - VecNorm, - ObservationNorm ) from torchrl.envs.libs.gym import GymWrapper from torchrl.modules import ( @@ -68,7 +67,7 @@ def reset(self, **kwargs): def make_base_env( - env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False + env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False ): env = gym.make(env_name) if not is_test: @@ -98,7 +97,11 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): if not is_test: env.append_transform(RewardClipping(-1, 1)) env.append_transform(DoubleToFloat()) - env.append_transform(ObservationNorm(in_keys=['pixels'], scale=1/255., loc=-0.5, standard_normal=False)) + env.append_transform( + ObservationNorm( + in_keys=["pixels"], scale=1 / 255.0, loc=-0.5, standard_normal=False + ) + ) return env diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index fa1225a3ff2..b0e9ea3d4d3 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -20,7 +20,13 @@ distance_loss, ValueEstimators, ) -from torchrl.objectives.value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator, VTrace +from torchrl.objectives.value import ( + GAE, + TD0Estimator, + TD1Estimator, + TDLambdaEstimator, + VTrace, +) class A2CLoss(LossModule): diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 0f607b40d24..dc3524f6ad5 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -1462,7 +1462,9 @@ def forward( # Make sure we have the log prob computed at collection time if self.tensor_keys.sample_log_prob not in tensordict.keys(): - raise ValueError(f"Expected {self.tensor_keys.sample_log_prob} to be in tensordict") + raise ValueError( + f"Expected {self.tensor_keys.sample_log_prob} to be in tensordict" + ) log_mu = tensordict.get(self.tensor_keys.sample_log_prob).view_as(value) # Compute log prob with current policy diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index b5dde0cec79..a44e045d059 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -1057,6 +1057,7 @@ def vec_td_lambda_advantage_estimate( # V-Trace # ----- + @_transpose_time def vtrace_advantage_estimate( gamma: float, @@ -1103,7 +1104,7 @@ def vtrace_advantage_estimate( clipped_rho = rho.clamp_max(rho_thresh) deltas = clipped_rho * (reward + discounts * next_state_value - state_value) c_thresh = c_thresh.to(device) - clipped_c = torch.clamp(rho, max=c_thresh) + clipped_c = rho.clamp_max(c_thresh) vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])] for i in reversed(range(time_steps)): From ee692f50a7a8902ce9fbe442b78e9845abcdb8bc Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 25 Sep 2023 17:03:18 +0200 Subject: [PATCH 029/109] working impala script --- examples/impala/README.md | 0 examples/impala/config.yaml | 13 +- examples/impala/impala.py | 230 ----------------- .../impala_single_node.py} | 64 ++--- examples/impala/utils.py | 70 +++-- examples/impala2/config.yaml | 35 --- examples/impala2/utils.py | 243 ------------------ 7 files changed, 74 insertions(+), 581 deletions(-) create mode 100644 examples/impala/README.md delete mode 100644 examples/impala/impala.py rename examples/{impala2/impala.py => impala/impala_single_node.py} (81%) delete mode 100644 examples/impala2/config.yaml delete mode 100644 examples/impala2/utils.py diff --git a/examples/impala/README.md b/examples/impala/README.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index ad5bd4d08f2..e55827560c8 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -1,12 +1,13 @@ # Environment env: env_name: PongNoFrameskip-v4 - num_envs: 32 + num_envs: 4 # collector collector: - frames_per_batch: 2560 + frames_per_batch: 320 total_frames: 200_000_000 + num_workers: 8 # logger logger: @@ -18,7 +19,7 @@ logger: # Optim optim: lr: 0.0006 - eps: 0.01 + eps: 1e-8 weight_decay: 0.0 momentum: 0.0 alpha: 0.99 @@ -28,8 +29,8 @@ optim: # loss loss: gamma: 0.99 - mini_batch_size: 2560 - ppo_epochs: 1 - critic_coef: 0.25 + batch_size: 8 + sgd_updates: 1 + critic_coef: 0.5 entropy_coef: 0.01 loss_critic_type: l2 diff --git a/examples/impala/impala.py b/examples/impala/impala.py deleted file mode 100644 index 4b00c6cdd58..00000000000 --- a/examples/impala/impala.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -""" -This script reproduces the IMPALA Algorithm -results from Espeholt et al. 2018 for the on Atari Environments. -""" -import hydra - - -@hydra.main(config_path=".", config_name="config", version_base="1.1") -def main(cfg: "DictConfig"): # noqa: F821 - - import time - - import torch.optim - import tqdm - - from tensordict import TensorDict - from torchrl.collectors import MultiaSyncDataCollector - from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer - from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement - from torchrl.envs import ExplorationType, set_exploration_type - from torchrl.objectives import A2CLoss - from torchrl.objectives.value import VTrace - from torchrl.record.loggers import generate_exp_name, get_logger - from utils import eval_model, make_parallel_env, make_ppo_models - - device = "cpu" if not torch.cuda.device_count() else "cuda" - - # Correct for frame_skip - frame_skip = 4 - total_frames = cfg.collector.total_frames // frame_skip - frames_per_batch = cfg.collector.frames_per_batch // frame_skip - mini_batch_size = cfg.loss.mini_batch_size // frame_skip - test_interval = cfg.logger.test_interval // frame_skip - - # Create models (check utils_atari.py) - actor, critic, critic_head = make_ppo_models(cfg.env.env_name) - actor, critic, critic_head = ( - actor.to(device), - critic.to(device), - critic_head.to(device), - ) - - # Create collector - collector = MultiaSyncDataCollector( - create_env_fn=[make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device)] - * 1, - policy=actor, - frames_per_batch=frames_per_batch, - total_frames=total_frames, - device=device, - storing_device=device, - max_frames_per_traj=-1, - update_at_each_batch=True, - ) - - # Create data buffer - sampler = SamplerWithoutReplacement() - data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(frames_per_batch), - sampler=sampler, - batch_size=mini_batch_size, - ) - - # Create loss and adv modules - adv_module = VTrace( - gamma=cfg.loss.gamma, - value_network=critic, - actor_network=actor, - average_adv=False, - ) - loss_module = A2CLoss( - actor=actor, - critic=critic, - loss_critic_type=cfg.loss.loss_critic_type, - entropy_coef=cfg.loss.entropy_coef, - critic_coef=cfg.loss.critic_coef, - ) - - # Create optimizer - optim = torch.optim.RMSprop( - loss_module.parameters(), - lr=cfg.optim.lr, - weight_decay=cfg.optim.weight_decay, - eps=cfg.optim.eps, - alpha=cfg.optim.alpha, - ) - - # Create logger - logger = None - if cfg.logger.backend: - exp_name = generate_exp_name( - "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" - ) - logger = get_logger( - cfg.logger.backend, logger_name="impala", experiment_name=exp_name - ) - - # Create test environment - test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) - test_env.eval() - - # Main loop - collected_frames = 0 - num_network_updates = 0 - start_time = time.time() - pbar = tqdm.tqdm(total=total_frames) - num_mini_batches = frames_per_batch // mini_batch_size - total_network_updates = ( - (total_frames // frames_per_batch) * cfg.loss.ppo_epochs * num_mini_batches - ) - - sampling_start = time.time() - for i, data in enumerate(collector): - - log_info = {} - sampling_time = time.time() - sampling_start - frames_in_batch = data.numel() - collected_frames += frames_in_batch * frame_skip - pbar.update(data.numel()) - - # Get training rewards and episode lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] - if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "done"]] - log_info.update( - { - "train/reward": episode_rewards.mean().item(), - "train/episode_length": episode_length.sum().item() - / len(episode_length), - } - ) - - # Apply episodic end of life - data["done"].copy_(data["end_of_life"]) - data["next", "done"].copy_(data["next", "end_of_life"]) - - losses = TensorDict({}, batch_size=[cfg.loss.ppo_epochs, num_mini_batches]) - training_start = time.time() - for j in range(cfg.loss.ppo_epochs): - - # Compute adv - with torch.no_grad(): - data = adv_module(data) - data_reshape = data.reshape(-1) - - # Update the data buffer - data_buffer.extend(data_reshape) - - for k, batch in enumerate(data_buffer): - - # Linearly decrease the learning rate and clip epsilon - alpha = 1 - (num_network_updates / total_network_updates) - if cfg.optim.anneal_lr: - for group in optim.param_groups: - group["lr"] = cfg.optim.lr * alpha - num_network_updates += 1 - - # Get a data batch - batch = batch.to(device) - - # Forward pass PPO loss - loss = loss_module(batch) - losses[j, k] = loss.select( - "loss_critic", "loss_entropy", "loss_objective" - ).detach() - loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] - ) - - # Backward pass - loss_sum.backward() - torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm - ) - - # Update the networks - optim.step() - optim.zero_grad() - - # Get training losses and times - training_time = time.time() - training_start - losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) - for key, value in losses.items(): - log_info.update({f"train/{key}": value.item()}) - log_info.update( - { - "train/lr": alpha * cfg.optim.lr, - "train/sampling_time": sampling_time, - "train/training_time": training_time, - } - ) - - # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): - if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( - i * frames_in_batch * frame_skip - ) // test_interval: - actor.eval() - eval_start = time.time() - test_rewards = eval_model( - actor, test_env, num_episodes=cfg.logger.num_test_episodes - ) - eval_time = time.time() - eval_start - log_info.update( - { - "eval/reward": test_rewards.mean(), - "eval/time": eval_time, - } - ) - actor.train() - - if logger: - for key, value in log_info.items(): - logger.log_scalar(key, value, collected_frames) - - collector.update_policy_weights_() - sampling_start = time.time() - - end_time = time.time() - execution_time = end_time - start_time - print(f"Training took {execution_time:.2f} seconds to finish") - - -if __name__ == "__main__": - main() diff --git a/examples/impala2/impala.py b/examples/impala/impala_single_node.py similarity index 81% rename from examples/impala2/impala.py rename to examples/impala/impala_single_node.py index da920d347bc..78bcef15d9b 100644 --- a/examples/impala2/impala.py +++ b/examples/impala/impala_single_node.py @@ -34,21 +34,30 @@ def main(cfg: "DictConfig"): # noqa: F821 frame_skip = 4 total_frames = cfg.collector.total_frames // frame_skip frames_per_batch = cfg.collector.frames_per_batch // frame_skip - mini_batch_size = cfg.loss.mini_batch_size // frame_skip test_interval = cfg.logger.test_interval // frame_skip + # Extract other config parameters + batch_size = cfg.loss.batch_size # Number of rollouts per batch + num_workers = ( + cfg.collector.num_workers + ) # Number of parallel workers collecting rollouts + lr = cfg.optim.lr + anneal_lr = cfg.optim.anneal_lr + sgd_updates = cfg.loss.sgd_updates + max_grad_norm = cfg.optim.max_grad_norm + num_test_episodes = cfg.logger.num_test_episodes + total_network_updates = ( + total_frames // frames_per_batch * batch_size + ) * cfg.loss.sgd_updates + # Create models (check utils_atari.py) actor, critic, critic_head = make_ppo_models(cfg.env.env_name) - actor, critic, critic_head = ( - actor.to(device), - critic.to(device), - critic_head.to(device), - ) + actor, critic = (actor.to(device), critic.to(device)) # Create collector collector = MultiaSyncDataCollector( create_env_fn=[make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device)] - * 8, + * num_workers, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, @@ -61,9 +70,9 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( - storage=LazyMemmapStorage(frames_per_batch * cfg.loss.batch_size), + storage=LazyMemmapStorage(frames_per_batch * batch_size), sampler=sampler, - batch_size=mini_batch_size * cfg.loss.batch_size, + batch_size=frames_per_batch * batch_size, ) # Create loss and adv modules @@ -109,10 +118,6 @@ def main(cfg: "DictConfig"): # noqa: F821 num_network_updates = 0 start_time = time.time() pbar = tqdm.tqdm(total=total_frames) - total_network_updates = ( - total_frames // frames_per_batch * cfg.loss.batch_size - ) * cfg.loss.sgd_updates - accumulator = [] sampling_start = time.time() for i, data in enumerate(collector): @@ -136,19 +141,19 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Apply episodic end of life - data["done"].copy_(data["end_of_life"]) - data["next", "done"].copy_(data["next", "end_of_life"]) + # data["done"].copy_(data["end_of_life"]) + # data["next", "done"].copy_(data["next", "end_of_life"]) - if len(accumulator) < cfg.loss.batch_size: + if len(accumulator) < batch_size: accumulator.append(data) if logger: for key, value in log_info.items(): logger.log_scalar(key, value, collected_frames) continue - losses = TensorDict({}, batch_size=[cfg.loss.sgd_updates, 1]) + losses = TensorDict({}, batch_size=[sgd_updates]) training_start = time.time() - for j in range(cfg.loss.sgd_updates): + for j in range(sgd_updates): for acc_data in accumulator: @@ -159,21 +164,22 @@ def main(cfg: "DictConfig"): # noqa: F821 # Update the data buffer data_buffer.extend(acc_data_reshape) - for k, batch in enumerate(data_buffer): + for batch in data_buffer: # Linearly decrease the learning rate and clip epsilon - alpha = 1 - (num_network_updates / total_network_updates) - if cfg.optim.anneal_lr: + alpha = 1.0 + if anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) for group in optim.param_groups: - group["lr"] = cfg.optim.lr * alpha + group["lr"] = lr * alpha num_network_updates += 1 # Get a data batch batch = batch.to(device) - # Forward pass PPO loss + # Forward pass loss loss = loss_module(batch) - losses[j, k] = loss.select( + losses[j] = loss.select( "loss_critic", "loss_entropy", "loss_objective" ).detach() loss_sum = ( @@ -183,7 +189,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Backward pass loss_sum.backward() torch.nn.utils.clip_grad_norm_( - list(loss_module.parameters()), max_norm=cfg.optim.max_grad_norm + list(loss_module.parameters()), max_norm=max_grad_norm ) # Update the networks @@ -197,7 +203,7 @@ def main(cfg: "DictConfig"): # noqa: F821 log_info.update({f"train/{key}": value.item()}) log_info.update( { - "train/lr": alpha * cfg.optim.lr, + "train/lr": alpha * lr, "train/sampling_time": sampling_time, "train/training_time": training_time, } @@ -210,13 +216,13 @@ def main(cfg: "DictConfig"): # noqa: F821 ) // test_interval: actor.eval() eval_start = time.time() - test_rewards = eval_model( - actor, test_env, num_episodes=cfg.logger.num_test_episodes + test_reward = eval_model( + actor, test_env, num_episodes=num_test_episodes ) eval_time = time.time() - eval_start log_info.update( { - "eval/reward": test_rewards.mean(), + "eval/reward": test_reward, "eval/time": eval_time, } ) diff --git a/examples/impala/utils.py b/examples/impala/utils.py index a7d478e54b1..5e5b8736b2e 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -4,15 +4,13 @@ # LICENSE file in the root directory of this source tree. import gymnasium as gym -import numpy as np import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import CompositeSpec +from torchrl.data import CompositeSpec, UnboundedDiscreteTensorSpec from torchrl.data.tensor_specs import DiscreteBox from torchrl.envs import ( CatFrames, - default_info_dict_reader, DoubleToFloat, EnvCreator, ExplorationType, @@ -25,6 +23,7 @@ RewardSum, StepCounter, ToTensorImage, + Transform, TransformedEnv, ) from torchrl.envs.libs.gym import GymWrapper @@ -43,43 +42,44 @@ # -------------------------------------------------------------------- -class EpisodicLifeEnv(gym.Wrapper): - def __init__(self, env): - """Make end-of-life == end-of-episode, but only reset on true game over. - Done by DeepMind for the DQN and co. It helps value estimation. - """ - gym.Wrapper.__init__(self, env) - self.lives = 0 - - def step(self, action): - obs, rew, done, truncated, info = self.env.step(action) - lives = self.env.unwrapped.ale.lives() - info["end_of_life"] = False - if (lives < self.lives) or done: - info["end_of_life"] = True - self.lives = lives - return obs, rew, done, truncated, info - - def reset(self, **kwargs): - reset_data = self.env.reset(**kwargs) - self.lives = self.env.unwrapped.ale.lives() - return reset_data +class EndOfLifeTransform(Transform): + def _step(self, tensordict, next_tensordict): + lives = self.parent.base_env._env.unwrapped.ale.lives() + end_of_life = torch.tensor( + [tensordict["lives"] < lives], device=self.parent.device + ) + end_of_life = end_of_life | next_tensordict.get("done") + next_tensordict.set("eol", end_of_life) + next_tensordict.set("lives", lives) + return next_tensordict + + def reset(self, tensordict): + lives = self.parent.base_env._env.unwrapped.ale.lives() + end_of_life = False + tensordict.set("eol", [end_of_life]) + tensordict.set("lives", lives) + return tensordict + + def transform_observation_spec(self, observation_spec): + full_done_spec = self.parent.output_spec["full_done_spec"] + observation_spec["eol"] = full_done_spec["done"].clone() + observation_spec["lives"] = UnboundedDiscreteTensorSpec( + self.parent.batch_size, device=self.parent.device + ) + return observation_spec def make_base_env( env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False ): env = gym.make(env_name) - if not is_test: - env = EpisodicLifeEnv(env) env = GymWrapper( env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: - reader = default_info_dict_reader(["end_of_life"]) - env.set_info_dict_reader(reader) + env.append_transform(EndOfLifeTransform()) return env @@ -208,18 +208,12 @@ def make_ppo_models(env_name): value_operator=value_module, ) - with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) - del td - actor = actor_critic.get_policy_operator() critic = actor_critic.get_value_operator() - critic_head = actor_critic.get_value_head() del proof_environment - return actor, critic, critic_head + return actor, critic # ==================================================================== @@ -228,8 +222,8 @@ def make_ppo_models(env_name): def eval_model(actor, test_env, num_episodes=3): - test_rewards = [] - for _ in range(num_episodes): + test_rewards = torch.zeros(num_episodes, dtype=torch.float32) + for i in range(num_episodes): td_test = test_env.rollout( policy=actor, auto_reset=True, @@ -238,6 +232,6 @@ def eval_model(actor, test_env, num_episodes=3): max_steps=10_000_000, ) reward = td_test["next", "episode_reward"][td_test["next", "done"]] - test_rewards = np.append(test_rewards, reward.cpu().numpy()) + test_rewards[i] = reward.sum() del td_test return test_rewards.mean() diff --git a/examples/impala2/config.yaml b/examples/impala2/config.yaml deleted file mode 100644 index a0e6e3b4651..00000000000 --- a/examples/impala2/config.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# Environment -env: - env_name: PongNoFrameskip-v4 - num_envs: 4 - -# collector -collector: - frames_per_batch: 320 - total_frames: 200_000_000 - -# logger -logger: - backend: wandb - exp_name: Atari_Schulman17 - test_interval: 200_000_000 - num_test_episodes: 3 - -# Optim -optim: - lr: 0.0006 - eps: 0.01 - weight_decay: 0.0 - momentum: 0.0 - alpha: 0.99 - max_grad_norm: 1.0 - anneal_lr: True - -# loss -loss: - gamma: 0.99 - batch_size: 8 - sgd_updates: 1 - critic_coef: 0.25 - entropy_coef: 0.01 - loss_critic_type: l2 diff --git a/examples/impala2/utils.py b/examples/impala2/utils.py deleted file mode 100644 index a7d478e54b1..00000000000 --- a/examples/impala2/utils.py +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import gymnasium as gym -import numpy as np -import torch.nn -import torch.optim -from tensordict.nn import TensorDictModule -from torchrl.data import CompositeSpec -from torchrl.data.tensor_specs import DiscreteBox -from torchrl.envs import ( - CatFrames, - default_info_dict_reader, - DoubleToFloat, - EnvCreator, - ExplorationType, - GrayScale, - NoopResetEnv, - ObservationNorm, - ParallelEnv, - Resize, - RewardClipping, - RewardSum, - StepCounter, - ToTensorImage, - TransformedEnv, -) -from torchrl.envs.libs.gym import GymWrapper -from torchrl.modules import ( - ActorValueOperator, - ConvNet, - MLP, - OneHotCategorical, - ProbabilisticActor, - TanhNormal, - ValueOperator, -) - -# ==================================================================== -# Environment utils -# -------------------------------------------------------------------- - - -class EpisodicLifeEnv(gym.Wrapper): - def __init__(self, env): - """Make end-of-life == end-of-episode, but only reset on true game over. - Done by DeepMind for the DQN and co. It helps value estimation. - """ - gym.Wrapper.__init__(self, env) - self.lives = 0 - - def step(self, action): - obs, rew, done, truncated, info = self.env.step(action) - lives = self.env.unwrapped.ale.lives() - info["end_of_life"] = False - if (lives < self.lives) or done: - info["end_of_life"] = True - self.lives = lives - return obs, rew, done, truncated, info - - def reset(self, **kwargs): - reset_data = self.env.reset(**kwargs) - self.lives = self.env.unwrapped.ale.lives() - return reset_data - - -def make_base_env( - env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False -): - env = gym.make(env_name) - if not is_test: - env = EpisodicLifeEnv(env) - env = GymWrapper( - env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device - ) - env = TransformedEnv(env) - env.append_transform(NoopResetEnv(noops=30, random=True)) - if not is_test: - reader = default_info_dict_reader(["end_of_life"]) - env.set_info_dict_reader(reader) - return env - - -def make_parallel_env(env_name, num_envs, device, is_test=False): - env = ParallelEnv( - num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) - ) - env = TransformedEnv(env) - env.append_transform(ToTensorImage(from_int=False)) - env.append_transform(GrayScale()) - env.append_transform(Resize(84, 84)) - env.append_transform(CatFrames(N=4, dim=-3)) - env.append_transform(RewardSum()) - env.append_transform(StepCounter(max_steps=4500)) - if not is_test: - env.append_transform(RewardClipping(-1, 1)) - env.append_transform(DoubleToFloat()) - env.append_transform( - ObservationNorm( - in_keys=["pixels"], scale=1 / 255.0, loc=-0.5, standard_normal=False - ) - ) - return env - - -# ==================================================================== -# Model utils -# -------------------------------------------------------------------- - - -def make_ppo_modules_pixels(proof_environment): - - # Define input shape - input_shape = proof_environment.observation_spec["pixels"].shape - - # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, DiscreteBox): - num_outputs = proof_environment.action_spec.space.n - distribution_class = OneHotCategorical - distribution_kwargs = {} - else: # is ContinuousBox - num_outputs = proof_environment.action_spec.shape - distribution_class = TanhNormal - distribution_kwargs = { - "min": proof_environment.action_spec.space.minimum, - "max": proof_environment.action_spec.space.maximum, - } - - # Define input keys - in_keys = ["pixels"] - - # Define a shared Module and TensorDictModule (CNN + MLP) - common_cnn = ConvNet( - activation_class=torch.nn.ReLU, - num_cells=[32, 64, 64], - kernel_sizes=[8, 4, 3], - strides=[4, 2, 1], - ) - common_cnn_output = common_cnn(torch.ones(input_shape)) - common_mlp = MLP( - in_features=common_cnn_output.shape[-1], - activation_class=torch.nn.ReLU, - activate_last_layer=True, - out_features=512, - num_cells=[], - ) - common_mlp_output = common_mlp(common_cnn_output) - - # Define shared net as TensorDictModule - common_module = TensorDictModule( - module=torch.nn.Sequential(common_cnn, common_mlp), - in_keys=in_keys, - out_keys=["common_features"], - ) - - # Define on head for the policy - policy_net = MLP( - in_features=common_mlp_output.shape[-1], - out_features=num_outputs, - activation_class=torch.nn.ReLU, - num_cells=[], - ) - policy_module = TensorDictModule( - module=policy_net, - in_keys=["common_features"], - out_keys=["logits"], - ) - - # Add probabilistic sampling of the actions - policy_module = ProbabilisticActor( - policy_module, - in_keys=["logits"], - spec=CompositeSpec(action=proof_environment.action_spec), - distribution_class=distribution_class, - distribution_kwargs=distribution_kwargs, - return_log_prob=True, - default_interaction_type=ExplorationType.RANDOM, - ) - - # Define another head for the value - value_net = MLP( - activation_class=torch.nn.ReLU, - in_features=common_mlp_output.shape[-1], - out_features=1, - num_cells=[], - ) - value_module = ValueOperator( - value_net, - in_keys=["common_features"], - ) - - return common_module, policy_module, value_module - - -def make_ppo_models(env_name): - - proof_environment = make_parallel_env(env_name, 1, device="cpu") - common_module, policy_module, value_module = make_ppo_modules_pixels( - proof_environment - ) - - # Wrap modules in a single ActorCritic operator - actor_critic = ActorValueOperator( - common_operator=common_module, - policy_operator=policy_module, - value_operator=value_module, - ) - - with torch.no_grad(): - td = proof_environment.rollout(max_steps=100, break_when_any_done=False) - td = actor_critic(td) - del td - - actor = actor_critic.get_policy_operator() - critic = actor_critic.get_value_operator() - critic_head = actor_critic.get_value_head() - - del proof_environment - - return actor, critic, critic_head - - -# ==================================================================== -# Evaluation utils -# -------------------------------------------------------------------- - - -def eval_model(actor, test_env, num_episodes=3): - test_rewards = [] - for _ in range(num_episodes): - td_test = test_env.rollout( - policy=actor, - auto_reset=True, - auto_cast_to_device=True, - break_when_any_done=True, - max_steps=10_000_000, - ) - reward = td_test["next", "episode_reward"][td_test["next", "done"]] - test_rewards = np.append(test_rewards, reward.cpu().numpy()) - del td_test - return test_rewards.mean() From a5eb8b6878f478510130c6ad33ee21ac2592b75a Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 25 Sep 2023 17:11:12 +0200 Subject: [PATCH 030/109] working impala script --- examples/impala/config.yaml | 8 ++++---- examples/impala/impala_single_node.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index e55827560c8..b43d1e9401f 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -1,13 +1,13 @@ # Environment env: env_name: PongNoFrameskip-v4 - num_envs: 4 + num_envs: 1 # collector collector: - frames_per_batch: 320 + frames_per_batch: 80 total_frames: 200_000_000 - num_workers: 8 + num_workers: 12 # logger logger: @@ -29,7 +29,7 @@ optim: # loss loss: gamma: 0.99 - batch_size: 8 + batch_size: 32 sgd_updates: 1 critic_coef: 0.5 entropy_coef: 0.01 diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index 78bcef15d9b..b5faa22a20b 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -51,7 +51,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) * cfg.loss.sgd_updates # Create models (check utils_atari.py) - actor, critic, critic_head = make_ppo_models(cfg.env.env_name) + actor, critic = make_ppo_models(cfg.env.env_name) actor, critic = (actor.to(device), critic.to(device)) # Create collector From b9e81d22ca46355c72547c1c0f518bb7ced23e19 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 25 Sep 2023 17:52:29 +0200 Subject: [PATCH 031/109] test offpolicy losses --- examples/impala/config.yaml | 2 +- test/test_cost.py | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/examples/impala/config.yaml b/examples/impala/config.yaml index b43d1e9401f..b777e0db9aa 100644 --- a/examples/impala/config.yaml +++ b/examples/impala/config.yaml @@ -23,7 +23,7 @@ optim: weight_decay: 0.0 momentum: 0.0 alpha: 0.99 - max_grad_norm: 1.0 + max_grad_norm: 40.0 anneal_lr: True # loss diff --git a/test/test_cost.py b/test/test_cost.py index 6b940331f7e..1e655ee9a6d 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -129,6 +129,7 @@ GAE, TD1Estimator, TDLambdaEstimator, + # VTrace, ) from torchrl.objectives.value.functional import ( _transpose_time, @@ -139,6 +140,7 @@ vec_generalized_advantage_estimate, vec_td1_advantage_estimate, vec_td_lambda_advantage_estimate, + # vtrace_advantage_estimate, ) from torchrl.objectives.value.utils import ( _custom_conv1d, @@ -432,7 +434,7 @@ def test_dqn(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = DQNLoss(actor, loss_function="l2", delay_value=delay_value) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -900,7 +902,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -1363,7 +1365,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): delay_actor=delay_actor, delay_value=delay_value, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -1957,7 +1959,7 @@ def test_td3( delay_actor=delay_actor, delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -2624,7 +2626,7 @@ def test_sac( **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -3351,7 +3353,7 @@ def test_discrete_sac( loss_function="l2", **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -3944,7 +3946,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4311,7 +4313,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4328,7 +4330,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_deprec.make_value_estimator(td_est) return @@ -4734,7 +4736,7 @@ def test_cql( **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return From d04d050fd74dcaab34268f0123188691f43b798b Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 25 Sep 2023 18:00:27 +0200 Subject: [PATCH 032/109] minor script fixes --- examples/impala/impala_single_node.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index b5faa22a20b..edbbe23260e 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -47,12 +47,12 @@ def main(cfg: "DictConfig"): # noqa: F821 max_grad_norm = cfg.optim.max_grad_norm num_test_episodes = cfg.logger.num_test_episodes total_network_updates = ( - total_frames // frames_per_batch * batch_size + total_frames // (frames_per_batch * batch_size) ) * cfg.loss.sgd_updates # Create models (check utils_atari.py) actor, critic = make_ppo_models(cfg.env.env_name) - actor, critic = (actor.to(device), critic.to(device)) + actor, critic = actor.to(device), critic.to(device) # Create collector collector = MultiaSyncDataCollector( @@ -236,6 +236,7 @@ def main(cfg: "DictConfig"): # noqa: F821 sampling_start = time.time() accumulator = [] + collector.shutdown() end_time = time.time() execution_time = end_time - start_time print(f"Training took {execution_time:.2f} seconds to finish") From 2a15708048a2c891b20de291fb8147d055ae771b Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 25 Sep 2023 18:48:35 +0200 Subject: [PATCH 033/109] test onpolicy losses --- test/test_cost.py | 178 +++++++++++++++++-------- torchrl/objectives/value/advantages.py | 2 +- 2 files changed, 121 insertions(+), 59 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 1e655ee9a6d..aad731cc39b 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -129,7 +129,7 @@ GAE, TD1Estimator, TDLambdaEstimator, - # VTrace, + VTrace, ) from torchrl.objectives.value.functional import ( _transpose_time, @@ -5195,7 +5195,7 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): @@ -5208,6 +5208,13 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -5274,7 +5281,7 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode): loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_ppo_shared(self, loss_class, device, advantage): torch.manual_seed(self.seed) @@ -5287,6 +5294,12 @@ def test_ppo_shared(self, loss_class, device, advantage): lmbda=0.9, value_network=value, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -5418,7 +5431,7 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses): ) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): @@ -5432,6 +5445,13 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -5494,6 +5514,7 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, + ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -5533,7 +5554,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est): self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): """Test PPO loss module with non-default tensordict keys.""" @@ -5561,6 +5582,13 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): value_network=value, differentiable=gradient_mode, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -5805,10 +5833,8 @@ def _create_seq_mock_data_a2c( action_dim=4, atoms=None, device="cpu", + sample_log_prob_key="sample_log_prob", action_key="action", - observation_key="observation", - reward_key="reward", - done_key="done", ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -5822,24 +5848,23 @@ def _create_seq_mock_data_a2c( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) + mask = torch.ones(batch, T, dtype=torch.bool, device=device) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 td = TensorDict( batch_size=(batch, T), source={ - observation_key: obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "next": { - observation_key: next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), - done_key: done, - reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "done": done, + "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0), - "sample_log_prob": torch.randn_like(action[..., 1]).masked_fill_( - ~mask, 0.0 - ) - / 10, + sample_log_prob_key: ( + torch.randn_like(action[..., 1]) / 10 + ).masked_fill_(~mask, 0.0), "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0), "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0), }, @@ -5849,7 +5874,7 @@ def _create_seq_mock_data_a2c( return td @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_a2c(self, device, gradient_mode, advantage, td_est): @@ -5862,6 +5887,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -5986,7 +6018,7 @@ def test_a2c_separate_losses(self, separate_losses): not _has_functorch, reason=f"functorch not found, {FUNCTORCH_ERR}" ) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_a2c_diff(self, device, gradient_mode, advantage): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): @@ -6000,6 +6032,13 @@ def test_a2c_diff(self, device, gradient_mode, advantage): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -6053,6 +6092,7 @@ def test_a2c_diff(self, device, gradient_mode, advantage): ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, + ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -6090,51 +6130,76 @@ def test_a2c_tensordict_keys(self, td_est): } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_tensordict_keys_run(self, device): + def test_a2c_tensordict_keys_run(self, device, advantage, td_est): """Test A2C loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True - advantage_key = "advantage_test" - value_target_key = "value_target_test" - value_key = "state_value_test" - action_key = "action_test" - reward_key = "reward_test" - done_key = ("done", "test") + tensor_keys = { + "advantage": "advantage_test", + "value_target": "value_target_test", + "value": "state_value_test", + "sample_log_prob": "sample_log_prob_test", + "action": "action_test", + } td = self._create_seq_mock_data_a2c( + sample_log_prob_key=tensor_keys["sample_log_prob"], + action_key=tensor_keys["action"], device=device, - action_key=action_key, - reward_key=reward_key, - done_key=done_key, ) actor = self._create_mock_actor(device=device) - value = self._create_mock_value(device=device, out_keys=[value_key]) - advantage = GAE( - gamma=0.9, - lmbda=0.9, - value_network=value, - differentiable=gradient_mode, - ) - advantage.set_keys( - advantage=advantage_key, - value_target=value_target_key, - value=value_key, - reward=reward_key, - done=done_key, - ) - loss_fn = A2CLoss(actor, value, loss_critic_type="l2") - loss_fn.set_keys( - advantage=advantage_key, - value_target=value_target_key, - value=value_key, - action=action_key, - reward=reward_key, - done=done_key, - ) + value = self._create_mock_value(device=device, out_keys=[tensor_keys["value"]]) - advantage(td) + if advantage == "gae": + advantage = GAE( + gamma=0.9, + lmbda=0.9, + value_network=value, + differentiable=gradient_mode, + ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) + elif advantage == "td": + advantage = TD1Estimator( + gamma=0.9, + value_network=value, + differentiable=gradient_mode, + ) + elif advantage == "td_lambda": + advantage = TDLambdaEstimator( + gamma=0.9, + lmbda=0.9, + value_network=value, + differentiable=gradient_mode, + ) + elif advantage is None: + pass + else: + raise NotImplementedError + + loss_fn = A2CLoss(actor, value, loss_critic_type="l2") + loss_fn.set_keys(**tensor_keys) + if advantage is not None: + # collect tensordict key names for the advantage module + adv_keys = { + key: value + for key, value in tensor_keys.items() + if key in asdict(GAE._AcceptedKeys()).keys() + } + advantage.set_keys(**adv_keys) + advantage(td) + else: + if td_est is not None: + loss_fn.make_value_estimator(td_est) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -6175,9 +6240,6 @@ def test_a2c_notensordict(self, action_key, observation_key, reward_key, done_ke value = self._create_mock_value(observation_key=observation_key) td = self._create_seq_mock_data_a2c( action_key=action_key, - observation_key=observation_key, - reward_key=reward_key, - done_key=done_key, ) loss = A2CLoss(actor, value) @@ -6924,7 +6986,7 @@ def test_dreamer_actor(self, device, imagination_horizon, discount_loss, td_est) imagination_horizon=imagination_horizon, discount_loss=discount_loss, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_module.make_value_estimator(td_est) return @@ -7656,7 +7718,7 @@ def test_iql( expectile=expectile, loss_function="l2", ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index dc3524f6ad5..274b6a1b520 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -584,7 +584,7 @@ def value_estimate( if self.average_rewards: reward = reward - reward.mean() - reward = reward / reward.std().clamp_min(1e-4) + reward = reward / reward.std().clamp_min(1e-5) tensordict.set( ("next", self.tensor_keys.reward), reward ) # we must update the rewards if they are used later in the code From a5c204601ccf98e7743c3d190d20358e182b7b5b Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 25 Sep 2023 19:18:56 +0200 Subject: [PATCH 034/109] test fix --- test/test_cost.py | 9 ++++++++- torchrl/objectives/reinforce.py | 10 +++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index aad731cc39b..241835d2682 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6283,7 +6283,7 @@ class TestReinforce(LossModuleTestBase): @pytest.mark.parametrize("delay_value", [True, False]) @pytest.mark.parametrize("gradient_mode", [True, False]) - @pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None]) + @pytest.mark.parametrize("advantage", ["gae", "vtrace", "td", "td_lambda", None]) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est): n_obs = 3 @@ -6309,6 +6309,13 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est value_network=get_functional(value_net), differentiable=gradient_mode, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=get_functional(value_net), + actor_network=actor_net, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=gamma, diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 7c314bace36..050bb67f747 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -18,7 +18,13 @@ distance_loss, ValueEstimators, ) -from torchrl.objectives.value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator +from torchrl.objectives.value import ( + GAE, + TD0Estimator, + TD1Estimator, + TDLambdaEstimator, + VTrace, +) class ReinforceLoss(LossModule): @@ -332,6 +338,8 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = GAE(value_network=self.critic, **hp) elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + elif value_type == ValueEstimators.VTrace: + self._value_estimator = VTrace(value_network=self.critic, **hp) else: raise NotImplementedError(f"Unknown value type {value_type}") From dbde27c0282e23b124dc17bc066998a58dd61236 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 25 Sep 2023 19:49:44 +0200 Subject: [PATCH 035/109] test fix --- test/test_cost.py | 2 ++ torchrl/objectives/a2c.py | 4 +++- torchrl/objectives/ppo.py | 4 +++- torchrl/objectives/reinforce.py | 4 +++- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 241835d2682..d03dea93011 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6349,6 +6349,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est "done": torch.zeros(batch, 1, dtype=torch.bool), }, "action": torch.randn(batch, n_act), + "sample_log_prob": torch.randn(batch, 1), }, [batch], names=["time"], @@ -6399,6 +6400,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, + ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index b0e9ea3d4d3..403b4601a45 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -387,7 +387,9 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) elif value_type == ValueEstimators.VTrace: - self._value_estimator = VTrace(value_network=self.critic, **hp) + self._value_estimator = VTrace( + value_network=self.critic, actor_network=self.actor, **hp + ) else: raise NotImplementedError(f"Unknown value type {value_type}") diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index ac9df698df2..999d0f2f99d 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -463,7 +463,9 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) elif value_type == ValueEstimators.VTrace: - self._value_estimator = VTrace(value_network=self.critic, **hp) + self._value_estimator = VTrace( + value_network=self.critic, actor_network=self.actor, **hp + ) else: raise NotImplementedError(f"Unknown value type {value_type}") diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 050bb67f747..2afb0d118ef 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -339,7 +339,9 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) elif value_type == ValueEstimators.VTrace: - self._value_estimator = VTrace(value_network=self.critic, **hp) + self._value_estimator = VTrace( + value_network=self.critic, actor_network=self.actor, **hp + ) else: raise NotImplementedError(f"Unknown value type {value_type}") From 74112351662d53f4489f2642039d0b16810a5780 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 26 Sep 2023 09:31:20 +0200 Subject: [PATCH 036/109] test fix --- torchrl/objectives/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 1dd3cfc5f35..b8ec5ec7c32 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -62,6 +62,8 @@ def default_value_kwargs(value_type: ValueEstimators): return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True} elif value_type == ValueEstimators.TDLambda: return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True} + elif value_type == ValueEstimators.VTrace: + return {"gamma": 0.99, "differentiable": True} else: raise NotImplementedError(f"Unknown value type {value_type}.") From fa5f835c15e8d8621f67231eeb93b8564e35cbed Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 26 Sep 2023 09:33:05 +0200 Subject: [PATCH 037/109] test fix --- torchrl/objectives/reinforce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 2afb0d118ef..5aaffe56702 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -340,7 +340,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) elif value_type == ValueEstimators.VTrace: self._value_estimator = VTrace( - value_network=self.critic, actor_network=self.actor, **hp + value_network=self.critic, actor_network=self.actor_network, **hp ) else: raise NotImplementedError(f"Unknown value type {value_type}") From 30e0cc132a6ecf59e8edd5df037fe4fee2672a18 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 2 Oct 2023 12:01:52 +0200 Subject: [PATCH 038/109] test fix --- .../{config.yaml => config_multi_node.yaml} | 0 examples/impala/config_single_node.yaml | 36 +++ examples/impala/impala_multi_node.py | 251 ++++++++++++++++++ examples/impala/impala_single_node.py | 6 +- test/test_cost.py | 2 +- 5 files changed, 291 insertions(+), 4 deletions(-) rename examples/impala/{config.yaml => config_multi_node.yaml} (100%) create mode 100644 examples/impala/config_single_node.yaml create mode 100644 examples/impala/impala_multi_node.py diff --git a/examples/impala/config.yaml b/examples/impala/config_multi_node.yaml similarity index 100% rename from examples/impala/config.yaml rename to examples/impala/config_multi_node.yaml diff --git a/examples/impala/config_single_node.yaml b/examples/impala/config_single_node.yaml new file mode 100644 index 00000000000..b777e0db9aa --- /dev/null +++ b/examples/impala/config_single_node.yaml @@ -0,0 +1,36 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + num_envs: 1 + +# collector +collector: + frames_per_batch: 80 + total_frames: 200_000_000 + num_workers: 12 + +# logger +logger: + backend: wandb + exp_name: Atari_Schulman17 + test_interval: 200_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 0.0006 + eps: 1e-8 + weight_decay: 0.0 + momentum: 0.0 + alpha: 0.99 + max_grad_norm: 40.0 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + batch_size: 32 + sgd_updates: 1 + critic_coef: 0.5 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/impala/impala_multi_node.py b/examples/impala/impala_multi_node.py new file mode 100644 index 00000000000..c815e0a453f --- /dev/null +++ b/examples/impala/impala_multi_node.py @@ -0,0 +1,251 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the IMPALA Algorithm +results from Espeholt et al. 2018 for the on Atari Environments. +""" +import hydra + + +@hydra.main(config_path=".", config_name="config", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors.distributed import RPCDataCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import A2CLoss + from torchrl.objectives.value import VTrace + from torchrl.record.loggers import generate_exp_name, get_logger + from utils import eval_model, make_parallel_env, make_ppo_models + + device = "cpu" if not torch.cuda.device_count() else "cuda" + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Extract other config parameters + batch_size = cfg.loss.batch_size # Number of rollouts per batch + num_workers = ( + cfg.collector.num_workers + ) # Number of parallel workers collecting rollouts + lr = cfg.optim.lr + anneal_lr = cfg.optim.anneal_lr + sgd_updates = cfg.loss.sgd_updates + max_grad_norm = cfg.optim.max_grad_norm + num_test_episodes = cfg.logger.num_test_episodes + total_network_updates = ( + total_frames // (frames_per_batch * batch_size) + ) * cfg.loss.sgd_updates + + # Create models (check utils_atari.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + collector = RPCDataCollector( + create_env_fn=[make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device)] + * 2, + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + storing_device=device, + max_frames_per_traj=-1, + sync=False, + slurm_kwargs={ + "timeout_min": 10, + "slurm_partition": "1080", + "slurm_cpus_per_task": 1, + "slurm_gpus_per_node": 1, + } + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch * batch_size), + sampler=sampler, + batch_size=frames_per_batch * batch_size, + ) + + # Create loss and adv modules + adv_module = VTrace( + gamma=cfg.loss.gamma, + value_network=critic, + actor_network=actor, + average_adv=False, + ) + loss_module = A2CLoss( + actor=actor, + critic=critic, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + ) + + # Create optimizer + optim = torch.optim.RMSprop( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + alpha=cfg.optim.alpha, + ) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name( + "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" + ) + logger = get_logger( + cfg.logger.backend, logger_name="impala", experiment_name=exp_name + ) + + # Create test environment + test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=total_frames) + accumulator = [] + sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + # Apply episodic end of life + data["done"].copy_(data["eol"]) + data["next", "done"].copy_(data["next", "eol"]) + + if len(accumulator) < batch_size: + accumulator.append(data) + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + continue + + losses = TensorDict({}, batch_size=[sgd_updates]) + training_start = time.time() + for j in range(sgd_updates): + + for acc_data in accumulator: + + with torch.no_grad(): + acc_data = adv_module(acc_data) + acc_data_reshape = acc_data.reshape(-1) + + # Update the data buffer + data_buffer.extend(acc_data_reshape) + + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = lr * alpha + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device) + + # Forward pass loss + loss = loss_module(batch) + losses[j] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: + actor.eval() + eval_start = time.time() + test_reward = eval_model( + actor, test_env, num_episodes=num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_reward, + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + accumulator = [] + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index edbbe23260e..c7fb53b260d 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -10,7 +10,7 @@ import hydra -@hydra.main(config_path=".", config_name="config", version_base="1.1") +@hydra.main(config_path=".", config_name="config_single_node", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time @@ -141,8 +141,8 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Apply episodic end of life - # data["done"].copy_(data["end_of_life"]) - # data["next", "done"].copy_(data["next", "end_of_life"]) + data["done"].copy_(data["eol"]) + data["next", "done"].copy_(data["next", "eol"]) if len(accumulator) < batch_size: accumulator.append(data) diff --git a/test/test_cost.py b/test/test_cost.py index d03dea93011..30b69783f1c 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6284,7 +6284,7 @@ class TestReinforce(LossModuleTestBase): @pytest.mark.parametrize("delay_value", [True, False]) @pytest.mark.parametrize("gradient_mode", [True, False]) @pytest.mark.parametrize("advantage", ["gae", "vtrace", "td", "td_lambda", None]) - @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + @pytest.mark.parametrize("td_est", list(ValueEstimators)) def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est): n_obs = 3 n_act = 5 From 0b9ed5cc3d1088bc2cb0df1abcec5af18e111958 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 2 Oct 2023 12:38:03 +0200 Subject: [PATCH 039/109] fixes --- examples/impala/impala_multi_node.py | 5 +- examples/impala/impala_single_node.py | 5 +- examples/impala/utils.py | 8 +-- test/test_cost.py | 85 ++++++++++++++++++++++----- 4 files changed, 75 insertions(+), 28 deletions(-) diff --git a/examples/impala/impala_multi_node.py b/examples/impala/impala_multi_node.py index c815e0a453f..a0b08543a24 100644 --- a/examples/impala/impala_multi_node.py +++ b/examples/impala/impala_multi_node.py @@ -94,6 +94,7 @@ def main(cfg: "DictConfig"): # noqa: F821 entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, ) + loss_module.set_keys(done="eol") # Create optimizer optim = torch.optim.RMSprop( @@ -145,10 +146,6 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - # Apply episodic end of life - data["done"].copy_(data["eol"]) - data["next", "done"].copy_(data["next", "eol"]) - if len(accumulator) < batch_size: accumulator.append(data) if logger: diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index c7fb53b260d..bc47518a3c3 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -89,6 +89,7 @@ def main(cfg: "DictConfig"): # noqa: F821 entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, ) + loss_module.set_keys(done="eol") # Create optimizer optim = torch.optim.RMSprop( @@ -140,10 +141,6 @@ def main(cfg: "DictConfig"): # noqa: F821 } ) - # Apply episodic end of life - data["done"].copy_(data["eol"]) - data["next", "done"].copy_(data["next", "eol"]) - if len(accumulator) < batch_size: accumulator.append(data) if logger: diff --git a/examples/impala/utils.py b/examples/impala/utils.py index 5e5b8736b2e..73037fef9a8 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -16,7 +16,7 @@ ExplorationType, GrayScale, NoopResetEnv, - ObservationNorm, + VecNorm, ParallelEnv, Resize, RewardClipping, @@ -97,11 +97,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): if not is_test: env.append_transform(RewardClipping(-1, 1)) env.append_transform(DoubleToFloat()) - env.append_transform( - ObservationNorm( - in_keys=["pixels"], scale=1 / 255.0, loc=-0.5, standard_normal=False - ) - ) + env.append_transform(VecNorm(in_keys=["pixels"])) return env diff --git a/test/test_cost.py b/test/test_cost.py index 30b69783f1c..e7b31610924 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5730,12 +5730,12 @@ class TestA2C(LossModuleTestBase): seed = 0 def _create_mock_actor( - self, - batch=2, - obs_dim=3, - action_dim=4, - device="cpu", - observation_key="observation", + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", ): # Actor action_spec = BoundedTensorSpec( @@ -5747,20 +5747,20 @@ def _create_mock_actor( ) actor = ProbabilisticActor( module=module, + distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, - distribution_class=TanhNormal, ) return actor.to(device) def _create_mock_value( - self, - batch=2, - obs_dim=3, - action_dim=4, - device="cpu", - out_keys=None, - observation_key="observation", + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + out_keys=None, + observation_key="observation", ): module = nn.Linear(obs_dim, 1) value = ValueOperator( @@ -5770,6 +5770,63 @@ def _create_mock_value( ) return value.to(device) + def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): + # Actor + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + base_layer = nn.Linear(obs_dim, 5) + net = NormalParamWrapper( + nn.Sequential(base_layer, nn.Linear(5, 2 * action_dim)) + ) + module = TensorDictModule( + net, in_keys=["observation"], out_keys=["loc", "scale"] + ) + actor = ProbabilisticActor( + module=module, + distribution_class=TanhNormal, + in_keys=["loc", "scale"], + spec=action_spec, + ) + module = nn.Sequential(base_layer, nn.Linear(5, 1)) + value = ValueOperator( + module=module, + in_keys=["observation"], + ) + return actor.to(device), value.to(device) + + def _create_mock_actor_value_shared( + self, batch=2, obs_dim=3, action_dim=4, device="cpu" + ): + # Actor + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + base_layer = nn.Linear(obs_dim, 5) + common = TensorDictModule( + base_layer, in_keys=["observation"], out_keys=["hidden"] + ) + net = nn.Sequential(nn.Linear(5, 2 * action_dim), NormalParamExtractor()) + module = TensorDictModule(net, in_keys=["hidden"], out_keys=["loc", "scale"]) + actor_head = ProbabilisticActor( + module=module, + distribution_class=TanhNormal, + in_keys=["loc", "scale"], + spec=action_spec, + ) + module = nn.Linear(5, 1) + value_head = ValueOperator( + module=module, + in_keys=["hidden"], + ) + model = ActorValueOperator(common, actor_head, value_head).to(device) + return model, model.get_policy_operator(), model.get_value_operator() + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=0, vmin=1, vmax=5 + ): + raise NotImplementedError + def _create_mock_common_layer_setup( self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2, T=10 ): From b3c0c9ede46ea03f6de051a936323016bc068e66 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 2 Oct 2023 17:59:23 +0200 Subject: [PATCH 040/109] multi node --- examples/impala/impala_multi_node.py | 38 ++++++++++++++------------- examples/impala/impala_single_node.py | 6 ++--- examples/impala/utils.py | 21 ++++----------- 3 files changed, 28 insertions(+), 37 deletions(-) diff --git a/examples/impala/impala_multi_node.py b/examples/impala/impala_multi_node.py index b7cad535984..78c6169a985 100644 --- a/examples/impala/impala_multi_node.py +++ b/examples/impala/impala_multi_node.py @@ -10,7 +10,7 @@ import hydra -@hydra.main(config_path=".", config_name="config", version_base="1.1") +@hydra.main(config_path=".", config_name="config_multi_node", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time @@ -19,14 +19,15 @@ def main(cfg: "DictConfig"): # noqa: F821 import tqdm from tensordict import TensorDict - from torchrl.collectors.distributed import RPCDataCollector + from torchrl.collectors import SyncDataCollector + from torchrl.collectors.distributed import RayCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type from torchrl.objectives import A2CLoss from torchrl.objectives.value import VTrace from torchrl.record.loggers import generate_exp_name, get_logger - from utils import eval_model, make_parallel_env, make_ppo_models + from utils import eval_model, make_env, make_ppo_models device = "cpu" if not torch.cuda.device_count() else "cuda" @@ -38,9 +39,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Extract other config parameters batch_size = cfg.loss.batch_size # Number of rollouts per batch - num_workers = ( - cfg.collector.num_workers - ) # Number of parallel workers collecting rollouts + num_workers = cfg.collector.num_workers # Number of parallel workers collecting rollouts lr = cfg.optim.lr anneal_lr = cfg.optim.anneal_lr sgd_updates = cfg.loss.sgd_updates @@ -55,21 +54,24 @@ def main(cfg: "DictConfig"): # noqa: F821 actor, critic = actor.to(device), critic.to(device) # Create collector - collector = RPCDataCollector( - create_env_fn=[make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device)] - * 2, + remote_config = { + "num_cpus": 1, + "num_gpus": 1.0 // num_workers, + "memory": 1024**3, + "object_store_memory": 1024**3, + } + collector = RayCollector( + create_env_fn=[make_env(cfg.env.env_name, cfg.env.num_envs, device)] + * num_workers, policy=actor, + collector_class=SyncDataCollector, frames_per_batch=frames_per_batch, total_frames=total_frames, - storing_device=device, max_frames_per_traj=-1, + remote_configs=remote_config, + num_collectors=1, sync=False, - slurm_kwargs={ - "timeout_min": 10, - "slurm_partition": "1080", - "slurm_cpus_per_task": 1, - "slurm_gpus_per_node": 1, - } + update_after_each_batch=True, ) # Create data buffer @@ -94,7 +96,7 @@ def main(cfg: "DictConfig"): # noqa: F821 entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, ) - loss_module.set_keys(done="eol", terminated="eol") + # loss_module.set_keys(done="eol", terminated="eol") # Create optimizer optim = torch.optim.RMSprop( @@ -116,7 +118,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create test environment - test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + test_env = make_env(cfg.env.env_name, 1, device, is_test=True) test_env.eval() # Main loop diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index f850d468dab..f40958bbe96 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -26,7 +26,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.objectives import A2CLoss from torchrl.objectives.value import VTrace from torchrl.record.loggers import generate_exp_name, get_logger - from utils import eval_model, make_parallel_env, make_ppo_models + from utils import eval_model, make_env, make_ppo_models device = "cpu" if not torch.cuda.device_count() else "cuda" @@ -56,7 +56,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create collector collector = MultiaSyncDataCollector( - create_env_fn=[make_parallel_env(cfg.env.env_name, cfg.env.num_envs, device)] + create_env_fn=[make_env(cfg.env.env_name, cfg.env.num_envs, device)] * num_workers, policy=actor, frames_per_batch=frames_per_batch, @@ -111,7 +111,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create test environment - test_env = make_parallel_env(cfg.env.env_name, 1, device, is_test=True) + test_env = make_env(cfg.env.env_name, 1, device, is_test=True) test_env.eval() # Main loop diff --git a/examples/impala/utils.py b/examples/impala/utils.py index 882c57a74f8..cb06768fa8d 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -68,34 +68,23 @@ def transform_observation_spec(self, observation_spec): return observation_spec -def make_base_env( - env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False -): +def make_env(env_name, device, frame_skip=4, is_test=False): env = GymEnv( env_name, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: - env.append_transform(EndOfLifeTransform()) - return env - - -def make_parallel_env(env_name, num_envs, device, is_test=False): - env = ParallelEnv( - num_envs, EnvCreator(lambda: make_base_env(env_name, device=device)) - ) - env = TransformedEnv(env) + # env.append_transform(EndOfLifeTransform()) + env.append_transform(RewardClipping(-1, 1)) env.append_transform(ToTensorImage(from_int=False)) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) env.append_transform(CatFrames(N=4, dim=-3)) env.append_transform(RewardSum()) env.append_transform(StepCounter(max_steps=4500)) - if not is_test: - env.append_transform(RewardClipping(-1, 1)) env.append_transform(DoubleToFloat()) - env.append_transform(VecNorm(in_keys=["pixels"])) + # env.append_transform(VecNorm(in_keys=["pixels"])) return env @@ -190,7 +179,7 @@ def make_ppo_modules_pixels(proof_environment): def make_ppo_models(env_name): - proof_environment = make_parallel_env(env_name, 1, device="cpu") + proof_environment = make_env(env_name, device="cpu") common_module, policy_module, value_module = make_ppo_modules_pixels( proof_environment ) From c634112e33b1710ce09c7f53ffdd998522471f82 Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 2 Oct 2023 18:08:06 +0200 Subject: [PATCH 041/109] multi node --- examples/impala/config_multi_node.yaml | 1 - examples/impala/config_single_node.yaml | 1 - examples/impala/impala_multi_node.py | 8 +++----- examples/impala/impala_single_node.py | 4 ++-- examples/impala/utils.py | 4 ++-- 5 files changed, 7 insertions(+), 11 deletions(-) diff --git a/examples/impala/config_multi_node.yaml b/examples/impala/config_multi_node.yaml index b777e0db9aa..86a11d6b40c 100644 --- a/examples/impala/config_multi_node.yaml +++ b/examples/impala/config_multi_node.yaml @@ -1,7 +1,6 @@ # Environment env: env_name: PongNoFrameskip-v4 - num_envs: 1 # collector collector: diff --git a/examples/impala/config_single_node.yaml b/examples/impala/config_single_node.yaml index b777e0db9aa..86a11d6b40c 100644 --- a/examples/impala/config_single_node.yaml +++ b/examples/impala/config_single_node.yaml @@ -1,7 +1,6 @@ # Environment env: env_name: PongNoFrameskip-v4 - num_envs: 1 # collector collector: diff --git a/examples/impala/impala_multi_node.py b/examples/impala/impala_multi_node.py index 78c6169a985..7988cdf01fe 100644 --- a/examples/impala/impala_multi_node.py +++ b/examples/impala/impala_multi_node.py @@ -57,11 +57,10 @@ def main(cfg: "DictConfig"): # noqa: F821 remote_config = { "num_cpus": 1, "num_gpus": 1.0 // num_workers, - "memory": 1024**3, - "object_store_memory": 1024**3, + "memory": 2 * 1024**3, } collector = RayCollector( - create_env_fn=[make_env(cfg.env.env_name, cfg.env.num_envs, device)] + create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, policy=actor, collector_class=SyncDataCollector, @@ -69,7 +68,6 @@ def main(cfg: "DictConfig"): # noqa: F821 total_frames=total_frames, max_frames_per_traj=-1, remote_configs=remote_config, - num_collectors=1, sync=False, update_after_each_batch=True, ) @@ -118,7 +116,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env.env_name, 1, device, is_test=True) + test_env = make_env(cfg.env.env_name, device, is_test=True) test_env.eval() # Main loop diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index f40958bbe96..9b487e32f9d 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -56,7 +56,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create collector collector = MultiaSyncDataCollector( - create_env_fn=[make_env(cfg.env.env_name, cfg.env.num_envs, device)] + create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, policy=actor, frames_per_batch=frames_per_batch, @@ -111,7 +111,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env.env_name, 1, device, is_test=True) + test_env = make_env(cfg.env.env_name, device, is_test=True) test_env.eval() # Main loop diff --git a/examples/impala/utils.py b/examples/impala/utils.py index cb06768fa8d..570fa4100cb 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -68,9 +68,9 @@ def transform_observation_spec(self, observation_spec): return observation_spec -def make_env(env_name, device, frame_skip=4, is_test=False): +def make_env(env_name, device, is_test=False): env = GymEnv( - env_name, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device + env_name, frame_skip=4, from_pixels=True, pixels_only=False, device=device ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) From df16acea9f089c2448550c36873118c5db45123a Mon Sep 17 00:00:00 2001 From: albert bou Date: Mon, 2 Oct 2023 18:24:44 +0200 Subject: [PATCH 042/109] multi node --- examples/impala/config_multi_node.yaml | 2 +- examples/impala/impala_multi_node.py | 9 ++++---- examples/impala/impala_single_node.py | 3 +-- examples/impala/utils.py | 8 +++---- test/test_cost.py | 30 +++++++++++++------------- torchrl/envs/transforms/transforms.py | 16 +++++++++++++- 6 files changed, 40 insertions(+), 28 deletions(-) diff --git a/examples/impala/config_multi_node.yaml b/examples/impala/config_multi_node.yaml index 86a11d6b40c..7ac584bcac6 100644 --- a/examples/impala/config_multi_node.yaml +++ b/examples/impala/config_multi_node.yaml @@ -6,7 +6,7 @@ env: collector: frames_per_batch: 80 total_frames: 200_000_000 - num_workers: 12 + num_workers: 1 # logger logger: diff --git a/examples/impala/impala_multi_node.py b/examples/impala/impala_multi_node.py index 7988cdf01fe..939acd8ea49 100644 --- a/examples/impala/impala_multi_node.py +++ b/examples/impala/impala_multi_node.py @@ -39,7 +39,9 @@ def main(cfg: "DictConfig"): # noqa: F821 # Extract other config parameters batch_size = cfg.loss.batch_size # Number of rollouts per batch - num_workers = cfg.collector.num_workers # Number of parallel workers collecting rollouts + num_workers = ( + cfg.collector.num_workers + ) # Number of parallel workers collecting rollouts lr = cfg.optim.lr anneal_lr = cfg.optim.anneal_lr sgd_updates = cfg.loss.sgd_updates @@ -56,12 +58,11 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create collector remote_config = { "num_cpus": 1, - "num_gpus": 1.0 // num_workers, + "num_gpus": 1.0, "memory": 2 * 1024**3, } collector = RayCollector( - create_env_fn=[make_env(cfg.env.env_name, device)] - * num_workers, + create_env_fn=[make_env(cfg.env.env_name, device)] * 1, policy=actor, collector_class=SyncDataCollector, frames_per_batch=frames_per_batch, diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index 9b487e32f9d..ed0100925d4 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -56,8 +56,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create collector collector = MultiaSyncDataCollector( - create_env_fn=[make_env(cfg.env.env_name, device)] - * num_workers, + create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, policy=actor, frames_per_batch=frames_per_batch, total_frames=total_frames, diff --git a/examples/impala/utils.py b/examples/impala/utils.py index 570fa4100cb..967f07d4fbe 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -11,12 +11,10 @@ from torchrl.envs import ( CatFrames, DoubleToFloat, - EnvCreator, ExplorationType, GrayScale, + GymEnv, NoopResetEnv, - VecNorm, - ParallelEnv, Resize, RewardClipping, RewardSum, @@ -24,7 +22,7 @@ ToTensorImage, Transform, TransformedEnv, - GymEnv + VecNorm, ) from torchrl.modules import ( ActorValueOperator, @@ -84,7 +82,7 @@ def make_env(env_name, device, is_test=False): env.append_transform(RewardSum()) env.append_transform(StepCounter(max_steps=4500)) env.append_transform(DoubleToFloat()) - # env.append_transform(VecNorm(in_keys=["pixels"])) + env.append_transform(VecNorm(in_keys=["pixels"])) return env diff --git a/test/test_cost.py b/test/test_cost.py index e7b31610924..2fbed5626eb 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5730,12 +5730,12 @@ class TestA2C(LossModuleTestBase): seed = 0 def _create_mock_actor( - self, - batch=2, - obs_dim=3, - action_dim=4, - device="cpu", - observation_key="observation", + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + observation_key="observation", ): # Actor action_spec = BoundedTensorSpec( @@ -5754,13 +5754,13 @@ def _create_mock_actor( return actor.to(device) def _create_mock_value( - self, - batch=2, - obs_dim=3, - action_dim=4, - device="cpu", - out_keys=None, - observation_key="observation", + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + out_keys=None, + observation_key="observation", ): module = nn.Linear(obs_dim, 1) value = ValueOperator( @@ -5796,7 +5796,7 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu return actor.to(device), value.to(device) def _create_mock_actor_value_shared( - self, batch=2, obs_dim=3, action_dim=4, device="cpu" + self, batch=2, obs_dim=3, action_dim=4, device="cpu" ): # Actor action_spec = BoundedTensorSpec( @@ -5823,7 +5823,7 @@ def _create_mock_actor_value_shared( return model, model.get_policy_operator(), model.get_value_operator() def _create_mock_distributional_actor( - self, batch=2, obs_dim=3, action_dim=4, atoms=0, vmin=1, vmax=5 + self, batch=2, obs_dim=3, action_dim=4, atoms=0, vmin=1, vmax=5 ): raise NotImplementedError diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 879432569f7..75b244f88d4 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -11,7 +11,7 @@ from copy import copy from functools import wraps from textwrap import indent -from typing import Any, List, Optional, OrderedDict, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, OrderedDict, Sequence, Tuple, Union import numpy as np @@ -4337,6 +4337,20 @@ def __repr__(self) -> str: f"eps={self.eps:4.4f}, keys={self.in_keys})" ) + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + _lock = state.pop("lock", None) + if _lock is not None: + state["lock_placeholder"] = None + return state + + def __setstate__(self, state: Dict[str, Any]): + if "lock_placeholder" in state: + state.pop("lock_placeholder") + _lock = mp.Lock() + state["lock"] = _lock + self.__dict__.update(state) + class RewardSum(Transform): """Tracks episode cumulative rewards. From 3403e29eff881edd49be8e0fe611927fad862ec0 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 3 Oct 2023 10:10:35 +0200 Subject: [PATCH 043/109] fix tests --- test/test_cost.py | 8 ++++++++ torchrl/objectives/a2c.py | 1 + torchrl/objectives/common.py | 2 +- torchrl/objectives/ppo.py | 1 + torchrl/objectives/reinforce.py | 1 + torchrl/objectives/value/advantages.py | 10 ++++++++-- 6 files changed, 20 insertions(+), 3 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 2fbed5626eb..ca6a67dd72a 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5750,6 +5750,7 @@ def _create_mock_actor( distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, + return_log_prob=True, ) return actor.to(device) @@ -5787,6 +5788,7 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, + return_log_prob=True, ) module = nn.Sequential(base_layer, nn.Linear(5, 1)) value = ValueOperator( @@ -5813,6 +5815,7 @@ def _create_mock_actor_value_shared( distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, + return_log_prob=True, ) module = nn.Linear(5, 1) value_head = ValueOperator( @@ -6166,8 +6169,11 @@ def test_a2c_tensordict_keys(self, td_est): "action": "action", "reward": "reward", "done": "done", + "sample_log_prob": "sample_log_prob", } + import ipdb; ipdb.set_trace() + self.tensordict_keys_test( loss_fn, default_keys=default_keys, @@ -6184,7 +6190,9 @@ def test_a2c_tensordict_keys(self, td_est): "value": ("value", "value_state_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "sample_log_prob": ("sample_log_prob", "sample_log_prob_test"), } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index ec8f2d3ee66..fd17c88f6f9 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -409,5 +409,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, "terminated": self.tensor_keys.terminated, + "sample_log_prob": self.tensor_keys.sample_log_prob, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index a8b80c98b63..37c5e820d23 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -138,7 +138,7 @@ def set_keys(self, **kwargs) -> None: """ for key, value in kwargs.items(): if key not in self._AcceptedKeys.__dict__: - raise ValueError(f"{key} it not an accepted tensordict key") + raise ValueError(f"{key} is not an accepted tensordict key") if value is not None: setattr(self.tensor_keys, key, value) else: diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 801735cf060..d54472764f1 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -483,6 +483,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, "terminated": self.tensor_keys.terminated, + "sample_log_prob": self.tensor_keys.sample_log_prob, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 7e459d71e72..04389eed89d 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -360,5 +360,6 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams "reward": self.tensor_keys.reward, "done": self.tensor_keys.done, "terminated": self.tensor_keys.terminated, + "sample_log_prob": self.tensor_keys.sample_log_prob, } self._value_estimator.set_keys(**tensor_keys) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 61f1b029a33..e86fff25de1 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -183,9 +183,11 @@ class _AcceptedKeys: whether a trajectory is done. Defaults to ``"done"``. terminated (NestedKey): The key in the input TensorDict that indicates whether a trajectory is terminated. Defaults to ``"terminated"``. - steps_to_next_obs_key (NestedKey): The key in the input tensordict + steps_to_next_obs (NestedKey): The key in the input tensordict that indicates the number of steps to the next observation. Defaults to ``"steps_to_next_obs"``. + sample_log_prob (NestedKey): The key in the input tensordict that + indicates the log probability of the sampled action. Defaults to ``"sample_log_prob"``. """ advantage: NestedKey = "advantage" @@ -228,6 +230,10 @@ def terminated_key(self): def steps_to_next_obs_key(self): return self.tensor_keys.steps_to_next_obs + @property + def sample_log_prob_key(self): + return self.tensor_keys.sample_log_prob + @abc.abstractmethod def forward( self, @@ -346,7 +352,7 @@ def set_keys(self, **kwargs) -> None: raise ValueError("tensordict keys cannot be None") if key not in self._AcceptedKeys.__dict__: raise KeyError( - f"{key} it not an acceptedaccepted tensordict key for advantages" + f"{key} is not an accepted tensordict key for advantages" ) if ( key == "value" From 9da24ebb7899a1f36bf5ec6d5f641adc01c79fda Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 3 Oct 2023 10:15:00 +0200 Subject: [PATCH 044/109] fix tests --- torchrl/objectives/a2c.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index fd17c88f6f9..f9ef9d521f7 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -208,6 +208,7 @@ class _AcceptedKeys: reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + sample_log_prob: NestedKey = "sample_log_prob" default_keys = _AcceptedKeys() default_value_estimator: ValueEstimators = ValueEstimators.GAE From 795620f53bc5f0dfce2aa5f8e79bb27e4d1c8a8f Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 3 Oct 2023 10:43:15 +0200 Subject: [PATCH 045/109] fix tests --- test/test_cost.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index ca6a67dd72a..aab9fa946e6 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5736,6 +5736,7 @@ def _create_mock_actor( action_dim=4, device="cpu", observation_key="observation", + sample_log_prob_key="sample_log_prob", ): # Actor action_spec = BoundedTensorSpec( @@ -5751,6 +5752,7 @@ def _create_mock_actor( in_keys=["loc", "scale"], spec=action_spec, return_log_prob=True, + log_prob_key=sample_log_prob_key, ) return actor.to(device) @@ -6150,10 +6152,10 @@ def test_a2c_diff(self, device, gradient_mode, advantage): "td_est", [ ValueEstimators.TD1, - ValueEstimators.TD0, - ValueEstimators.GAE, - ValueEstimators.VTrace, - ValueEstimators.TDLambda, + # ValueEstimators.TD0, + # ValueEstimators.GAE, + # ValueEstimators.VTrace, + # ValueEstimators.TDLambda, ], ) def test_a2c_tensordict_keys(self, td_est): @@ -6172,8 +6174,6 @@ def test_a2c_tensordict_keys(self, td_est): "sample_log_prob": "sample_log_prob", } - import ipdb; ipdb.set_trace() - self.tensordict_keys_test( loss_fn, default_keys=default_keys, @@ -6216,7 +6216,7 @@ def test_a2c_tensordict_keys_run(self, device, advantage, td_est): device=device, ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor(device=device, sample_log_prob_key=tensor_keys["sample_log_prob"]) value = self._create_mock_value(device=device, out_keys=[tensor_keys["value"]]) if advantage == "gae": From 53aceba5f5f7afe2dc1b9cfee48c6fb492b141e5 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 3 Oct 2023 12:01:43 +0200 Subject: [PATCH 046/109] merge main --- test/test_cost.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index aab9fa946e6..ca9015062fd 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6152,10 +6152,10 @@ def test_a2c_diff(self, device, gradient_mode, advantage): "td_est", [ ValueEstimators.TD1, - # ValueEstimators.TD0, - # ValueEstimators.GAE, - # ValueEstimators.VTrace, - # ValueEstimators.TDLambda, + ValueEstimators.TD0, + ValueEstimators.GAE, + ValueEstimators.VTrace, + ValueEstimators.TDLambda, ], ) def test_a2c_tensordict_keys(self, td_est): From 02cebf683adab3c706e8460f3f35b012a7086748 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 3 Oct 2023 12:31:23 +0200 Subject: [PATCH 047/109] multinode script --- examples/impala/impala_multi_node.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/examples/impala/impala_multi_node.py b/examples/impala/impala_multi_node.py index 939acd8ea49..2ac4e2d4bdd 100644 --- a/examples/impala/impala_multi_node.py +++ b/examples/impala/impala_multi_node.py @@ -20,7 +20,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from tensordict import TensorDict from torchrl.collectors import SyncDataCollector - from torchrl.collectors.distributed import RayCollector + from torchrl.collectors.distributed import RayCollector, RPCDataCollector, DistributedSyncDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type @@ -39,9 +39,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Extract other config parameters batch_size = cfg.loss.batch_size # Number of rollouts per batch - num_workers = ( - cfg.collector.num_workers - ) # Number of parallel workers collecting rollouts + num_workers = cfg.collector.num_workers # Number of parallel workers collecting rollouts lr = cfg.optim.lr anneal_lr = cfg.optim.anneal_lr sgd_updates = cfg.loss.sgd_updates @@ -51,7 +49,7 @@ def main(cfg: "DictConfig"): # noqa: F821 total_frames // (frames_per_batch * batch_size) ) * cfg.loss.sgd_updates - # Create models (check utils_atari.py) + # Create models (check utils.py) actor, critic = make_ppo_models(cfg.env.env_name) actor, critic = actor.to(device), critic.to(device) @@ -62,7 +60,7 @@ def main(cfg: "DictConfig"): # noqa: F821 "memory": 2 * 1024**3, } collector = RayCollector( - create_env_fn=[make_env(cfg.env.env_name, device)] * 1, + create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, policy=actor, collector_class=SyncDataCollector, frames_per_batch=frames_per_batch, @@ -73,6 +71,23 @@ def main(cfg: "DictConfig"): # noqa: F821 update_after_each_batch=True, ) + # collector = RPCDataCollector( + # create_env_fn=[make_env(cfg.env.env_name, device)] * 1, + # policy=actor, + # collector_class=SyncDataCollector, + # frames_per_batch=frames_per_batch, + # total_frames=total_frames, + # max_frames_per_traj=-1, + # slurm_kwargs={ + # "timeout_min": 10, + # "slurm_partition": "3090", + # "slurm_cpus_per_task": 1, + # "slurm_gpus_per_node": 0, + # }, + # sync=False, + # update_after_each_batch=True, + # ) + # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( @@ -160,6 +175,7 @@ def main(cfg: "DictConfig"): # noqa: F821 for acc_data in accumulator: + acc_data = acc_data.to(device) with torch.no_grad(): acc_data = adv_module(acc_data) acc_data_reshape = acc_data.reshape(-1) From 5c0aec01daf84abc49eff7a097528bbcfb49de00 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 3 Oct 2023 12:43:01 +0200 Subject: [PATCH 048/109] call actor func --- .../collectors/multi_nodes/ray_train.py | 2 +- examples/impala/impala_multi_node.py | 9 ++++++-- test/test_cost.py | 4 +++- torchrl/objectives/value/advantages.py | 22 ++++++++++++++----- 4 files changed, 28 insertions(+), 9 deletions(-) diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index a5265f442b7..360c6daac28 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -117,7 +117,7 @@ "object_store_memory": 1024**3, } collector = RayCollector( - env_makers=[env] * num_collectors, + create_env_fn=[env] * num_collectors, policy=policy_module, collector_class=SyncDataCollector, collector_kwargs={ diff --git a/examples/impala/impala_multi_node.py b/examples/impala/impala_multi_node.py index 2ac4e2d4bdd..6eadc2752c5 100644 --- a/examples/impala/impala_multi_node.py +++ b/examples/impala/impala_multi_node.py @@ -20,7 +20,10 @@ def main(cfg: "DictConfig"): # noqa: F821 from tensordict import TensorDict from torchrl.collectors import SyncDataCollector - from torchrl.collectors.distributed import RayCollector, RPCDataCollector, DistributedSyncDataCollector + from torchrl.collectors.distributed import ( + RayCollector, + RPCDataCollector, + ) from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type @@ -39,7 +42,9 @@ def main(cfg: "DictConfig"): # noqa: F821 # Extract other config parameters batch_size = cfg.loss.batch_size # Number of rollouts per batch - num_workers = cfg.collector.num_workers # Number of parallel workers collecting rollouts + num_workers = ( + cfg.collector.num_workers + ) # Number of parallel workers collecting rollouts lr = cfg.optim.lr anneal_lr = cfg.optim.anneal_lr sgd_updates = cfg.loss.sgd_updates diff --git a/test/test_cost.py b/test/test_cost.py index ca9015062fd..66202227fb7 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6216,7 +6216,9 @@ def test_a2c_tensordict_keys_run(self, device, advantage, td_est): device=device, ) - actor = self._create_mock_actor(device=device, sample_log_prob_key=tensor_keys["sample_log_prob"]) + actor = self._create_mock_actor( + device=device, sample_log_prob_key=tensor_keys["sample_log_prob"] + ) value = self._create_mock_value(device=device, out_keys=[tensor_keys["value"]]) if advantage == "gae": diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index e86fff25de1..5fae31846e6 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -151,6 +151,17 @@ def _call_value_nets( return value, value_ +def _call_actor_net( + actor_net: TensorDictModuleBase, + data: TensorDictBase, + params: TensorDictBase, + log_prob_key: NestedKey, +): + # TODO: extend to handle time dimension (and vmap?) + log_pi = actor_net(data.select(actor_net.in_keys)).get(log_prob_key) + return log_pi + + class ValueEstimatorBase(TensorDictModuleBase): """An abstract parent class for value function modules. @@ -1543,11 +1554,12 @@ def forward( # Compute log prob with current policy with hold_out_net(self.actor_network): - log_pi = ( - self.actor_network(tensordict.select(self.actor_network.in_keys)) - .get(self.tensor_keys.sample_log_prob) - .view_as(value) - ) + log_pi = _call_actor_net( + actor_net=self.actor_network, + data=tensordict, + params=None, + log_prob_key=self.tensor_keys.sample_log_prob, + ).view_as(value) # Compute the V-Trace correction done = tensordict.get(("next", self.tensor_keys.done)) From c8ef2c7b27b925fbc2393fb523043b2f11f1d54b Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 3 Oct 2023 15:32:12 +0200 Subject: [PATCH 049/109] faster scripts --- examples/impala/config_multi_node.yaml | 2 +- examples/impala/impala_multi_node.py | 17 +++++++++-------- examples/impala/impala_single_node.py | 14 ++++++++------ 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/examples/impala/config_multi_node.yaml b/examples/impala/config_multi_node.yaml index 7ac584bcac6..86a11d6b40c 100644 --- a/examples/impala/config_multi_node.yaml +++ b/examples/impala/config_multi_node.yaml @@ -6,7 +6,7 @@ env: collector: frames_per_batch: 80 total_frames: 200_000_000 - num_workers: 1 + num_workers: 12 # logger logger: diff --git a/examples/impala/impala_multi_node.py b/examples/impala/impala_multi_node.py index 6eadc2752c5..3ed63ae0e64 100644 --- a/examples/impala/impala_multi_node.py +++ b/examples/impala/impala_multi_node.py @@ -61,7 +61,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create collector remote_config = { "num_cpus": 1, - "num_gpus": 1.0, + "num_gpus": 1.0 / num_workers, "memory": 2 * 1024**3, } collector = RayCollector( @@ -178,15 +178,16 @@ def main(cfg: "DictConfig"): # noqa: F821 training_start = time.time() for j in range(sgd_updates): - for acc_data in accumulator: + # Create a single batch of trajectories + stacked_data = torch.stack(accumulator, dim=0) + stacked_data = stacked_data.to(device) - acc_data = acc_data.to(device) - with torch.no_grad(): - acc_data = adv_module(acc_data) - acc_data_reshape = acc_data.reshape(-1) + # Compute advantage + stacked_data = adv_module(stacked_data) - # Update the data buffer - data_buffer.extend(acc_data_reshape) + # Add to replay buffer + stacked_data_reshape = stacked_data.reshape(-1) + data_buffer.extend(stacked_data_reshape) for batch in data_buffer: diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index ed0100925d4..fea69b3fb14 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -151,14 +151,16 @@ def main(cfg: "DictConfig"): # noqa: F821 training_start = time.time() for j in range(sgd_updates): - for acc_data in accumulator: + # Create a single batch of trajectories + stacked_data = torch.stack(accumulator, dim=0) + stacked_data = stacked_data.to(device) - with torch.no_grad(): - acc_data = adv_module(acc_data) - acc_data_reshape = acc_data.reshape(-1) + # Compute advantage + stacked_data = adv_module(stacked_data) - # Update the data buffer - data_buffer.extend(acc_data_reshape) + # Add to replay buffer + stacked_data_reshape = stacked_data.reshape(-1) + data_buffer.extend(stacked_data_reshape) for batch in data_buffer: From e024c0976081eaa51b45efa80e4d8a6a2c191478 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 3 Oct 2023 16:19:40 +0200 Subject: [PATCH 050/109] multinode script --- examples/impala/impala_multi_node.py | 7 +++---- examples/impala/utils.py | 15 +++++++++++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/examples/impala/impala_multi_node.py b/examples/impala/impala_multi_node.py index 3ed63ae0e64..d843ecb6bfd 100644 --- a/examples/impala/impala_multi_node.py +++ b/examples/impala/impala_multi_node.py @@ -7,6 +7,7 @@ This script reproduces the IMPALA Algorithm results from Espeholt et al. 2018 for the on Atari Environments. """ + import hydra @@ -20,10 +21,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from tensordict import TensorDict from torchrl.collectors import SyncDataCollector - from torchrl.collectors.distributed import ( - RayCollector, - RPCDataCollector, - ) + from torchrl.collectors.distributed import RayCollector, RPCDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type @@ -73,6 +71,7 @@ def main(cfg: "DictConfig"): # noqa: F821 max_frames_per_traj=-1, remote_configs=remote_config, sync=False, + storing_device=device, update_after_each_batch=True, ) diff --git a/examples/impala/utils.py b/examples/impala/utils.py index 967f07d4fbe..ddbf94d37f9 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -3,6 +3,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import os + +# To pickle the environment, in particular the EndOfLifeTransform, we need to +# add the utils path to the PYTHONPATH +utils_path = os.path.abspath(os.path.abspath(os.path.dirname(__file__))) +current_pythonpath = os.environ.get("PYTHONPATH", "") +new_pythonpath = f"{utils_path}:{current_pythonpath}" +os.environ["PYTHONPATH"] = new_pythonpath + + import torch.nn import torch.optim from tensordict.nn import TensorDictModule @@ -41,7 +51,8 @@ class EndOfLifeTransform(Transform): def _step(self, tensordict, next_tensordict): - lives = self.parent.base_env._env.unwrapped.ale.lives() + # lives = self.parent.base_env._env.unwrapped.ale.lives() + lives = 0 end_of_life = torch.tensor( [tensordict["lives"] < lives], device=self.parent.device ) @@ -73,7 +84,7 @@ def make_env(env_name, device, is_test=False): env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: - # env.append_transform(EndOfLifeTransform()) + env.append_transform(EndOfLifeTransform()) env.append_transform(RewardClipping(-1, 1)) env.append_transform(ToTensorImage(from_int=False)) env.append_transform(GrayScale()) From 55b7947f8208e261aa8df1ecf23657dadb153b16 Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 3 Oct 2023 16:30:21 +0200 Subject: [PATCH 051/109] simplify utils --- examples/impala/utils.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/examples/impala/utils.py b/examples/impala/utils.py index ddbf94d37f9..d0fb4a6a262 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -108,17 +108,9 @@ def make_ppo_modules_pixels(proof_environment): input_shape = proof_environment.observation_spec["pixels"].shape # Define distribution class and kwargs - if isinstance(proof_environment.action_spec.space, DiscreteBox): - num_outputs = proof_environment.action_spec.space.n - distribution_class = OneHotCategorical - distribution_kwargs = {} - else: # is ContinuousBox - num_outputs = proof_environment.action_spec.shape - distribution_class = TanhNormal - distribution_kwargs = { - "min": proof_environment.action_spec.space.minimum, - "max": proof_environment.action_spec.space.maximum, - } + num_outputs = proof_environment.action_spec.space.n + distribution_class = OneHotCategorical + distribution_kwargs = {} # Define input keys in_keys = ["pixels"] From 6d6df00d864f7b508db1018c85112a1719931757 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 4 Oct 2023 08:41:37 +0200 Subject: [PATCH 052/109] revert tests --- test/test_cost.py | 1073 ++++++++++++++++++++++++++++++--------------- 1 file changed, 714 insertions(+), 359 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 66202227fb7..6c38e6a8b65 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -129,7 +129,6 @@ GAE, TD1Estimator, TDLambdaEstimator, - VTrace, ) from torchrl.objectives.value.functional import ( _transpose_time, @@ -140,7 +139,6 @@ vec_generalized_advantage_estimate, vec_td1_advantage_estimate, vec_td_lambda_advantage_estimate, - # vtrace_advantage_estimate, ) from torchrl.objectives.value.utils import ( _custom_conv1d, @@ -353,6 +351,7 @@ def _create_mock_data_dqn( action = torch.argmax(action, -1, keepdim=False) reward = torch.randn(batch, 1) done = torch.zeros(batch, 1, dtype=torch.bool) + terminated = torch.zeros(batch, 1, dtype=torch.bool) td = TensorDict( batch_size=(batch,), source={ @@ -360,6 +359,7 @@ def _create_mock_data_dqn( "next": { "observation": next_obs, "done": done, + "terminated": terminated, "reward": reward, }, action_key: action, @@ -397,6 +397,7 @@ def _create_seq_mock_data_dqn( # action_value = action_value.unsqueeze(-1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) if action_spec_type == "categorical": action_value = torch.max(action_value, -1, keepdim=True)[0] @@ -411,6 +412,7 @@ def _create_seq_mock_data_dqn( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -434,7 +436,7 @@ def test_dqn(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = DQNLoss(actor, loss_function="l2", delay_value=delay_value) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est is ValueEstimators.GAE: with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -557,6 +559,7 @@ def test_dqn_tensordict_keys(self, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test(loss_fn, default_keys=default_keys) @@ -567,6 +570,7 @@ def test_dqn_tensordict_keys(self, td_est): "value_target": ("value_target", ("value_target", "nested")), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -673,7 +677,10 @@ def test_distributional_dqn( @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) - def test_dqn_notensordict(self, observation_key, reward_key, done_key): + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_dqn_notensordict( + self, observation_key, reward_key, done_key, terminated_key + ): n_obs = 3 n_action = 4 action_spec = OneHotDiscreteTensorSpec(n_action) @@ -685,18 +692,20 @@ def test_dqn_notensordict(self, observation_key, reward_key, done_key): in_keys=[observation_key], ) dqn_loss = DQNLoss(actor) - dqn_loss.set_keys(reward=reward_key, done=done_key) + dqn_loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) # define data observation = torch.randn(n_obs) next_observation = torch.randn(n_obs) action = action_spec.rand() next_reward = torch.randn(1) next_done = torch.zeros(1, dtype=torch.bool) + next_terminated = torch.zeros(1, dtype=torch.bool) kwargs = { observation_key: observation, f"next_{observation_key}": next_observation, f"next_{reward_key}": next_reward, f"next_{done_key}": next_done, + f"next_{terminated_key}": next_terminated, "action": action, } td = TensorDict(kwargs, []).unflatten_keys("_") @@ -721,6 +730,7 @@ def test_distributional_dqn_tensordict_keys(self): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", "steps_to_next_obs": "steps_to_next_obs", } @@ -853,6 +863,7 @@ def _create_mock_data_dqn( reward = torch.randn(*batch, 1, device=device) done = torch.zeros(*batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(*batch, 1, dtype=torch.bool, device=device) td = TensorDict( { "agents": TensorDict( @@ -874,6 +885,7 @@ def _create_mock_data_dqn( "state": next_state, "reward": reward, "done": done, + "terminated": terminated, }, batch_size=batch, device=device, @@ -902,7 +914,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est is ValueEstimators.GAE: with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -1052,6 +1064,7 @@ def test_qmix_tensordict_keys(self, td_est): "action": ("agents", "action"), "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test(loss_fn, default_keys=default_keys) @@ -1062,6 +1075,7 @@ def test_qmix_tensordict_keys(self, td_est): "value_target": ("value_target", ("value_target", "nested")), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -1155,6 +1169,7 @@ def test_mixer_keys( "state": torch.zeros(32, 64, 64, 3), "reward": torch.zeros(32, 1), "done": torch.zeros(32, 1, dtype=torch.bool), + "terminated": torch.zeros(32, 1, dtype=torch.bool), }, [32], ), @@ -1208,20 +1223,20 @@ def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): return actor.to(device) def _create_mock_value( - self, batch=2, obs_dim=3, action_dim=4, device="cpu", out_keys=None + self, batch=2, obs_dim=3, action_dim=4, state_dim=8, device="cpu", out_keys=None ): # Actor class ValueClass(nn.Module): def __init__(self): super().__init__() - self.linear = nn.Linear(obs_dim + action_dim, 1) + self.linear = nn.Linear(obs_dim + action_dim + state_dim, 1) - def forward(self, obs, act): - return self.linear(torch.cat([obs, act], -1)) + def forward(self, obs, state, act): + return self.linear(torch.cat([obs, state, act], -1)) module = ValueClass() value = ValueOperator( - module=module, in_keys=["observation", "action"], out_keys=out_keys + module=module, in_keys=["observation", "state", "action"], out_keys=out_keys ) return value.to(device) @@ -1257,10 +1272,12 @@ def _create_mock_common_layer_setup( "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -1280,10 +1297,12 @@ def _create_mock_data_ddpg( batch=8, obs_dim=3, action_dim=4, + state_dim=8, atoms=None, device="cpu", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -1293,14 +1312,19 @@ def _create_mock_data_ddpg( else: action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) + state = torch.randn(batch, state_dim, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ "observation": obs, + "state": state, "next": { "observation": next_obs, + "state": state, done_key: done, + terminated_key: terminated, reward_key: reward, }, "action": action, @@ -1315,15 +1339,20 @@ def _create_seq_mock_data_ddpg( T=4, obs_dim=3, action_dim=4, + state_dim=8, atoms=None, device="cpu", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + total_state = torch.randn(batch, T + 1, state_dim, device=device) obs = total_obs[:, :T] next_obs = total_obs[:, 1:] + state = total_state[:, :T] + next_state = total_state[:, 1:] if atoms: action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( -1, 1 @@ -1331,15 +1360,20 @@ def _create_seq_mock_data_ddpg( else: action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), source={ "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "state": state.masked_fill_(~mask.unsqueeze(-1), 0.0), "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + "state": next_state.masked_fill_(~mask.unsqueeze(-1), 0.0), done_key: done, + terminated_key: terminated, reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -1365,7 +1399,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): delay_actor=delay_actor, delay_value=delay_value, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est is ValueEstimators.GAE: with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -1647,6 +1681,7 @@ def test_ddpg_tensordict_keys(self, td_est): default_keys = { "reward": "reward", "done": "done", + "terminated": "terminated", "state_action_value": "state_action_value", "priority": "td_error", } @@ -1667,6 +1702,7 @@ def test_ddpg_tensordict_keys(self, td_est): "state_action_value": ("value", "state_action_value_test"), "reward": ("reward", "reward2"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -1682,12 +1718,15 @@ def test_ddpg_tensordict_run(self, td_est): "priority": "td_error_test", "reward": "reward_test", "done": ("done", "test"), + "terminated": ("terminated", "test"), } actor = self._create_mock_actor() value = self._create_mock_value(out_keys=[tensor_keys["state_action_value"]]) td = self._create_mock_data_ddpg( - reward_key="reward_test", done_key=("done", "test") + reward_key="reward_test", + done_key=("done", "test"), + terminated_key=("terminated", "test"), ) loss_fn = DDPGLoss( actor, @@ -1715,8 +1754,11 @@ def test_ddpg_notensordict(self): "observation": td.get("observation"), "next_reward": td.get(("next", "reward")), "next_done": td.get(("next", "done")), + "next_terminated": td.get(("next", "terminated")), "next_observation": td.get(("next", "observation")), "action": td.get("action"), + "state": td.get("state"), + "next_state": td.get(("next", "state")), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -1824,10 +1866,12 @@ def _create_mock_common_layer_setup( "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -1861,6 +1905,7 @@ def _create_mock_data_td3( observation_key="observation", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -1871,6 +1916,7 @@ def _create_mock_data_td3( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -1878,6 +1924,7 @@ def _create_mock_data_td3( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -1901,6 +1948,7 @@ def _create_seq_mock_data_td3( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -1910,6 +1958,7 @@ def _create_seq_mock_data_td3( "observation": next_obs * mask.to(obs.dtype), "reward": reward * mask.to(obs.dtype), "done": done, + "terminated": terminated, }, "collector": {"mask": mask}, "action": action * mask.to(obs.dtype), @@ -1959,7 +2008,7 @@ def test_td3( delay_actor=delay_actor, delay_qvalue=delay_qvalue, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est is ValueEstimators.GAE: with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -2282,6 +2331,7 @@ def test_td3_tensordict_keys(self, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -2300,6 +2350,7 @@ def test_td3_tensordict_keys(self, td_est): "state_action_value": ("value", "state_action_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -2333,22 +2384,29 @@ def test_constructor(self, spec, bounds): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) - def test_td3_notensordict(self, observation_key, reward_key, done_key): + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_td3_notensordict( + self, observation_key, reward_key, done_key, terminated_key + ): torch.manual_seed(self.seed) actor = self._create_mock_actor(in_keys=[observation_key]) qvalue = self._create_mock_value( observation_key=observation_key, out_keys=["state_action_value"] ) td = self._create_mock_data_td3( - observation_key=observation_key, reward_key=reward_key, done_key=done_key + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, ) loss = TD3Loss(actor, qvalue, action_spec=actor.spec) - loss.set_keys(reward=reward_key, done=done_key) + loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) kwargs = { observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), "action": td.get("action"), } @@ -2359,8 +2417,12 @@ def test_td3_notensordict(self, observation_key, reward_key, done_key): loss_val_td = loss(td) torch.manual_seed(0) loss_val = loss(**kwargs) - for i, key in enumerate(loss_val_td.keys()): + for i in loss_val: + assert i in loss_val_td.values(), f"{i} not in {loss_val_td.values()}" + + for i, key in enumerate(loss.out_keys): torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) + # test select loss.select_out_keys("loss_actor", "loss_qvalue") torch.manual_seed(0) @@ -2481,10 +2543,12 @@ def _create_mock_common_layer_setup( "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -2521,6 +2585,7 @@ def _create_mock_data_sac( observation_key="observation", action_key="action", done_key="done", + terminated_key="terminated", reward_key="reward", ): # create a tensordict @@ -2532,6 +2597,7 @@ def _create_mock_data_sac( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -2539,6 +2605,7 @@ def _create_mock_data_sac( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -2562,6 +2629,7 @@ def _create_seq_mock_data_sac( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -2570,6 +2638,7 @@ def _create_seq_mock_data_sac( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -2626,7 +2695,7 @@ def test_sac( **kwargs, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est is ValueEstimators.GAE: with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -3079,6 +3148,7 @@ def test_sac_tensordict_keys(self, td_est, version): "log_prob": "_log_prob", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -3098,6 +3168,7 @@ def test_sac_tensordict_keys(self, td_est, version): "value": ("value", "state_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -3105,8 +3176,9 @@ def test_sac_tensordict_keys(self, td_est, version): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) def test_sac_notensordict( - self, action_key, observation_key, reward_key, done_key, version + self, action_key, observation_key, reward_key, done_key, terminated_key, version ): torch.manual_seed(self.seed) td = self._create_mock_data_sac( @@ -3114,6 +3186,7 @@ def test_sac_notensordict( observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor( @@ -3134,13 +3207,19 @@ def test_sac_notensordict( qvalue_network=qvalue, value_network=value, ) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { action_key: td.get(action_key), observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -3248,6 +3327,7 @@ def _create_mock_data_sac( observation_key="observation", action_key="action", done_key="done", + terminated_key="terminated", reward_key="reward", ): # create a tensordict @@ -3265,6 +3345,7 @@ def _create_mock_data_sac( action = (action_value == action_value.max(-1, True)[0]).to(torch.long) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -3272,6 +3353,7 @@ def _create_mock_data_sac( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -3300,6 +3382,7 @@ def _create_seq_mock_data_sac( reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -3308,6 +3391,7 @@ def _create_seq_mock_data_sac( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -3353,7 +3437,7 @@ def test_discrete_sac( loss_function="l2", **kwargs, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est is ValueEstimators.GAE: with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -3621,6 +3705,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( loss_fn, @@ -3640,6 +3725,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): "value": ("value", "state_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -3647,8 +3733,9 @@ def test_discrete_sac_tensordict_keys(self, td_est): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) def test_discrete_sac_notensordict( - self, action_key, observation_key, reward_key, done_key + self, action_key, observation_key, reward_key, done_key, terminated_key ): torch.manual_seed(self.seed) td = self._create_mock_data_sac( @@ -3656,6 +3743,7 @@ def test_discrete_sac_notensordict( observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor( @@ -3670,13 +3758,19 @@ def test_discrete_sac_notensordict( qvalue_network=qvalue, num_actions=actor.spec[action_key].space.n, ) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { action_key: td.get(action_key), observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -3790,10 +3884,12 @@ def _create_mock_common_layer_setup( "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -3870,6 +3966,7 @@ def _create_mock_data_redq( action_key="action", reward_key="reward", done_key="done", + terminated_key="terminated", ): # create a tensordict obs = torch.randn(batch, obs_dim, device=device) @@ -3880,6 +3977,7 @@ def _create_mock_data_redq( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -3887,6 +3985,7 @@ def _create_mock_data_redq( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -3910,6 +4009,7 @@ def _create_seq_mock_data_redq( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -3918,6 +4018,7 @@ def _create_seq_mock_data_redq( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -3946,7 +4047,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est is ValueEstimators.GAE: with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4313,7 +4414,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est is ValueEstimators.GAE: with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4330,7 +4431,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est is ValueEstimators.GAE: with pytest.raises(NotImplementedError): loss_fn_deprec.make_value_estimator(td_est) return @@ -4486,6 +4587,7 @@ def test_redq_tensordict_keys(self, td_est): "state_action_value": "state_action_value", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( loss_fn, @@ -4504,6 +4606,7 @@ def test_redq_tensordict_keys(self, td_est): "value": ("value", "state_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -4511,9 +4614,10 @@ def test_redq_tensordict_keys(self, td_est): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) @pytest.mark.parametrize("deprec", [True, False]) def test_redq_notensordict( - self, action_key, observation_key, reward_key, done_key, deprec + self, action_key, observation_key, reward_key, done_key, terminated_key, deprec ): torch.manual_seed(self.seed) td = self._create_mock_data_redq( @@ -4521,6 +4625,7 @@ def test_redq_notensordict( observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor( @@ -4541,13 +4646,19 @@ def test_redq_notensordict( actor_network=actor, qvalue_network=qvalue, ) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { action_key: td.get(action_key), observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -4646,6 +4757,7 @@ def _create_mock_data_cql( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -4653,6 +4765,7 @@ def _create_mock_data_cql( "next": { "observation": next_obs, "done": done, + "terminated": terminated, "reward": reward, }, "action": action, @@ -4676,6 +4789,7 @@ def _create_seq_mock_data_cql( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -4684,6 +4798,7 @@ def _create_seq_mock_data_cql( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -4736,7 +4851,7 @@ def test_cql( **kwargs, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est is ValueEstimators.GAE: with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -5118,6 +5233,7 @@ def _create_mock_data_ppo( action_key="action", reward_key="reward", done_key="done", + terminated_key="terminated", sample_log_prob_key="sample_log_prob", ): # create a tensordict @@ -5129,6 +5245,7 @@ def _create_mock_data_ppo( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -5136,6 +5253,7 @@ def _create_mock_data_ppo( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -5168,6 +5286,7 @@ def _create_seq_mock_data_ppo( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 @@ -5178,6 +5297,7 @@ def _create_seq_mock_data_ppo( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -5195,7 +5315,7 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): @@ -5208,13 +5328,6 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) - elif advantage == "vtrace": - advantage = VTrace( - gamma=0.9, - value_network=value, - actor_network=actor, - differentiable=gradient_mode, - ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -5281,7 +5394,7 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode): loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) - @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_ppo_shared(self, loss_class, device, advantage): torch.manual_seed(self.seed) @@ -5294,12 +5407,6 @@ def test_ppo_shared(self, loss_class, device, advantage): lmbda=0.9, value_network=value, ) - elif advantage == "vtrace": - advantage = VTrace( - gamma=0.9, - value_network=value, - actor_network=actor, - ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -5431,7 +5538,7 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses): ) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): @@ -5445,13 +5552,6 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) - elif advantage == "vtrace": - advantage = VTrace( - gamma=0.9, - value_network=value, - actor_network=actor, - differentiable=gradient_mode, - ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -5514,7 +5614,6 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, - ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -5532,6 +5631,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est): "action": "action", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -5550,11 +5650,12 @@ def test_ppo_tensordict_keys(self, loss_class, td_est): "value": ("value", value_key), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) - @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): """Test PPO loss module with non-default tensordict keys.""" @@ -5582,13 +5683,6 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): value_network=value, differentiable=gradient_mode, ) - elif advantage == "vtrace": - advantage = VTrace( - gamma=0.9, - value_network=value, - actor_network=actor, - differentiable=gradient_mode, - ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -5661,6 +5755,7 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) def test_ppo_notensordict( self, loss_class, @@ -5669,6 +5764,7 @@ def test_ppo_notensordict( observation_key, reward_key, done_key, + terminated_key, ): torch.manual_seed(self.seed) td = self._create_mock_data_ppo( @@ -5677,6 +5773,7 @@ def test_ppo_notensordict( sample_log_prob_key=sample_log_prob_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor(observation_key=observation_key) @@ -5687,6 +5784,7 @@ def test_ppo_notensordict( action=action_key, reward=reward_key, done=done_key, + terminated=terminated_key, sample_log_prob=sample_log_prob_key, ) @@ -5696,6 +5794,7 @@ def test_ppo_notensordict( sample_log_prob_key: td.get(sample_log_prob_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size, names=["time"]).unflatten_keys("_") @@ -5736,7 +5835,6 @@ def _create_mock_actor( action_dim=4, device="cpu", observation_key="observation", - sample_log_prob_key="sample_log_prob", ): # Actor action_spec = BoundedTensorSpec( @@ -5748,11 +5846,9 @@ def _create_mock_actor( ) actor = ProbabilisticActor( module=module, - distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, - return_log_prob=True, - log_prob_key=sample_log_prob_key, + distribution_class=TanhNormal, ) return actor.to(device) @@ -5773,65 +5869,6 @@ def _create_mock_value( ) return value.to(device) - def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"): - # Actor - action_spec = BoundedTensorSpec( - -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) - ) - base_layer = nn.Linear(obs_dim, 5) - net = NormalParamWrapper( - nn.Sequential(base_layer, nn.Linear(5, 2 * action_dim)) - ) - module = TensorDictModule( - net, in_keys=["observation"], out_keys=["loc", "scale"] - ) - actor = ProbabilisticActor( - module=module, - distribution_class=TanhNormal, - in_keys=["loc", "scale"], - spec=action_spec, - return_log_prob=True, - ) - module = nn.Sequential(base_layer, nn.Linear(5, 1)) - value = ValueOperator( - module=module, - in_keys=["observation"], - ) - return actor.to(device), value.to(device) - - def _create_mock_actor_value_shared( - self, batch=2, obs_dim=3, action_dim=4, device="cpu" - ): - # Actor - action_spec = BoundedTensorSpec( - -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) - ) - base_layer = nn.Linear(obs_dim, 5) - common = TensorDictModule( - base_layer, in_keys=["observation"], out_keys=["hidden"] - ) - net = nn.Sequential(nn.Linear(5, 2 * action_dim), NormalParamExtractor()) - module = TensorDictModule(net, in_keys=["hidden"], out_keys=["loc", "scale"]) - actor_head = ProbabilisticActor( - module=module, - distribution_class=TanhNormal, - in_keys=["loc", "scale"], - spec=action_spec, - return_log_prob=True, - ) - module = nn.Linear(5, 1) - value_head = ValueOperator( - module=module, - in_keys=["hidden"], - ) - model = ActorValueOperator(common, actor_head, value_head).to(device) - return model, model.get_policy_operator(), model.get_value_operator() - - def _create_mock_distributional_actor( - self, batch=2, obs_dim=3, action_dim=4, atoms=0, vmin=1, vmax=5 - ): - raise NotImplementedError - def _create_mock_common_layer_setup( self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2, T=10 ): @@ -5860,10 +5897,12 @@ def _create_mock_common_layer_setup( "action": torch.randn(*batch, n_act), "sample_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -5895,8 +5934,11 @@ def _create_seq_mock_data_a2c( action_dim=4, atoms=None, device="cpu", - sample_log_prob_key="sample_log_prob", action_key="action", + observation_key="observation", + reward_key="reward", + done_key="done", + terminated_key="terminated", ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -5910,23 +5952,26 @@ def _create_seq_mock_data_a2c( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) - mask = torch.ones(batch, T, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, dtype=torch.bool, device=device) params_mean = torch.randn_like(action) / 10 params_scale = torch.rand_like(action) / 10 td = TensorDict( batch_size=(batch, T), source={ - "observation": obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + observation_key: obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "next": { - "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), - "done": done, - "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), + observation_key: next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), + done_key: done, + terminated_key: terminated, + reward_key: reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0), - sample_log_prob_key: ( - torch.randn_like(action[..., 1]) / 10 - ).masked_fill_(~mask, 0.0), + "sample_log_prob": torch.randn_like(action[..., 1]).masked_fill_( + ~mask, 0.0 + ) + / 10, "loc": params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0), "scale": params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0), }, @@ -5936,7 +5981,7 @@ def _create_seq_mock_data_a2c( return td @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_a2c(self, device, gradient_mode, advantage, td_est): @@ -5949,13 +5994,6 @@ def test_a2c(self, device, gradient_mode, advantage, td_est): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) - elif advantage == "vtrace": - advantage = VTrace( - gamma=0.9, - value_network=value, - actor_network=actor, - differentiable=gradient_mode, - ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -6080,7 +6118,7 @@ def test_a2c_separate_losses(self, separate_losses): not _has_functorch, reason=f"functorch not found, {FUNCTORCH_ERR}" ) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_a2c_diff(self, device, gradient_mode, advantage): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): @@ -6094,13 +6132,6 @@ def test_a2c_diff(self, device, gradient_mode, advantage): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) - elif advantage == "vtrace": - advantage = VTrace( - gamma=0.9, - value_network=value, - actor_network=actor, - differentiable=gradient_mode, - ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -6154,7 +6185,6 @@ def test_a2c_diff(self, device, gradient_mode, advantage): ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, - ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -6171,7 +6201,7 @@ def test_a2c_tensordict_keys(self, td_est): "action": "action", "reward": "reward", "done": "done", - "sample_log_prob": "sample_log_prob", + "terminated": "terminated", } self.tensordict_keys_test( @@ -6190,83 +6220,59 @@ def test_a2c_tensordict_keys(self, td_est): "value": ("value", "value_state_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), - "sample_log_prob": ("sample_log_prob", "sample_log_prob_test"), + "terminated": ("terminated", ("terminated", "test")), } - self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) - @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) - @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_tensordict_keys_run(self, device, advantage, td_est): + def test_a2c_tensordict_keys_run(self, device): """Test A2C loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True - tensor_keys = { - "advantage": "advantage_test", - "value_target": "value_target_test", - "value": "state_value_test", - "sample_log_prob": "sample_log_prob_test", - "action": "action_test", - } + advantage_key = "advantage_test" + value_target_key = "value_target_test" + value_key = "state_value_test" + action_key = "action_test" + reward_key = "reward_test" + done_key = ("done", "test") + terminated_key = ("terminated", "test") td = self._create_seq_mock_data_a2c( - sample_log_prob_key=tensor_keys["sample_log_prob"], - action_key=tensor_keys["action"], device=device, + action_key=action_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, ) - actor = self._create_mock_actor( - device=device, sample_log_prob_key=tensor_keys["sample_log_prob"] + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device, out_keys=[value_key]) + advantage = GAE( + gamma=0.9, + lmbda=0.9, + value_network=value, + differentiable=gradient_mode, + ) + advantage.set_keys( + advantage=advantage_key, + value_target=value_target_key, + value=value_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, ) - value = self._create_mock_value(device=device, out_keys=[tensor_keys["value"]]) - - if advantage == "gae": - advantage = GAE( - gamma=0.9, - lmbda=0.9, - value_network=value, - differentiable=gradient_mode, - ) - elif advantage == "vtrace": - advantage = VTrace( - gamma=0.9, - value_network=value, - actor_network=actor, - differentiable=gradient_mode, - ) - elif advantage == "td": - advantage = TD1Estimator( - gamma=0.9, - value_network=value, - differentiable=gradient_mode, - ) - elif advantage == "td_lambda": - advantage = TDLambdaEstimator( - gamma=0.9, - lmbda=0.9, - value_network=value, - differentiable=gradient_mode, - ) - elif advantage is None: - pass - else: - raise NotImplementedError - loss_fn = A2CLoss(actor, value, loss_critic_type="l2") - loss_fn.set_keys(**tensor_keys) - if advantage is not None: - # collect tensordict key names for the advantage module - adv_keys = { - key: value - for key, value in tensor_keys.items() - if key in asdict(GAE._AcceptedKeys()).keys() - } - advantage.set_keys(**adv_keys) - advantage(td) - else: - if td_est is not None: - loss_fn.make_value_estimator(td_est) + loss_fn.set_keys( + advantage=advantage_key, + value_target=value_target_key, + value=value_key, + action=action_key, + reward=reward_key, + done=done_key, + terminated=done_key, + ) + + advantage(td) loss = loss_fn(td) loss_critic = loss["loss_critic"] @@ -6300,23 +6306,36 @@ def test_a2c_tensordict_keys_run(self, device, advantage, td_est): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) - def test_a2c_notensordict(self, action_key, observation_key, reward_key, done_key): + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_a2c_notensordict( + self, action_key, observation_key, reward_key, done_key, terminated_key + ): torch.manual_seed(self.seed) actor = self._create_mock_actor(observation_key=observation_key) value = self._create_mock_value(observation_key=observation_key) td = self._create_seq_mock_data_a2c( action_key=action_key, + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, ) loss = A2CLoss(actor, value) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { observation_key: td.get(observation_key), f"next_{observation_key}": td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), action_key: td.get(action_key), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -6350,8 +6369,8 @@ class TestReinforce(LossModuleTestBase): @pytest.mark.parametrize("delay_value", [True, False]) @pytest.mark.parametrize("gradient_mode", [True, False]) - @pytest.mark.parametrize("advantage", ["gae", "vtrace", "td", "td_lambda", None]) - @pytest.mark.parametrize("td_est", list(ValueEstimators)) + @pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None]) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est): n_obs = 3 n_act = 5 @@ -6376,13 +6395,6 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est value_network=get_functional(value_net), differentiable=gradient_mode, ) - elif advantage == "vtrace": - advantage = VTrace( - gamma=0.9, - value_network=get_functional(value_net), - actor_network=actor_net, - differentiable=gradient_mode, - ) elif advantage == "td": advantage = TD1Estimator( gamma=gamma, @@ -6414,9 +6426,9 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est "observation": torch.randn(batch, n_obs), "reward": torch.randn(batch, 1), "done": torch.zeros(batch, 1, dtype=torch.bool), + "terminated": torch.zeros(batch, 1, dtype=torch.bool), }, "action": torch.randn(batch, n_act), - "sample_log_prob": torch.randn(batch, 1), }, [batch], names=["time"], @@ -6467,7 +6479,6 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, - ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -6499,6 +6510,7 @@ def test_reinforce_tensordict_keys(self, td_est): "sample_log_prob": "sample_log_prob", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -6522,6 +6534,7 @@ def test_reinforce_tensordict_keys(self, td_est): "value": ("value", "state_value_test"), "reward": ("reward", "reward_test"), "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -6554,10 +6567,12 @@ def _create_mock_common_layer_setup( "action": torch.randn(*batch, n_act), "sample_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -6654,8 +6669,9 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) def test_reinforce_notensordict( - self, action_key, observation_key, reward_key, done_key + self, action_key, observation_key, reward_key, done_key, terminated_key ): torch.manual_seed(self.seed) n_obs = 3 @@ -6674,19 +6690,26 @@ def test_reinforce_notensordict( spec=UnboundedContinuousTensorSpec(n_act), ) loss = ReinforceLoss(actor=actor_net, critic=value_net) - loss.set_keys(reward=reward_key, done=done_key, action=action_key) + loss.set_keys( + reward=reward_key, + done=done_key, + action=action_key, + terminated=terminated_key, + ) observation = torch.randn(batch, n_obs) action = torch.randn(batch, n_act) next_reward = torch.randn(batch, 1) next_observation = torch.randn(batch, n_obs) next_done = torch.zeros(batch, 1, dtype=torch.bool) + next_terminated = torch.zeros(batch, 1, dtype=torch.bool) kwargs = { action_key: action, observation_key: observation, f"next_{reward_key}": next_reward, f"next_{done_key}": next_done, + f"next_{terminated_key}": next_terminated, f"next_{observation_key}": next_observation, } td = TensorDict(kwargs, [batch]).unflatten_keys("_") @@ -6727,6 +6750,9 @@ def _create_world_model_data( ), "reward": torch.randn(batch_size, temporal_length, 1), "done": torch.zeros(batch_size, temporal_length, dtype=torch.bool), + "terminated": torch.zeros( + batch_size, temporal_length, dtype=torch.bool + ), }, "action": torch.randn(batch_size, temporal_length, 64), }, @@ -7062,7 +7088,7 @@ def test_dreamer_actor(self, device, imagination_horizon, discount_loss, td_est) imagination_horizon=imagination_horizon, discount_loss=discount_loss, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est is ValueEstimators.GAE: with pytest.raises(NotImplementedError): loss_module.make_value_estimator(td_est) return @@ -7151,6 +7177,7 @@ def test_dreamer_actor_tensordict_keys(self, td_est, device): "reward": "reward", "value": "state_value", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( loss_fn, @@ -7656,10 +7683,12 @@ def _create_mock_common_layer_setup( "action": torch.randn(*batch, n_act), "sample_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { "obs": torch.randn(*batch, n_obs), "reward": torch.randn(*batch, 1), "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), }, }, batch, @@ -7706,6 +7735,7 @@ def _create_mock_data_iql( observation_key="observation", action_key="action", done_key="done", + terminated_key="terminated", reward_key="reward", ): # create a tensordict @@ -7717,6 +7747,7 @@ def _create_mock_data_iql( action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, 1, device=device) done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch,), source={ @@ -7724,6 +7755,7 @@ def _create_mock_data_iql( "next": { observation_key: next_obs, done_key: done, + terminated_key: terminated, reward_key: reward, }, action_key: action, @@ -7747,6 +7779,7 @@ def _create_seq_mock_data_iql( action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) reward = torch.randn(batch, T, 1, device=device) done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) mask = torch.ones(batch, T, dtype=torch.bool, device=device) td = TensorDict( batch_size=(batch, T), @@ -7755,6 +7788,7 @@ def _create_seq_mock_data_iql( "next": { "observation": next_obs.masked_fill_(~mask.unsqueeze(-1), 0.0), "done": done, + "terminated": terminated, "reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0), }, "collector": {"mask": mask}, @@ -7794,7 +7828,7 @@ def test_iql( expectile=expectile, loss_function="l2", ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est is ValueEstimators.GAE: with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -8212,6 +8246,7 @@ def test_iql_tensordict_keys(self, td_est): "value": "state_value", "reward": "reward", "done": "done", + "terminated": "terminated", } self.tensordict_keys_test( @@ -8231,6 +8266,7 @@ def test_iql_tensordict_keys(self, td_est): key_mapping = { "value": ("value", "value_test"), "done": ("done", "done_test"), + "terminated": ("terminated", "terminated_test"), "reward": ("reward", ("reward", "test")), } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @@ -8239,13 +8275,17 @@ def test_iql_tensordict_keys(self, td_est): @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) - def test_iql_notensordict(self, action_key, observation_key, reward_key, done_key): + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_iql_notensordict( + self, action_key, observation_key, reward_key, done_key, terminated_key + ): torch.manual_seed(self.seed) td = self._create_mock_data_iql( action_key=action_key, observation_key=observation_key, reward_key=reward_key, done_key=done_key, + terminated_key=terminated_key, ) actor = self._create_mock_actor(observation_key=observation_key) @@ -8257,13 +8297,19 @@ def test_iql_notensordict(self, action_key, observation_key, reward_key, done_ke value = self._create_mock_value(observation_key=observation_key) loss = IQLLoss(actor_network=actor, qvalue_network=qvalue, value_network=value) - loss.set_keys(action=action_key, reward=reward_key, done=done_key) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) kwargs = { action_key: td.get(action_key), observation_key: td.get(observation_key), f"next_{reward_key}": td.get(("next", reward_key)), f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), f"next_{observation_key}": td.get(("next", observation_key)), } td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") @@ -8581,24 +8627,77 @@ class TestValues: @pytest.mark.parametrize("gamma", [0.1, 0.5, 0.99]) @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) @pytest.mark.parametrize("N", [(3,), (7, 3)]) - @pytest.mark.parametrize("T", [3, 5, 200]) + @pytest.mark.parametrize("T", [200, 5, 3]) # @pytest.mark.parametrize("random_gamma,rolling_gamma", [[True, False], [True, True], [False, None]]) @pytest.mark.parametrize("random_gamma,rolling_gamma", [[False, None]]) def test_tdlambda(self, device, gamma, lmbda, N, T, random_gamma, rolling_gamma): torch.manual_seed(0) - done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool).bernoulli_(0.1) + done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone().bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) - next_state_value = torch.randn(*N, T, 1, device=device) if random_gamma: gamma = torch.rand_like(reward) * gamma + next_state_value = torch.cat( + [state_value[..., 1:, :], torch.randn_like(state_value[..., -1:, :])], -2 + ) r1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done, rolling_gamma + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) r2 = td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done, rolling_gamma + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + ) + r3, *_ = vec_generalized_advantage_estimate( + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + ) + torch.testing.assert_close(r3, r2, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(r1, r3, rtol=1e-4, atol=1e-4) + + # test when v' is not v from next step (not working with gae) + next_state_value = torch.randn_like(next_state_value) + r1 = vec_td_lambda_advantage_estimate( + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + ) + r2 = td_lambda_advantage_estimate( + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) @@ -8615,7 +8714,9 @@ def test_tdlambda_multi( torch.manual_seed(0) D = feature_dim time_dim = -1 - len(D) - done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool).bernoulli_(0.1) + done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool) + terminated = done.clone().bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, *D, device=device) state_value = torch.randn(*N, T, *D, device=device) next_state_value = torch.randn(*N, T, *D, device=device) @@ -8628,8 +8729,9 @@ def test_tdlambda_multi( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) r2 = td_lambda_advantage_estimate( @@ -8638,8 +8740,9 @@ def test_tdlambda_multi( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) if len(D) == 2: @@ -8651,8 +8754,9 @@ def test_tdlambda_multi( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], - rolling_gamma, + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8668,8 +8772,9 @@ def test_tdlambda_multi( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], - rolling_gamma, + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8686,8 +8791,9 @@ def test_tdlambda_multi( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], - rolling_gamma, + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8702,8 +8808,9 @@ def test_tdlambda_multi( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], - rolling_gamma, + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8723,7 +8830,9 @@ def test_tdlambda_multi( def test_td1(self, device, gamma, N, T, random_gamma, rolling_gamma): torch.manual_seed(0) - done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool).bernoulli_(0.1) + done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone().bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -8731,10 +8840,22 @@ def test_td1(self, device, gamma, N, T, random_gamma, rolling_gamma): gamma = torch.rand_like(reward) * gamma r1 = vec_td1_advantage_estimate( - gamma, state_value, next_state_value, reward, done, rolling_gamma + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) r2 = td1_advantage_estimate( - gamma, state_value, next_state_value, reward, done, rolling_gamma + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) @@ -8751,7 +8872,9 @@ def test_td1_multi( D = feature_dim time_dim = -1 - len(D) - done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool).bernoulli_(0.1) + done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool) + terminated = done.clone().bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, *D, device=device) state_value = torch.randn(*N, T, *D, device=device) next_state_value = torch.randn(*N, T, *D, device=device) @@ -8763,8 +8886,9 @@ def test_td1_multi( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) r2 = td1_advantage_estimate( @@ -8772,8 +8896,9 @@ def test_td1_multi( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, time_dim=time_dim, ) if len(D) == 2: @@ -8784,8 +8909,9 @@ def test_td1_multi( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], - rolling_gamma, + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8800,8 +8926,9 @@ def test_td1_multi( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], - rolling_gamma, + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8817,8 +8944,9 @@ def test_td1_multi( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], - rolling_gamma, + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8832,8 +8960,9 @@ def test_td1_multi( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], - rolling_gamma, + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + rolling_gamma=rolling_gamma, time_dim=-2, ) for i in range(D[0]) @@ -8851,22 +8980,36 @@ def test_td1_multi( @pytest.mark.parametrize("N", [(1,), (3,), (7, 3)]) @pytest.mark.parametrize("T", [200, 5, 3]) @pytest.mark.parametrize("dtype", [torch.float, torch.double]) - @pytest.mark.parametrize("has_done", [True, False]) + @pytest.mark.parametrize("has_done", [False, True]) def test_gae(self, device, gamma, lmbda, N, T, dtype, has_done): torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone() if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device, dtype=dtype) state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) next_state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) r1 = vec_generalized_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) r2 = generalized_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) @@ -8891,8 +9034,10 @@ def test_gae_param_as_tensor( T = 200 done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone() if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device, dtype=dtype) state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) next_state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) @@ -8912,10 +9057,22 @@ def test_gae_param_as_tensor( lmbda_vec = lmbda r1 = vec_generalized_advantage_estimate( - gamma_vec, lmbda_vec, state_value, next_state_value, reward, done + gamma_vec, + lmbda_vec, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) r2 = generalized_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) @@ -8936,8 +9093,10 @@ def test_gae_multidim( torch.manual_seed(0) done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool) + terminated = done.clone() if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, *D, device=device, dtype=dtype) state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) next_state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) @@ -8948,7 +9107,8 @@ def test_gae_multidim( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, time_dim=time_dim, ) r2 = generalized_advantage_estimate( @@ -8957,7 +9117,8 @@ def test_gae_multidim( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, time_dim=time_dim, ) if len(D) == 2: @@ -8968,7 +9129,8 @@ def test_gae_multidim( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], + done=done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], time_dim=-2, ) for i in range(D[0]) @@ -8981,7 +9143,8 @@ def test_gae_multidim( state_value[..., i : i + 1, j], next_state_value[..., i : i + 1, j], reward[..., i : i + 1, j], - done[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + done=done[..., i : i + 1, j], time_dim=-2, ) for i in range(D[0]) @@ -8995,7 +9158,8 @@ def test_gae_multidim( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], time_dim=-2, ) for i in range(D[0]) @@ -9007,7 +9171,8 @@ def test_gae_multidim( state_value[..., i : i + 1], next_state_value[..., i : i + 1], reward[..., i : i + 1], - done[..., i : i + 1], + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], time_dim=-2, ) for i in range(D[0]) @@ -9028,7 +9193,7 @@ def test_gae_multidim( @pytest.mark.parametrize("gamma", [0.5, 0.99, 0.1]) @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) @pytest.mark.parametrize("N", [(3,), (7, 3)]) - @pytest.mark.parametrize("T", [3, 5, 200]) + @pytest.mark.parametrize("T", [200, 5, 3]) @pytest.mark.parametrize("has_done", [True, False]) def test_tdlambda_tensor_gamma(self, device, gamma, lmbda, N, T, has_done): """Tests vec_td_lambda_advantage_estimate against itself with @@ -9038,32 +9203,61 @@ def test_tdlambda_tensor_gamma(self, device, gamma, lmbda, N, T, has_done): torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) gamma_tensor = torch.full((*N, T, 1), gamma, device=device) - + # if len(N) == 2: + # print(terminated[4, 0, -10:]) + # print(done[4, 0, -10:]) v1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_tensor, lmbda, state_value, next_state_value, reward, done + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory gamma_tensor[..., -1, :] = 0.0 v1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_tensor, lmbda, state_value, next_state_value, reward, done + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9089,8 +9283,10 @@ def test_tdlambda_tensor_gamma_single_element( torch.manual_seed(0) done = torch.zeros(*N, T, F, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, F, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, F, device=device) state_value = torch.randn(*N, T, F, device=device) next_state_value = torch.randn(*N, T, F, device=device) @@ -9108,22 +9304,47 @@ def test_tdlambda_tensor_gamma_single_element( lmbda_vec = lmbda v1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_vec, lmbda_vec, state_value, next_state_value, reward, done + gamma_vec, + lmbda_vec, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory v1 = vec_td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_vec, lmbda_vec, state_value, next_state_value, reward, done + gamma_vec, + lmbda_vec, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9141,8 +9362,10 @@ def test_td1_tensor_gamma(self, device, gamma, N, T, has_done): torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -9150,23 +9373,44 @@ def test_td1_tensor_gamma(self, device, gamma, N, T, has_done): gamma_tensor = torch.full((*N, T, 1), gamma, device=device) v1 = vec_td1_advantage_estimate( - gamma, state_value, next_state_value, reward, done + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td1_advantage_estimate( - gamma_tensor, state_value, next_state_value, reward, done + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory gamma_tensor[..., -1, :] = 0.0 v1 = vec_td1_advantage_estimate( - gamma, state_value, next_state_value, reward, done + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td1_advantage_estimate( - gamma_tensor, state_value, next_state_value, reward, done + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9188,8 +9432,10 @@ def test_vectdlambda_tensor_gamma( torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -9197,23 +9443,48 @@ def test_vectdlambda_tensor_gamma( gamma_tensor = torch.full((*N, T, 1), gamma, device=device) v1 = td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_tensor, lmbda, state_value, next_state_value, reward, done + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory gamma_tensor[..., -1, :] = 0.0 v1 = td_lambda_advantage_estimate( - gamma, lmbda, state_value, next_state_value, reward, done + gamma, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) v2 = vec_td_lambda_advantage_estimate( - gamma_tensor, lmbda, state_value, next_state_value, reward, done + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9234,28 +9505,55 @@ def test_vectd1_tensor_gamma( torch.manual_seed(0) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) gamma_tensor = torch.full((*N, T, 1), gamma, device=device) - v1 = td1_advantage_estimate(gamma, state_value, next_state_value, reward, done) + v1 = td1_advantage_estimate( + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + ) v2 = vec_td1_advantage_estimate( - gamma_tensor, state_value, next_state_value, reward, done + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) # same with last done being true done[..., -1, :] = True # terminating trajectory + terminated[..., -1, :] = True # terminating trajectory gamma_tensor[..., -1, :] = 0.0 - v1 = td1_advantage_estimate(gamma, state_value, next_state_value, reward, done) + v1 = td1_advantage_estimate( + gamma, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + ) v2 = vec_td1_advantage_estimate( - gamma_tensor, state_value, next_state_value, reward, done + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9277,8 +9575,10 @@ def test_vectdlambda_rand_gamma( torch.manual_seed(seed) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -9292,8 +9592,9 @@ def test_vectdlambda_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) if rolling_gamma is False and not done[..., 1:, :][done[..., :-1, :]].all(): # if a not-done follows a done, then rolling_gamma=False cannot be used @@ -9306,8 +9607,24 @@ def test_vectdlambda_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + ) + return + elif rolling_gamma is False: + with pytest.raises( + NotImplementedError, match=r"The vectorized version of TD" + ): + vec_td_lambda_advantage_estimate( + gamma_tensor, + lmbda, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) return v2 = vec_td_lambda_advantage_estimate( @@ -9316,8 +9633,9 @@ def test_vectdlambda_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9337,8 +9655,10 @@ def test_vectd1_rand_gamma( torch.manual_seed(seed) done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) if has_done: - done = done.bernoulli_(0.1) + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) next_state_value = torch.randn(*N, T, 1, device=device) @@ -9351,10 +9671,14 @@ def test_vectd1_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) - if rolling_gamma is False and not done[..., 1:, :][done[..., :-1, :]].all(): + if ( + rolling_gamma is False + and not terminated[..., 1:, :][terminated[..., :-1, :]].all() + ): # if a not-done follows a done, then rolling_gamma=False cannot be used with pytest.raises( NotImplementedError, match="When using rolling_gamma=False" @@ -9364,8 +9688,23 @@ def test_vectd1_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, + ) + return + elif rolling_gamma is False: + with pytest.raises( + NotImplementedError, match="The vectorized version of TD" + ): + vec_td1_advantage_estimate( + gamma_tensor, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) return v2 = vec_td1_advantage_estimate( @@ -9373,8 +9712,9 @@ def test_vectd1_rand_gamma( state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9426,8 +9766,10 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): lmbda = torch.rand([]).item() - done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) - done[..., T // 2 - 1, :] = 1 + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated[..., T // 2 - 1, :] = 1 + done = terminated.clone() + done[..., -1, :] = 1 reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) @@ -9442,8 +9784,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) v1a = td_lambda_advantage_estimate( gamma_tensor[..., : T // 2, :], @@ -9451,8 +9794,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], - rolling_gamma, + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], + rolling_gamma=rolling_gamma, ) v1b = td_lambda_advantage_estimate( gamma_tensor[..., T // 2 :, :], @@ -9460,8 +9804,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], - rolling_gamma, + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(v1, torch.cat([v1a, v1b], -2), rtol=1e-4, atol=1e-4) @@ -9475,8 +9820,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) return v2 = vec_td_lambda_advantage_estimate( @@ -9485,8 +9831,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value, next_state_value, reward, - done, - rolling_gamma, + done=done, + terminated=terminated, + rolling_gamma=rolling_gamma, ) v2a = vec_td_lambda_advantage_estimate( gamma_tensor[..., : T // 2, :], @@ -9494,8 +9841,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], - rolling_gamma, + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], + rolling_gamma=rolling_gamma, ) v2b = vec_td_lambda_advantage_estimate( gamma_tensor[..., T // 2 :, :], @@ -9503,8 +9851,9 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], - rolling_gamma, + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], + rolling_gamma=rolling_gamma, ) torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) @@ -9516,22 +9865,17 @@ def test_successive_traj_tdlambda(self, device, N, T, rolling_gamma): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("N", [(3,), (3, 7)]) @pytest.mark.parametrize("T", [3, 5, 200]) - def test_successive_traj_tdadv( - self, - device, - N, - T, - ): + def test_successive_traj_tdadv(self, device, N, T): """Tests td_lambda_advantage_estimate against vec_td_lambda_advantage_estimate with gamma being a random tensor """ torch.manual_seed(0) - lmbda = torch.rand([]).item() - + # for td0, a done that is not terminated has no effect done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) done[..., T // 2 - 1, :] = 1 + terminated = done.clone() reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) @@ -9545,21 +9889,24 @@ def test_successive_traj_tdadv( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, ) v1a = td0_advantage_estimate( gamma_tensor[..., : T // 2, :], state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], ) v1b = td0_advantage_estimate( gamma_tensor[..., T // 2 :, :], state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], ) torch.testing.assert_close(v1, torch.cat([v1a, v1b], -2), rtol=1e-4, atol=1e-4) @@ -9580,8 +9927,10 @@ def test_successive_traj_gae( lmbda = torch.rand([]).item() - done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) - done[..., T // 2 - 1, :] = 1 + terminated = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated[..., T // 2 - 1, :] = 1 + done = terminated.clone() + done[..., -1, :] = 1 reward = torch.randn(*N, T, 1, device=device) state_value = torch.randn(*N, T, 1, device=device) @@ -9596,7 +9945,8 @@ def test_successive_traj_gae( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, )[0] v1a = generalized_advantage_estimate( gamma_tensor, @@ -9604,7 +9954,8 @@ def test_successive_traj_gae( state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], )[0] v1b = generalized_advantage_estimate( gamma_tensor, @@ -9612,7 +9963,8 @@ def test_successive_traj_gae( state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], )[0] torch.testing.assert_close(v1, torch.cat([v1a, v1b], -2), rtol=1e-4, atol=1e-4) @@ -9622,7 +9974,8 @@ def test_successive_traj_gae( state_value, next_state_value, reward, - done, + done=done, + terminated=terminated, )[0] v2a = vec_generalized_advantage_estimate( gamma_tensor, @@ -9630,7 +9983,8 @@ def test_successive_traj_gae( state_value[..., : T // 2, :], next_state_value[..., : T // 2, :], reward[..., : T // 2, :], - done[..., : T // 2, :], + done=done[..., : T // 2, :], + terminated=terminated[..., : T // 2, :], )[0] v2b = vec_generalized_advantage_estimate( gamma_tensor, @@ -9638,7 +9992,8 @@ def test_successive_traj_gae( state_value[..., T // 2 :, :], next_state_value[..., T // 2 :, :], reward[..., T // 2 :, :], - done[..., T // 2 :, :], + done=done[..., T // 2 :, :], + terminated=terminated[..., T // 2 :, :], )[0] torch.testing.assert_close(v1, v2, rtol=1e-4, atol=1e-4) torch.testing.assert_close(v2, torch.cat([v2a, v2b], -2), rtol=1e-4, atol=1e-4) From 9e1d64b4ad0b332ae2369016816c824f6a916a80 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 4 Oct 2023 09:35:11 +0200 Subject: [PATCH 053/109] introduce review feedback --- examples/impala/test_vtrace_examples.py | 52 +++++++++++++++++++++++++ examples/impala/utils.py | 19 +++++---- 2 files changed, 61 insertions(+), 10 deletions(-) create mode 100644 examples/impala/test_vtrace_examples.py diff --git a/examples/impala/test_vtrace_examples.py b/examples/impala/test_vtrace_examples.py new file mode 100644 index 00000000000..59848e3166a --- /dev/null +++ b/examples/impala/test_vtrace_examples.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from tensordict import TensorDict +from tensordict.nn import TensorDictModule +from torchrl.modules.distributions import OneHotCategorical +from torchrl.modules.tensordict_module.actors import ProbabilisticActor +from torchrl.objectives.value.advantages import VTrace, GAE + +value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]) +actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]) +actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, +) +vtrace_module = VTrace( + gamma=0.98, + value_network=value_net, + actor_network=actor_net, + differentiable=False, +) +gae_module = GAE( + gamma=0.98, + lmbda=0.95, + value_network=value_net, + differentiable=False, +) + +obs, next_obs = torch.randn(2, 1, 10, 3) +reward = torch.randn(1, 10, 1) +done = torch.zeros(1, 10, 1, dtype=torch.bool) +terminated = torch.zeros(1, 10, 1, dtype=torch.bool) +sample_log_prob = torch.randn(1, 10, 1) +tensordict = TensorDict( + { + "obs": obs, + "done": done, + "terminated": terminated, + "sample_log_prob": sample_log_prob, + "next": { + "obs": next_obs, + "reward": reward, + "done": done, + "terminated": terminated, + }, + }, + batch_size=[1, 10], +) +advantage, value_target = gae_module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated, sample_log_prob=sample_log_prob) +advantage, value_target = vtrace_module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated, sample_log_prob=sample_log_prob) diff --git a/examples/impala/utils.py b/examples/impala/utils.py index d0fb4a6a262..39c32b70d7b 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -5,19 +5,10 @@ import os -# To pickle the environment, in particular the EndOfLifeTransform, we need to -# add the utils path to the PYTHONPATH -utils_path = os.path.abspath(os.path.abspath(os.path.dirname(__file__))) -current_pythonpath = os.environ.get("PYTHONPATH", "") -new_pythonpath = f"{utils_path}:{current_pythonpath}" -os.environ["PYTHONPATH"] = new_pythonpath - - import torch.nn import torch.optim from tensordict.nn import TensorDictModule from torchrl.data import CompositeSpec, UnboundedDiscreteTensorSpec -from torchrl.data.tensor_specs import DiscreteBox from torchrl.envs import ( CatFrames, DoubleToFloat, @@ -40,10 +31,18 @@ MLP, OneHotCategorical, ProbabilisticActor, - TanhNormal, ValueOperator, ) + +# To pickle the environment, in particular the EndOfLifeTransform, we need to +# add the utils path to the PYTHONPATH +utils_path = os.path.abspath(os.path.abspath(os.path.dirname(__file__))) +current_pythonpath = os.environ.get("PYTHONPATH", "") +new_pythonpath = f"{utils_path}:{current_pythonpath}" +os.environ["PYTHONPATH"] = new_pythonpath + + # ==================================================================== # Environment utils # -------------------------------------------------------------------- From 224ae9188af56ff84b2ff3baa90663039069bc98 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 4 Oct 2023 10:02:29 +0200 Subject: [PATCH 054/109] torch compile --- examples/impala/test_vtrace_examples.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/examples/impala/test_vtrace_examples.py b/examples/impala/test_vtrace_examples.py index 59848e3166a..cd74a5860f1 100644 --- a/examples/impala/test_vtrace_examples.py +++ b/examples/impala/test_vtrace_examples.py @@ -4,7 +4,7 @@ from tensordict.nn import TensorDictModule from torchrl.modules.distributions import OneHotCategorical from torchrl.modules.tensordict_module.actors import ProbabilisticActor -from torchrl.objectives.value.advantages import VTrace, GAE +from torchrl.objectives.value.advantages import GAE, VTrace value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]) actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]) @@ -48,5 +48,19 @@ }, batch_size=[1, 10], ) -advantage, value_target = gae_module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated, sample_log_prob=sample_log_prob) -advantage, value_target = vtrace_module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated, sample_log_prob=sample_log_prob) +advantage, value_target = gae_module( + obs=obs, + reward=reward, + done=done, + next_obs=next_obs, + terminated=terminated, + sample_log_prob=sample_log_prob, +) +advantage, value_target = vtrace_module( + obs=obs, + reward=reward, + done=done, + next_obs=next_obs, + terminated=terminated, + sample_log_prob=sample_log_prob, +) From e5438889aaab783e35116b321da36861b5132431 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 4 Oct 2023 10:03:41 +0200 Subject: [PATCH 055/109] torch compile --- torchrl/objectives/value/advantages.py | 44 ++++++++++++++++++++++---- torchrl/objectives/value/functional.py | 17 +++++++--- 2 files changed, 50 insertions(+), 11 deletions(-) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 5fae31846e6..a2182ebe700 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -1478,37 +1478,67 @@ def forward( An updated TensorDict with an advantage and a value_error keys as defined in the constructor. Examples: - >>> from tensordict import TensorDict - >>> value_net = TensorDictModule( - ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + >>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]) + >>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]) + >>> actor_net = ProbabilisticActor( + ... module=actor_net, + ... in_keys=["logits"], + ... out_keys=["action"], + ... distribution_class=OneHotCategorical, + ... return_log_prob=True, ... ) >>> module = VTrace( ... gamma=0.98, ... value_network=value_net, + ... actor_network=actor_net, ... differentiable=False, ... ) >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs}, "done": done, "reward": reward}, [1, 10]) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> sample_log_prob = torch.randn(1, 10, 1) + >>> tensordict = TensorDict({ + ... "obs": obs, + ... "done": done, + ... "terminated": terminated, + ... "sample_log_prob": sample_log_prob, + ... "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated}, + ... }, batch_size=[1, 10]) >>> _ = module(tensordict) >>> assert "advantage" in tensordict.keys() The module supports non-tensordict (i.e. unpacked tensordict) inputs too: Examples: - >>> value_net = TensorDictModule( - ... nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] + >>> value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]) + >>> actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]) + >>> actor_net = ProbabilisticActor( + ... module=actor_net, + ... in_keys=["logits"], + ... out_keys=["action"], + ... distribution_class=OneHotCategorical, + ... return_log_prob=True, ... ) >>> module = VTrace( ... gamma=0.98, ... value_network=value_net, + ... actor_network=actor_net, ... differentiable=False, ... ) >>> obs, next_obs = torch.randn(2, 1, 10, 3) >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs) + >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) + >>> sample_log_prob = torch.randn(1, 10, 1) + >>> tensordict = TensorDict({ + ... "obs": obs, + ... "done": done, + ... "terminated": terminated, + ... "sample_log_prob": sample_log_prob, + ... "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated}, + ... }, batch_size=[1, 10]) + >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated, sample_log_prob=sample_log_prob) """ if tensordict.batch_dims < 1: diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 1ad977e9d63..f164feaf74f 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -1217,6 +1217,7 @@ def vec_td_lambda_advantage_estimate( # ----- +@torch.compile @_transpose_time def vtrace_advantage_estimate( gamma: float, @@ -1226,6 +1227,7 @@ def vtrace_advantage_estimate( next_state_value: torch.Tensor, reward: torch.Tensor, done: torch.Tensor, + terminated: torch.Tensor | None = None, rho_thresh: Union[float, torch.Tensor] = 1.0, c_thresh: Union[float, torch.Tensor] = 1.0, time_dim: int = -2, @@ -1243,6 +1245,7 @@ def vtrace_advantage_estimate( next_state_value (Tensor): value function result with next_state input. reward (Tensor): reward of taking actions in the environment. done (Tensor): boolean flag for end of episode. + terminated (torch.Tensor): a [B, T] boolean tensor containing the terminated states. rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights. c_thresh (Union[float, Tensor]): c clipping parameter for importance weights. time_dim (int): dimension where the time is unrolled. Defaults to -2. @@ -1256,19 +1259,23 @@ def vtrace_advantage_estimate( device = state_value.device not_done = (~done).int() + not_terminated = (~terminated).int() *batch_size, time_steps, lastdim = not_done.shape - discounts = gamma * not_done + done_discounts = gamma * not_done + terminated_discounts = gamma * not_terminated rho = (log_pi - log_mu).exp() clipped_rho = rho.clamp_max(rho_thresh) - deltas = clipped_rho * (reward + discounts * next_state_value - state_value) + deltas = clipped_rho * ( + reward + terminated_discounts * next_state_value - state_value + ) c_thresh = c_thresh.to(device) clipped_c = rho.clamp_max(c_thresh) vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])] for i in reversed(range(time_steps)): discount_t, c_t, delta_t = ( - discounts[..., i, :], + done_discounts[..., i, :], clipped_c[..., i, :], deltas[..., i, :], ) @@ -1279,7 +1286,9 @@ def vtrace_advantage_estimate( vs_t_plus_1 = torch.cat( [vs[..., 1:, :], next_state_value[..., -1:, :]], dim=time_dim ) - advantages = clipped_rho * (reward + discounts * vs_t_plus_1 - state_value) + advantages = clipped_rho * ( + reward + terminated_discounts * vs_t_plus_1 - state_value + ) return advantages, vs From db541c09aa788991622a1dac5792444c6a362c36 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 4 Oct 2023 10:09:33 +0200 Subject: [PATCH 056/109] fix --- torchrl/objectives/value/advantages.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index a2182ebe700..12e2128ac1c 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -1593,6 +1593,7 @@ def forward( # Compute the V-Trace correction done = tensordict.get(("next", self.tensor_keys.done)) + terminated = tensordict.get(("next", self.tensor_keys.terminated)) adv, value_target = vtrace_advantage_estimate( gamma, @@ -1602,6 +1603,7 @@ def forward( next_value, reward, done, + terminated, rho_thresh=self.rho_thresh, c_thresh=self.c_thresh, time_dim=tensordict.ndim - 1, From 937b8197c3b7a43d16e98074c5ab8cc276af1146 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 4 Oct 2023 12:02:09 +0200 Subject: [PATCH 057/109] fix --- examples/impala/test_vtrace_examples.py | 66 ------------------------- 1 file changed, 66 deletions(-) delete mode 100644 examples/impala/test_vtrace_examples.py diff --git a/examples/impala/test_vtrace_examples.py b/examples/impala/test_vtrace_examples.py deleted file mode 100644 index cd74a5860f1..00000000000 --- a/examples/impala/test_vtrace_examples.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch -import torch.nn as nn -from tensordict import TensorDict -from tensordict.nn import TensorDictModule -from torchrl.modules.distributions import OneHotCategorical -from torchrl.modules.tensordict_module.actors import ProbabilisticActor -from torchrl.objectives.value.advantages import GAE, VTrace - -value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]) -actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]) -actor_net = ProbabilisticActor( - module=actor_net, - in_keys=["logits"], - out_keys=["action"], - distribution_class=OneHotCategorical, - return_log_prob=True, -) -vtrace_module = VTrace( - gamma=0.98, - value_network=value_net, - actor_network=actor_net, - differentiable=False, -) -gae_module = GAE( - gamma=0.98, - lmbda=0.95, - value_network=value_net, - differentiable=False, -) - -obs, next_obs = torch.randn(2, 1, 10, 3) -reward = torch.randn(1, 10, 1) -done = torch.zeros(1, 10, 1, dtype=torch.bool) -terminated = torch.zeros(1, 10, 1, dtype=torch.bool) -sample_log_prob = torch.randn(1, 10, 1) -tensordict = TensorDict( - { - "obs": obs, - "done": done, - "terminated": terminated, - "sample_log_prob": sample_log_prob, - "next": { - "obs": next_obs, - "reward": reward, - "done": done, - "terminated": terminated, - }, - }, - batch_size=[1, 10], -) -advantage, value_target = gae_module( - obs=obs, - reward=reward, - done=done, - next_obs=next_obs, - terminated=terminated, - sample_log_prob=sample_log_prob, -) -advantage, value_target = vtrace_module( - obs=obs, - reward=reward, - done=done, - next_obs=next_obs, - terminated=terminated, - sample_log_prob=sample_log_prob, -) From 9e33035ca49bc9bcee49edf164ee9addec5a2d8f Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 4 Oct 2023 12:11:28 +0200 Subject: [PATCH 058/109] tests --- test/test_cost.py | 26 ++++++++++++++------------ torchrl/objectives/value/advantages.py | 12 ++++++++++-- torchrl/objectives/value/functional.py | 1 - 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 6c38e6a8b65..d0ce5332ab4 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -129,6 +129,7 @@ GAE, TD1Estimator, TDLambdaEstimator, + VTrace, ) from torchrl.objectives.value.functional import ( _transpose_time, @@ -139,6 +140,7 @@ vec_generalized_advantage_estimate, vec_td1_advantage_estimate, vec_td_lambda_advantage_estimate, + vtrace_advantage_estimate, ) from torchrl.objectives.value.utils import ( _custom_conv1d, @@ -436,7 +438,7 @@ def test_dqn(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = DQNLoss(actor, loss_function="l2", delay_value=delay_value) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -914,7 +916,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -1399,7 +1401,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): delay_actor=delay_actor, delay_value=delay_value, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -2008,7 +2010,7 @@ def test_td3( delay_actor=delay_actor, delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -2695,7 +2697,7 @@ def test_sac( **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -3437,7 +3439,7 @@ def test_discrete_sac( loss_function="l2", **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4047,7 +4049,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4414,7 +4416,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -4431,7 +4433,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_deprec.make_value_estimator(td_est) return @@ -4851,7 +4853,7 @@ def test_cql( **kwargs, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -7088,7 +7090,7 @@ def test_dreamer_actor(self, device, imagination_horizon, discount_loss, td_est) imagination_horizon=imagination_horizon, discount_loss=discount_loss, ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_module.make_value_estimator(td_est) return @@ -7828,7 +7830,7 @@ def test_iql( expectile=expectile, loss_function="l2", ) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 12e2128ac1c..f75b07cd035 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -1220,7 +1220,7 @@ def forward( >>> reward = torch.randn(1, 10, 1) >>> done = torch.zeros(1, 10, 1, dtype=torch.bool) >>> terminated = torch.zeros(1, 10, 1, dtype=torch.bool) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated) + >>> advantage, value_target = module(obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated) """ if tensordict.batch_dims < 1: @@ -1448,6 +1448,12 @@ def __init__( "Per-value gamma is not supported yet. Gamma must be a scalar." ) + @property + def in_keys(self): + parent_in_keys = super().in_keys + extended_in_keys = parent_in_keys + [self.tensor_keys.sample_log_prob] + return extended_in_keys + @_self_set_skip_existing @_self_set_grad_enabled @dispatch @@ -1538,7 +1544,9 @@ def forward( ... "sample_log_prob": sample_log_prob, ... "next": {"obs": next_obs, "reward": reward, "done": done, "terminated": terminated}, ... }, batch_size=[1, 10]) - >>> advantage, value_target = module(obs=obs, reward=reward, done=done, next_obs=next_obs, terminated=terminated, sample_log_prob=sample_log_prob) + >>> advantage, value_target = module( + ... obs=obs, next_reward=reward, next_done=done, next_obs=next_obs, next_terminated=terminated, sample_log_prob=sample_log_prob + ... ) """ if tensordict.batch_dims < 1: diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index f164feaf74f..3492409c6cc 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -1217,7 +1217,6 @@ def vec_td_lambda_advantage_estimate( # ----- -@torch.compile @_transpose_time def vtrace_advantage_estimate( gamma: float, From 199bc3b50ae3cc0fc662bca8c206d8541a2f4125 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 09:27:16 +0200 Subject: [PATCH 059/109] adapt ppo tests --- test/test_cost.py | 40 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index d0ce5332ab4..81b84258833 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5147,6 +5147,7 @@ def _create_mock_actor( distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, + return_log_prob=True, ) return actor.to(device) @@ -5184,6 +5185,7 @@ def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, + return_log_prob=True, ) module = nn.Sequential(base_layer, nn.Linear(5, 1)) value = ValueOperator( @@ -5210,6 +5212,7 @@ def _create_mock_actor_value_shared( distribution_class=TanhNormal, in_keys=["loc", "scale"], spec=action_spec, + return_log_prob=True, ) module = nn.Linear(5, 1) value_head = ValueOperator( @@ -5317,7 +5320,7 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): @@ -5330,6 +5333,10 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, value_network=value, actor_network=actor, differentiable=gradient_mode + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -5396,7 +5403,7 @@ def test_ppo_state_dict(self, loss_class, device, gradient_mode): loss_fn2.load_state_dict(sd) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_ppo_shared(self, loss_class, device, advantage): torch.manual_seed(self.seed) @@ -5409,6 +5416,12 @@ def test_ppo_shared(self, loss_class, device, advantage): lmbda=0.9, value_network=value, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -5470,6 +5483,7 @@ def test_ppo_shared(self, loss_class, device, advantage): "advantage", ( "gae", + "vtrace", "td", "td_lambda", ), @@ -5489,6 +5503,12 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses): lmbda=0.9, value_network=value, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, @@ -5540,7 +5560,7 @@ def test_ppo_shared_seq(self, loss_class, device, advantage, separate_losses): ) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): @@ -5554,6 +5574,10 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, value_network=value, actor_network=actor, differentiable=gradient_mode + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -5616,6 +5640,7 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, + ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -5657,7 +5682,7 @@ def test_ppo_tensordict_keys(self, loss_class, td_est): self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) @pytest.mark.parametrize("loss_class", (PPOLoss, ClipPPOLoss, KLPENPPOLoss)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): """Test PPO loss module with non-default tensordict keys.""" @@ -5685,6 +5710,13 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): value_network=value, differentiable=gradient_mode, ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, From f1b11dd541b57ce6069aa62a1c076dc1516ffad0 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 09:51:44 +0200 Subject: [PATCH 060/109] adapt ppo tests --- test/test_cost.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 81b84258833..1ea165f1751 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5133,6 +5133,7 @@ def _create_mock_actor( action_dim=4, device="cpu", observation_key="observation", + sample_log_prob_key="sample_log_prob", ): # Actor action_spec = BoundedTensorSpec( @@ -5148,6 +5149,7 @@ def _create_mock_actor( in_keys=["loc", "scale"], spec=action_spec, return_log_prob=True, + log_prob_key=sample_log_prob_key, ) return actor.to(device) @@ -5700,7 +5702,7 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): sample_log_prob_key=tensor_keys["sample_log_prob"], action_key=tensor_keys["action"], ) - actor = self._create_mock_actor() + actor = self._create_mock_actor(sample_log_prob_key=tensor_keys["sample_log_prob"]) value = self._create_mock_value(out_keys=[tensor_keys["value"]]) if advantage == "gae": @@ -5810,7 +5812,7 @@ def test_ppo_notensordict( terminated_key=terminated_key, ) - actor = self._create_mock_actor(observation_key=observation_key) + actor = self._create_mock_actor(observation_key=observation_key, sample_log_prob_key=sample_log_prob_key) value = self._create_mock_value(observation_key=observation_key) loss = loss_class(actor=actor, critic=value) From 1d8d1efadf9306cc775d67ae04c20a69efa44600 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 10:02:41 +0200 Subject: [PATCH 061/109] adapt ppo tests --- test/test_cost.py | 18 ++++++++++++++---- torchrl/objectives/a2c.py | 14 ++++++++++++-- torchrl/objectives/ppo.py | 14 ++++++++++++-- torchrl/objectives/reinforce.py | 14 ++++++++++++-- 4 files changed, 50 insertions(+), 10 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 1ea165f1751..a5e61dc7972 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5337,7 +5337,10 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): ) elif advantage == "vtrace": advantage = VTrace( - gamma=0.9, value_network=value, actor_network=actor, differentiable=gradient_mode + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, ) elif advantage == "td": advantage = TD1Estimator( @@ -5578,7 +5581,10 @@ def test_ppo_diff(self, loss_class, device, gradient_mode, advantage): ) elif advantage == "vtrace": advantage = VTrace( - gamma=0.9, value_network=value, actor_network=actor, differentiable=gradient_mode + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, ) elif advantage == "td": advantage = TD1Estimator( @@ -5702,7 +5708,9 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): sample_log_prob_key=tensor_keys["sample_log_prob"], action_key=tensor_keys["action"], ) - actor = self._create_mock_actor(sample_log_prob_key=tensor_keys["sample_log_prob"]) + actor = self._create_mock_actor( + sample_log_prob_key=tensor_keys["sample_log_prob"] + ) value = self._create_mock_value(out_keys=[tensor_keys["value"]]) if advantage == "gae": @@ -5812,7 +5820,9 @@ def test_ppo_notensordict( terminated_key=terminated_key, ) - actor = self._create_mock_actor(observation_key=observation_key, sample_log_prob_key=sample_log_prob_key) + actor = self._create_mock_actor( + observation_key=observation_key, sample_log_prob_key=sample_log_prob_key + ) value = self._create_mock_value(observation_key=observation_key) loss = loss_class(actor=actor, critic=value) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index f9ef9d521f7..92955d4cab3 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -3,11 +3,17 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import warnings +from copy import deepcopy from dataclasses import dataclass from typing import Tuple import torch -from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule +from tensordict.nn import ( + dispatch, + ProbabilisticTensorDictSequential, + repopulate_module, + TensorDictModule, +) from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -397,8 +403,12 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) elif value_type == ValueEstimators.VTrace: + # VTrace currently does not support functional call on the actor + actor_with_params = repopulate_module( + deepcopy(self.actor), self.actor_params + ) self._value_estimator = VTrace( - value_network=self.critic, actor_network=self.actor, **hp + value_network=self.critic, actor_network=actor_with_params, **hp ) else: raise NotImplementedError(f"Unknown value type {value_type}") diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index d54472764f1..285ab63ecbb 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -4,11 +4,17 @@ # LICENSE file in the root directory of this source tree. import math import warnings +from copy import deepcopy from dataclasses import dataclass from typing import Tuple import torch -from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule +from tensordict.nn import ( + dispatch, + ProbabilisticTensorDictSequential, + repopulate_module, + TensorDictModule, +) from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -470,8 +476,12 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) elif value_type == ValueEstimators.VTrace: + # VTrace currently does not support functional call on the actor + actor_with_params = repopulate_module( + deepcopy(self.actor), self.actor_params + ) self._value_estimator = VTrace( - value_network=self.critic, actor_network=self.actor, **hp + value_network=self.critic, actor_network=actor_with_params, **hp ) else: raise NotImplementedError(f"Unknown value type {value_type}") diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 04389eed89d..1ae9c1e8252 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -3,12 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import warnings +from copy import deepcopy from dataclasses import dataclass from typing import Optional import torch -from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule +from tensordict.nn import ( + dispatch, + ProbabilisticTensorDictSequential, + repopulate_module, + TensorDictModule, +) from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torchrl.objectives.common import LossModule @@ -347,8 +353,12 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams elif value_type == ValueEstimators.TDLambda: self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) elif value_type == ValueEstimators.VTrace: + # VTrace currently does not support functional call on the actor + actor_with_params = repopulate_module( + deepcopy(self.actor), self.actor_params + ) self._value_estimator = VTrace( - value_network=self.critic, actor_network=self.actor_network, **hp + value_network=self.critic, actor_network=actor_with_params, **hp ) else: raise NotImplementedError(f"Unknown value type {value_type}") From ebf74b86f50c12df2bbcf958ca18a030623d442d Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 10:26:21 +0200 Subject: [PATCH 062/109] fix tests ppo --- torchrl/objectives/value/functional.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 3492409c6cc..9b80688441b 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -1256,6 +1256,8 @@ def vtrace_advantage_estimate( raise RuntimeError(SHAPE_ERR) device = state_value.device + c_thresh = c_thresh.to(device) + rho_thresh = rho_thresh.to(device) not_done = (~done).int() not_terminated = (~terminated).int() @@ -1268,7 +1270,6 @@ def vtrace_advantage_estimate( deltas = clipped_rho * ( reward + terminated_discounts * next_state_value - state_value ) - c_thresh = c_thresh.to(device) clipped_c = rho.clamp_max(c_thresh) vs_minus_v_xs = [torch.zeros_like(next_state_value[..., -1, :])] From 1180993264cb11970386fb934fedf6cb0615fad9 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 11:41:01 +0200 Subject: [PATCH 063/109] fix tests a2c --- test/test_cost.py | 45 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index a5e61dc7972..9b40f4df56c 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5881,6 +5881,7 @@ def _create_mock_actor( action_dim=4, device="cpu", observation_key="observation", + sample_log_prob_key="sample_log_prob", ): # Actor action_spec = BoundedTensorSpec( @@ -5895,6 +5896,8 @@ def _create_mock_actor( in_keys=["loc", "scale"], spec=action_spec, distribution_class=TanhNormal, + return_log_prob=True, + log_prob_key=sample_log_prob_key, ) return actor.to(device) @@ -6027,7 +6030,7 @@ def _create_seq_mock_data_a2c( return td @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) def test_a2c(self, device, gradient_mode, advantage, td_est): @@ -6040,6 +6043,10 @@ def test_a2c(self, device, gradient_mode, advantage, td_est): advantage = GAE( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, value_network=value, actor_network=actor, differentiable=gradient_mode, + ) elif advantage == "td": advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode @@ -6164,7 +6171,7 @@ def test_a2c_separate_losses(self, separate_losses): not _has_functorch, reason=f"functorch not found, {FUNCTORCH_ERR}" ) @pytest.mark.parametrize("gradient_mode", (True, False)) - @pytest.mark.parametrize("advantage", ("gae", "td", "td_lambda", None)) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) def test_a2c_diff(self, device, gradient_mode, advantage): if pack_version.parse(torch.__version__) > pack_version.parse("1.14"): @@ -6182,6 +6189,13 @@ def test_a2c_diff(self, device, gradient_mode, advantage): advantage = TD1Estimator( gamma=0.9, value_network=value, differentiable=gradient_mode ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) elif advantage == "td_lambda": advantage = TDLambdaEstimator( gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode @@ -6231,6 +6245,7 @@ def test_a2c_diff(self, device, gradient_mode, advantage): ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.GAE, + ValueEstimators.VTrace, ValueEstimators.TDLambda, ], ) @@ -6248,6 +6263,7 @@ def test_a2c_tensordict_keys(self, td_est): "reward": "reward", "done": "done", "terminated": "terminated", + "sample_log_prob": "sample_log_prob", } self.tensordict_keys_test( @@ -6270,8 +6286,9 @@ def test_a2c_tensordict_keys(self, td_est): } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_tensordict_keys_run(self, device): + def test_a2c_tensordict_keys_run(self, device, advantage): """Test A2C loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -6280,6 +6297,7 @@ def test_a2c_tensordict_keys_run(self, device): value_key = "state_value_test" action_key = "action_test" reward_key = "reward_test" + sample_log_prob_key = "sample_log_prob_test" done_key = ("done", "test") terminated_key = ("terminated", "test") @@ -6291,14 +6309,19 @@ def test_a2c_tensordict_keys_run(self, device): terminated_key=terminated_key, ) - actor = self._create_mock_actor(device=device) + actor = self._create_mock_actor(device=device, sample_log_prob_key=sample_log_prob_key) value = self._create_mock_value(device=device, out_keys=[value_key]) - advantage = GAE( - gamma=0.9, - lmbda=0.9, - value_network=value, - differentiable=gradient_mode, - ) + if advantage == "gae": + advantage = GAE( + gamma=0.9, lmbda=0.9, value_network=value, differentiable=gradient_mode + ) + elif advantage == "vtrace": + advantage = VTrace( + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, + ) advantage.set_keys( advantage=advantage_key, value_target=value_target_key, @@ -6306,6 +6329,7 @@ def test_a2c_tensordict_keys_run(self, device): reward=reward_key, done=done_key, terminated=terminated_key, + sample_log_prob=sample_log_prob_key, ) loss_fn = A2CLoss(actor, value, loss_critic_type="l2") loss_fn.set_keys( @@ -6316,6 +6340,7 @@ def test_a2c_tensordict_keys_run(self, device): reward=reward_key, done=done_key, terminated=done_key, + sample_log_prob=sample_log_prob_key, ) advantage(td) From 6e73acd2ef4272617ac497a192208f54ecbd0c5b Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 11:53:41 +0200 Subject: [PATCH 064/109] fix tests a2c --- test/test_cost.py | 49 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 9b40f4df56c..11ca9834bf3 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6045,7 +6045,10 @@ def test_a2c(self, device, gradient_mode, advantage, td_est): ) elif advantage == "vtrace": advantage = VTrace( - gamma=0.9, value_network=value, actor_network=actor, differentiable=gradient_mode, + gamma=0.9, + value_network=value, + actor_network=actor, + differentiable=gradient_mode, ) elif advantage == "td": advantage = TD1Estimator( @@ -6286,9 +6289,16 @@ def test_a2c_tensordict_keys(self, td_est): } self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) - @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) + @pytest.mark.parametrize( + "td_est", + [ + ValueEstimators.GAE, + ValueEstimators.VTrace, + ], + ) + @pytest.mark.parametrize("advantage", ("gae", "vtrace", None)) @pytest.mark.parametrize("device", get_default_devices()) - def test_a2c_tensordict_keys_run(self, device, advantage): + def test_a2c_tensordict_keys_run(self, device, advantage, td_est): """Test A2C loss module with non-default tensordict keys.""" torch.manual_seed(self.seed) gradient_mode = True @@ -6309,7 +6319,9 @@ def test_a2c_tensordict_keys_run(self, device, advantage): terminated_key=terminated_key, ) - actor = self._create_mock_actor(device=device, sample_log_prob_key=sample_log_prob_key) + actor = self._create_mock_actor( + device=device, sample_log_prob_key=sample_log_prob_key + ) value = self._create_mock_value(device=device, out_keys=[value_key]) if advantage == "gae": advantage = GAE( @@ -6322,15 +6334,11 @@ def test_a2c_tensordict_keys_run(self, device, advantage): actor_network=actor, differentiable=gradient_mode, ) - advantage.set_keys( - advantage=advantage_key, - value_target=value_target_key, - value=value_key, - reward=reward_key, - done=done_key, - terminated=terminated_key, - sample_log_prob=sample_log_prob_key, - ) + elif advantage is None: + pass + else: + raise NotImplementedError + loss_fn = A2CLoss(actor, value, loss_critic_type="l2") loss_fn.set_keys( advantage=advantage_key, @@ -6343,7 +6351,20 @@ def test_a2c_tensordict_keys_run(self, device, advantage): sample_log_prob=sample_log_prob_key, ) - advantage(td) + if advantage is not None: + advantage.set_keys( + advantage=advantage_key, + value_target=value_target_key, + value=value_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + sample_log_prob=sample_log_prob_key, + ) + advantage(td) + else: + if td_est is not None: + loss_fn.make_value_estimator(td_est) loss = loss_fn(td) loss_critic = loss["loss_critic"] From 2ecb10302afee8882dd71ac9c0cf5c77a95d04f8 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 11:55:36 +0200 Subject: [PATCH 065/109] fix tests a2c --- test/test_cost.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_cost.py b/test/test_cost.py index 11ca9834bf3..0b61e9b4870 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5988,6 +5988,7 @@ def _create_seq_mock_data_a2c( reward_key="reward", done_key="done", terminated_key="terminated", + sample_log_prob_key="sample_log_prob", ): # create a tensordict total_obs = torch.randn(batch, T + 1, obs_dim, device=device) @@ -6017,7 +6018,7 @@ def _create_seq_mock_data_a2c( }, "collector": {"mask": mask}, action_key: action.masked_fill_(~mask.unsqueeze(-1), 0.0), - "sample_log_prob": torch.randn_like(action[..., 1]).masked_fill_( + sample_log_prob_key: torch.randn_like(action[..., 1]).masked_fill_( ~mask, 0.0 ) / 10, @@ -6317,6 +6318,7 @@ def test_a2c_tensordict_keys_run(self, device, advantage, td_est): reward_key=reward_key, done_key=done_key, terminated_key=terminated_key, + sample_log_prob_key=sample_log_prob_key, ) actor = self._create_mock_actor( From cd077192878c9d000a06e1023db350131d700bab Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 12:01:03 +0200 Subject: [PATCH 066/109] fix tests reinforce --- test/test_cost.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/test/test_cost.py b/test/test_cost.py index 0b61e9b4870..1c7f2dec173 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -6464,7 +6464,16 @@ class TestReinforce(LossModuleTestBase): @pytest.mark.parametrize("delay_value", [True, False]) @pytest.mark.parametrize("gradient_mode", [True, False]) @pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None]) - @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + @pytest.mark.parametrize( + "td_est", + [ + ValueEstimators.TD1, + ValueEstimators.TD0, + ValueEstimators.GAE, + ValueEstimators.TDLambda, + None, + ], + ) def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est): n_obs = 3 n_act = 5 From f1d2770beb87c08f32161f306d99ca00878a7ee1 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 15:49:33 +0200 Subject: [PATCH 067/109] fix tests values --- test/test_cost.py | 106 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/test/test_cost.py b/test/test_cost.py index 1c7f2dec173..4f2bcdc586f 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -9292,6 +9292,112 @@ def test_gae_multidim( torch.testing.assert_close(r1, r3, rtol=1e-4, atol=1e-4) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) + @pytest.mark.parametrize("gamma", [0.99, 0.5, 0.1]) + @pytest.mark.parametrize("N", [(1,), (3,), (7, 3)]) + @pytest.mark.parametrize("T", [200, 5, 3]) + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) + @pytest.mark.parametrize("has_done", [False, True]) + def test_vtrace(self, device, gamma, N, T, dtype, has_done): + torch.manual_seed(0) + + done = torch.zeros(*N, T, 1, device=device, dtype=torch.bool) + terminated = done.clone() + if has_done: + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated + reward = torch.randn(*N, T, 1, device=device, dtype=dtype) + state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) + next_state_value = torch.randn(*N, T, 1, device=device, dtype=dtype) + log_pi = torch.log(torch.rand(*N, T, 1, device=device, dtype=dtype)) + log_mu = torch.log(torch.rand(*N, T, 1, device=device, dtype=dtype)) + + _, value_target = vtrace_advantage_estimate( + gamma, + log_pi, + log_mu, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + ) + + assert not torch.isnan(value_target).any() + assert not torch.isinf(value_target).any() + + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("gamma", [0.99, 0.5, 0.1]) + @pytest.mark.parametrize("N", [(3,), (7, 3)]) + @pytest.mark.parametrize("T", [100, 3]) + @pytest.mark.parametrize("dtype", [torch.float, torch.double]) + @pytest.mark.parametrize("feature_dim", [[5], [2, 5]]) + @pytest.mark.parametrize("has_done", [True, False]) + def test_vtrace_multidim(self, device, gamma, N, T, dtype, has_done, feature_dim): + D = feature_dim + time_dim = -1 - len(D) + + torch.manual_seed(0) + + done = torch.zeros(*N, T, *D, device=device, dtype=torch.bool) + terminated = done.clone() + if has_done: + terminated = terminated.bernoulli_(0.1) + done = done.bernoulli_(0.1) | terminated + reward = torch.randn(*N, T, *D, device=device, dtype=dtype) + state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) + next_state_value = torch.randn(*N, T, *D, device=device, dtype=dtype) + log_pi = torch.log(torch.rand(*N, T, *D, device=device, dtype=dtype)) + log_mu = torch.log(torch.rand(*N, T, *D, device=device, dtype=dtype)) + + r1 = vtrace_advantage_estimate( + gamma, + log_pi, + log_mu, + state_value, + next_state_value, + reward, + done=done, + terminated=terminated, + time_dim=time_dim, + ) + if len(D) == 2: + r2 = [ + vtrace_advantage_estimate( + gamma, + log_pi[..., i : i + 1, j], + log_mu[..., i : i + 1, j], + state_value[..., i : i + 1, j], + next_state_value[..., i : i + 1, j], + reward[..., i : i + 1, j], + terminated=terminated[..., i : i + 1, j], + done=done[..., i : i + 1, j], + time_dim=-2, + ) + for i in range(D[0]) + for j in range(D[1]) + ] + else: + r2 = [ + vtrace_advantage_estimate( + gamma, + log_pi[..., i : i + 1], + log_mu[..., i : i + 1], + state_value[..., i : i + 1], + next_state_value[..., i : i + 1], + reward[..., i : i + 1], + done=done[..., i : i + 1], + terminated=terminated[..., i : i + 1], + time_dim=-2, + ) + for i in range(D[0]) + ] + + list2 = list(zip(*r2)) + r2 = [torch.cat(list2[0], -1), torch.cat(list2[1], -1)] + if len(D) == 2: + r2 = [r2[0].unflatten(-1, D), r2[1].unflatten(-1, D)] + torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) + @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("gamma", [0.5, 0.99, 0.1]) @pytest.mark.parametrize("lmbda", [0.1, 0.5, 0.99]) From 676d8f51c9792b2265b4c9324cd601b8a7f68a9f Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 15:55:09 +0200 Subject: [PATCH 068/109] fix tests values --- test/test_cost.py | 1 + torchrl/objectives/value/functional.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 4f2bcdc586f..4c644b704b0 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -9292,6 +9292,7 @@ def test_gae_multidim( torch.testing.assert_close(r1, r3, rtol=1e-4, atol=1e-4) torch.testing.assert_close(r1, r2, rtol=1e-4, atol=1e-4) + @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("gamma", [0.99, 0.5, 0.1]) @pytest.mark.parametrize("N", [(1,), (3,), (7, 3)]) @pytest.mark.parametrize("T", [200, 5, 3]) diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 9b80688441b..99f91c5663b 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -1227,8 +1227,8 @@ def vtrace_advantage_estimate( reward: torch.Tensor, done: torch.Tensor, terminated: torch.Tensor | None = None, - rho_thresh: Union[float, torch.Tensor] = 1.0, - c_thresh: Union[float, torch.Tensor] = 1.0, + rho_thresh: torch.Tensor = torch.tensor(1.0), + c_thresh: torch.Tensor = torch.tensor(1.0), time_dim: int = -2, ) -> Tuple[torch.Tensor, torch.Tensor]: """Computes V-Trace off-policy actor critic targets. From f491e7d8245824ba40b7653a123d4a840d578847 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 16:05:07 +0200 Subject: [PATCH 069/109] fix tests adv --- test/test_cost.py | 42 ++++++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 4c644b704b0..5ef79a52276 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -10333,6 +10333,7 @@ class TestAdv: [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}] ], ) def test_dispatch( @@ -10343,18 +10344,35 @@ def test_dispatch( value_net = TensorDictModule( nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) - module = adv( - gamma=0.98, - value_network=value_net, - differentiable=False, - **kwargs, - ) - kwargs = { - "obs": torch.randn(1, 10, 3), - "next_reward": torch.randn(1, 10, 1, requires_grad=True), - "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), - "next_obs": torch.randn(1, 10, 3), - } + if adv == VTrace: + actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + kwargs = { + "obs": torch.randn(1, 10, 3), + "sample_log_prob": torch.log(torch.rand(1, 10, 4)), + "next_reward": torch.randn(1, 10, 1, requires_grad=True), + "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_obs": torch.randn(1, 10, 3), + } + else: + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=False, + **kwargs, + ) + kwargs = { + "obs": torch.randn(1, 10, 3), + "next_reward": torch.randn(1, 10, 1, requires_grad=True), + "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_obs": torch.randn(1, 10, 3), + } advantage, value_target = module(**kwargs) assert advantage.shape == torch.Size([1, 10, 1]) assert value_target.shape == torch.Size([1, 10, 1]) From 7a63dd6f6519f913f0ba43bb9923f4b63e0d2df6 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 16:06:14 +0200 Subject: [PATCH 070/109] fix tests adv --- test/test_cost.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_cost.py b/test/test_cost.py index 5ef79a52276..6f52e3de40f 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -10353,6 +10353,13 @@ def test_dispatch( distribution_class=OneHotCategorical, return_log_prob=True, ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + differentiable=False, + **kwargs, + ) kwargs = { "obs": torch.randn(1, 10, 3), "sample_log_prob": torch.log(torch.rand(1, 10, 4)), From d30bb9d45975178019642ac358c269eff5aaa213 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 16:26:54 +0200 Subject: [PATCH 071/109] fix tests adv --- test/test_cost.py | 286 +++++++++++++++++++------ torchrl/objectives/value/functional.py | 12 +- 2 files changed, 228 insertions(+), 70 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 6f52e3de40f..8ea5951d260 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -10333,7 +10333,7 @@ class TestAdv: [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], - [VTrace, {}] + [VTrace, {}], ], ) def test_dispatch( @@ -10345,7 +10345,9 @@ def test_dispatch( nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) if adv == VTrace: - actor_net = TensorDictModule(nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"]) + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) actor_net = ProbabilisticActor( module=actor_net, in_keys=["logits"], @@ -10362,9 +10364,10 @@ def test_dispatch( ) kwargs = { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 4)), + "sample_log_prob": torch.log(torch.rand(1, 10, 1)), "next_reward": torch.randn(1, 10, 1, requires_grad=True), "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool), "next_obs": torch.randn(1, 10, 3), } else: @@ -10378,6 +10381,7 @@ def test_dispatch( "obs": torch.randn(1, 10, 3), "next_reward": torch.randn(1, 10, 1, requires_grad=True), "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool), "next_obs": torch.randn(1, 10, 3), } advantage, value_target = module(**kwargs) @@ -10390,6 +10394,7 @@ def test_dispatch( [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) def test_diff_reward( @@ -10400,23 +10405,55 @@ def test_diff_reward( value_net = TensorDictModule( nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) - module = adv( - gamma=0.98, - value_network=value_net, - differentiable=True, - **kwargs, - ) - td = TensorDict( - { - "obs": torch.randn(1, 10, 3), - "next": { + if adv == VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + differentiable=False, + **kwargs, + ) + td = TensorDict( + { "obs": torch.randn(1, 10, 3), - "reward": torch.randn(1, 10, 1, requires_grad=True), - "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "next": { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "terminated": torch.zeros(1, 10, 1, dtype=torch.bool), + }, }, - }, - [1, 10], - ) + [1, 10], + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=True, + **kwargs, + ) + td = TensorDict( + { + "obs": torch.randn(1, 10, 3), + "next": { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + }, + }, + [1, 10], + ) td = module(td.clone(False)) # check that the advantage can't backprop to the value params td["advantage"].sum().backward() @@ -10431,6 +10468,7 @@ def test_diff_reward( [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) @pytest.mark.parametrize("shifted", [True, False]) @@ -10438,25 +10476,60 @@ def test_non_differentiable(self, adv, shifted, kwargs): value_net = TensorDictModule( nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) - module = adv( - gamma=0.98, - value_network=value_net, - differentiable=False, - shifted=shifted, - **kwargs, - ) - td = TensorDict( - { - "obs": torch.randn(1, 10, 3), - "next": { + + if adv == VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + differentiable=False, + shifted=shifted, + **kwargs, + ) + td = TensorDict( + { "obs": torch.randn(1, 10, 3), - "reward": torch.randn(1, 10, 1, requires_grad=True), - "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "next": { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "terminated": torch.zeros(1, 10, 1, dtype=torch.bool), + }, }, - }, - [1, 10], - names=[None, "time"], - ) + [1, 10], + names=[None, "time"], + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=False, + shifted=shifted, + **kwargs, + ) + td = TensorDict( + { + "obs": torch.randn(1, 10, 3), + "next": { + "obs": torch.randn(1, 10, 3), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + }, + }, + [1, 10], + names=[None, "time"], + ) td = module(td.clone(False)) assert td["advantage"].is_leaf @@ -10466,6 +10539,7 @@ def test_non_differentiable(self, adv, shifted, kwargs): [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) @pytest.mark.parametrize("has_value_net", [True, False]) @@ -10488,28 +10562,64 @@ def test_skip_existing( else: value_net = None - module = adv( - gamma=0.98, - value_network=value_net, - differentiable=True, - shifted=shifted, - skip_existing=skip_existing, - **kwargs, - ) - td = TensorDict( - { - "obs": torch.randn(1, 10, 3), - "state_value": torch.ones(1, 10, 1), - "next": { + if adv == VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + differentiable=False, + shifted=shifted, + **kwargs, + ) + td = TensorDict( + { "obs": torch.randn(1, 10, 3), + "sample_log_prob": torch.log(torch.rand(1, 10, 1)), "state_value": torch.ones(1, 10, 1), - "reward": torch.randn(1, 10, 1, requires_grad=True), - "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "next": { + "obs": torch.randn(1, 10, 3), + "state_value": torch.ones(1, 10, 1), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + "terminated": torch.zeros(1, 10, 1, dtype=torch.bool), + }, }, - }, - [1, 10], - names=[None, "time"], - ) + [1, 10], + names=[None, "time"], + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + differentiable=True, + shifted=shifted, + skip_existing=skip_existing, + **kwargs, + ) + td = TensorDict( + { + "obs": torch.randn(1, 10, 3), + "state_value": torch.ones(1, 10, 1), + "next": { + "obs": torch.randn(1, 10, 3), + "state_value": torch.ones(1, 10, 1), + "reward": torch.randn(1, 10, 1, requires_grad=True), + "done": torch.zeros(1, 10, 1, dtype=torch.bool), + }, + }, + [1, 10], + names=[None, "time"], + ) td = module(td.clone(False)) if has_value_net and not skip_existing: exp_val = 0 @@ -10527,15 +10637,34 @@ def test_skip_existing( [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) def test_set_keys(self, value, adv, kwargs): value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=[value]) - module = adv( - gamma=0.98, - value_network=value_net, - **kwargs, - ) + if adv == VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + **kwargs, + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + **kwargs, + ) module.set_keys(value=value) assert module.tensor_keys.value == value @@ -10549,6 +10678,7 @@ def test_set_keys(self, value, adv, kwargs): [GAE, {"lmbda": 0.95}], [TD1Estimator, {}], [TDLambdaEstimator, {"lmbda": 0.95}], + [VTrace, {}], ], ) def test_set_deprecated_keys(self, adv, kwargs): @@ -10557,14 +10687,36 @@ def test_set_deprecated_keys(self, adv, kwargs): ) with pytest.warns(DeprecationWarning): - module = adv( - gamma=0.98, - value_network=value_net, - value_key="test_value", - advantage_key="advantage_test", - value_target_key="value_target_test", - **kwargs, - ) + + if adv == VTrace: + actor_net = TensorDictModule( + nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] + ) + actor_net = ProbabilisticActor( + module=actor_net, + in_keys=["logits"], + out_keys=["action"], + distribution_class=OneHotCategorical, + return_log_prob=True, + ) + module = adv( + gamma=0.98, + actor_network=actor_net, + value_network=value_net, + value_key="test_value", + advantage_key="advantage_test", + value_target_key="value_target_test", + **kwargs, + ) + else: + module = adv( + gamma=0.98, + value_network=value_net, + value_key="test_value", + advantage_key="advantage_test", + value_target_key="value_target_test", + **kwargs, + ) assert module.tensor_keys.value == "test_value" assert module.tensor_keys.advantage == "advantage_test" assert module.tensor_keys.value_target == "value_target_test" diff --git a/torchrl/objectives/value/functional.py b/torchrl/objectives/value/functional.py index 99f91c5663b..b7af99dd855 100644 --- a/torchrl/objectives/value/functional.py +++ b/torchrl/objectives/value/functional.py @@ -1227,8 +1227,8 @@ def vtrace_advantage_estimate( reward: torch.Tensor, done: torch.Tensor, terminated: torch.Tensor | None = None, - rho_thresh: torch.Tensor = torch.tensor(1.0), - c_thresh: torch.Tensor = torch.tensor(1.0), + rho_thresh: Union[float, torch.Tensor] = 1.0, + c_thresh: Union[float, torch.Tensor] = 1.0, time_dim: int = -2, ) -> Tuple[torch.Tensor, torch.Tensor]: """Computes V-Trace off-policy actor critic targets. @@ -1256,11 +1256,17 @@ def vtrace_advantage_estimate( raise RuntimeError(SHAPE_ERR) device = state_value.device + + if not isinstance(rho_thresh, torch.Tensor): + rho_thresh = torch.tensor(rho_thresh, device=device) + if not isinstance(c_thresh, torch.Tensor): + c_thresh = torch.tensor(c_thresh, device=device) + c_thresh = c_thresh.to(device) rho_thresh = rho_thresh.to(device) not_done = (~done).int() - not_terminated = (~terminated).int() + not_terminated = not_done if terminated is None else (~terminated).int() *batch_size, time_steps, lastdim = not_done.shape done_discounts = gamma * not_done terminated_discounts = gamma * not_terminated From a9e1db3d442d9f72d68d89210789ca1a1080f4a4 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 16:52:36 +0200 Subject: [PATCH 072/109] code examples --- examples/impala/README.md | 28 +++++++++ examples/impala/config_multi_node.yaml | 35 ----------- examples/impala/config_multi_node_ray.yaml | 62 +++++++++++++++++++ ...multi_node.py => impala_multi_node_ray.py} | 47 +++++++------- examples/impala/utils.py | 42 +------------ 5 files changed, 117 insertions(+), 97 deletions(-) delete mode 100644 examples/impala/config_multi_node.yaml create mode 100644 examples/impala/config_multi_node_ray.yaml rename examples/impala/{impala_multi_node.py => impala_multi_node_ray.py} (86%) diff --git a/examples/impala/README.md b/examples/impala/README.md index e69de29bb2d..da18a0d952f 100644 --- a/examples/impala/README.md +++ b/examples/impala/README.md @@ -0,0 +1,28 @@ +## Reproducing Importance Weighted Actor-Learner Architecture (IMPALA) Algorithm Results + +This repository contains scripts that enable training agents using the IMPALA Algorithm on MuJoCo and Atari environments. We follow the original paper [Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347) by Espeholt et al. 2018. + +## Examples Structure + +Please note that we provide 2 examples, one for single node training and one for distributed training. Both examples rely on the same utils file, but besides that are independent. Each example contains the following files: + +1. **Main Script:** The definition of algorithm components and the training loop can be found in the main script (e.g. impala_single_node_ray.py). + +2. **Utils File:** A utility file is provided to contain various helper functions, generally to create the environment and the models (e.g. utils.py). + +3. **Configuration File:** This file includes default hyperparameters specified in the original paper. For the multi-node case, the file also includes the configuration file of the Ray cluster. Users can modify these hyperparameters to customize their experiments (e.g. config_single_node.yaml). + + +## Running the Examples + +You can execute the single node IMPALA algorithm on Atari environments by running the following command: + +```bash +python impala_single_node_ray.py +``` + +You can execute the multi-node IMPALA algorithm on Atari environments by running the following command: + +```bash +python impala_single_node_ray.py +``` diff --git a/examples/impala/config_multi_node.yaml b/examples/impala/config_multi_node.yaml deleted file mode 100644 index 86a11d6b40c..00000000000 --- a/examples/impala/config_multi_node.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# Environment -env: - env_name: PongNoFrameskip-v4 - -# collector -collector: - frames_per_batch: 80 - total_frames: 200_000_000 - num_workers: 12 - -# logger -logger: - backend: wandb - exp_name: Atari_Schulman17 - test_interval: 200_000_000 - num_test_episodes: 3 - -# Optim -optim: - lr: 0.0006 - eps: 1e-8 - weight_decay: 0.0 - momentum: 0.0 - alpha: 0.99 - max_grad_norm: 40.0 - anneal_lr: True - -# loss -loss: - gamma: 0.99 - batch_size: 32 - sgd_updates: 1 - critic_coef: 0.5 - entropy_coef: 0.01 - loss_critic_type: l2 diff --git a/examples/impala/config_multi_node_ray.yaml b/examples/impala/config_multi_node_ray.yaml new file mode 100644 index 00000000000..89e69c3336d --- /dev/null +++ b/examples/impala/config_multi_node_ray.yaml @@ -0,0 +1,62 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + +# Ray Config - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html +ray_init_config: + address: None + num_cpus: None + num_gpus: None + resources: None + object_store_memory: None + local_mode: False + ignore_reinit_error: False + include_dashboard: None + dashboard_host: "127.0.0.1" + dashboard_port: None + job_config: None + configure_logging: True + logging_level: "info" + logging_format: None + log_to_driver: True + namespace: None + runtime_env: None + storage: None + +# Resources assigned to each IMPALA rollout collection worker +remote_worker_resources: + num_cpus: 1 + num_gpus: 0.25 + memory: 2 * 1024**3 + +# collector +collector: + frames_per_batch: 80 + total_frames: 200_000_000 + num_workers: 12 + +# logger +logger: + backend: wandb + exp_name: Atari_Schulman17 + test_interval: 200_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 0.0006 + eps: 1e-8 + weight_decay: 0.0 + momentum: 0.0 + alpha: 0.99 + max_grad_norm: 40.0 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + batch_size: 32 + sgd_updates: 1 + critic_coef: 0.5 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/impala/impala_multi_node.py b/examples/impala/impala_multi_node_ray.py similarity index 86% rename from examples/impala/impala_multi_node.py rename to examples/impala/impala_multi_node_ray.py index d843ecb6bfd..637a0ce7fb6 100644 --- a/examples/impala/impala_multi_node.py +++ b/examples/impala/impala_multi_node_ray.py @@ -7,7 +7,6 @@ This script reproduces the IMPALA Algorithm results from Espeholt et al. 2018 for the on Atari Environments. """ - import hydra @@ -21,7 +20,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from tensordict import TensorDict from torchrl.collectors import SyncDataCollector - from torchrl.collectors.distributed import RayCollector, RPCDataCollector + from torchrl.collectors.distributed import RayCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type @@ -57,10 +56,30 @@ def main(cfg: "DictConfig"): # noqa: F821 actor, critic = actor.to(device), critic.to(device) # Create collector + ray_init_config = { + "address": cfg.ray_init_config.address, + "num_cpus": cfg.ray_init_config.num_cpus, + "num_gpus": cfg.ray_init_config.num_gpus, + "resources": cfg.ray_init_config.resources, + "object_store_memory": cfg.ray_init_config.object_store_memory, + "local_mode": cfg.ray_init_config.local_mode, + "ignore_reinit_error": cfg.ray_init_config.ignore_reinit_error, + "include_dashboard": cfg.ray_init_config.include_dashboard, + "dashboard_host": cfg.ray_init_config.dashboard_host, + "dashboard_port": cfg.ray_init_config.dashboard_port, + "job_config": cfg.ray_init_config.job_config, + "configure_logging": cfg.ray_init_config.configure_logging, + "logging_level": cfg.ray_init_config.logging_level, + "logging_format": cfg.ray_init_config.logging_format, + "log_to_driver": cfg.ray_init_config.log_to_driver, + "namespace": cfg.ray_init_config.namespace, + "runtime_env": cfg.ray_init_config.runtime_env, + "storage": cfg.ray_init_config.storage, + } remote_config = { - "num_cpus": 1, - "num_gpus": 1.0 / num_workers, - "memory": 2 * 1024**3, + "num_cpus": cfg.remote_worker_resources.num_cpus, + "num_gpus": cfg.remote_worker_resources.num_gpus, + "memory": cfg.remote_worker_resources.memory, } collector = RayCollector( create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, @@ -69,29 +88,13 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch=frames_per_batch, total_frames=total_frames, max_frames_per_traj=-1, + ray_init_config=ray_init_config, remote_configs=remote_config, sync=False, storing_device=device, update_after_each_batch=True, ) - # collector = RPCDataCollector( - # create_env_fn=[make_env(cfg.env.env_name, device)] * 1, - # policy=actor, - # collector_class=SyncDataCollector, - # frames_per_batch=frames_per_batch, - # total_frames=total_frames, - # max_frames_per_traj=-1, - # slurm_kwargs={ - # "timeout_min": 10, - # "slurm_partition": "3090", - # "slurm_cpus_per_task": 1, - # "slurm_gpus_per_node": 0, - # }, - # sync=False, - # update_after_each_batch=True, - # ) - # Create data buffer sampler = SamplerWithoutReplacement() data_buffer = TensorDictReplayBuffer( diff --git a/examples/impala/utils.py b/examples/impala/utils.py index 39c32b70d7b..2983f8a0193 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -3,15 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import os - import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import CompositeSpec, UnboundedDiscreteTensorSpec +from torchrl.data import CompositeSpec from torchrl.envs import ( CatFrames, DoubleToFloat, + EndOfLifeTransform, ExplorationType, GrayScale, GymEnv, @@ -21,7 +20,6 @@ RewardSum, StepCounter, ToTensorImage, - Transform, TransformedEnv, VecNorm, ) @@ -35,47 +33,11 @@ ) -# To pickle the environment, in particular the EndOfLifeTransform, we need to -# add the utils path to the PYTHONPATH -utils_path = os.path.abspath(os.path.abspath(os.path.dirname(__file__))) -current_pythonpath = os.environ.get("PYTHONPATH", "") -new_pythonpath = f"{utils_path}:{current_pythonpath}" -os.environ["PYTHONPATH"] = new_pythonpath - - # ==================================================================== # Environment utils # -------------------------------------------------------------------- -class EndOfLifeTransform(Transform): - def _step(self, tensordict, next_tensordict): - # lives = self.parent.base_env._env.unwrapped.ale.lives() - lives = 0 - end_of_life = torch.tensor( - [tensordict["lives"] < lives], device=self.parent.device - ) - end_of_life = end_of_life | next_tensordict.get("done") - next_tensordict.set("eol", end_of_life) - next_tensordict.set("lives", lives) - return next_tensordict - - def reset(self, tensordict): - lives = self.parent.base_env._env.unwrapped.ale.lives() - end_of_life = False - tensordict.set("eol", [end_of_life]) - tensordict.set("lives", lives) - return tensordict - - def transform_observation_spec(self, observation_spec): - full_done_spec = self.parent.output_spec["full_done_spec"] - observation_spec["eol"] = full_done_spec["done"].clone() - observation_spec["lives"] = UnboundedDiscreteTensorSpec( - self.parent.batch_size, device=self.parent.device - ) - return observation_spec - - def make_env(env_name, device, is_test=False): env = GymEnv( env_name, frame_skip=4, from_pixels=True, pixels_only=False, device=device From 32cd518e5f21527f64a1f1e42359707de8bda925 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 16:53:08 +0200 Subject: [PATCH 073/109] code examples --- examples/impala/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/impala/README.md b/examples/impala/README.md index da18a0d952f..b7bf7a1fe16 100644 --- a/examples/impala/README.md +++ b/examples/impala/README.md @@ -18,7 +18,7 @@ Please note that we provide 2 examples, one for single node training and one for You can execute the single node IMPALA algorithm on Atari environments by running the following command: ```bash -python impala_single_node_ray.py +python impala_single_node.py ``` You can execute the multi-node IMPALA algorithm on Atari environments by running the following command: From 53617bb0fd942671c992923951188be67b8f8aa8 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 16:56:12 +0200 Subject: [PATCH 074/109] fix tests adv --- test/test_cost.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 47a8563ff72..466bf956092 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -10420,7 +10420,7 @@ def test_diff_reward( gamma=0.98, actor_network=actor_net, value_network=value_net, - differentiable=False, + differentiable=True, **kwargs, ) td = TensorDict( @@ -10577,7 +10577,7 @@ def test_skip_existing( gamma=0.98, actor_network=actor_net, value_network=value_net, - differentiable=False, + differentiable=True, shifted=shifted, **kwargs, ) From 0bc7b8c68069e4e87a11ee1a5513f1d75323a6b9 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 16:57:09 +0200 Subject: [PATCH 075/109] fix tests adv --- test/test_cost.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_cost.py b/test/test_cost.py index 466bf956092..883da7cf4d4 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -10579,6 +10579,7 @@ def test_skip_existing( value_network=value_net, differentiable=True, shifted=shifted, + skip_existing=skip_existing, **kwargs, ) td = TensorDict( From 40cc02f685b9263187fd02a170eb9971b53be28a Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 17:02:15 +0200 Subject: [PATCH 076/109] code examples tests --- .github/unittest/linux_examples/scripts/run_test.sh | 6 ++++++ examples/impala/impala_multi_node_ray.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index a6e09a51a43..4dfa2c81f10 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -48,6 +48,12 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/decision_trans # ==================================================================================== # # ================================ Gymnasium ========================================= # +python .github/unittest/helpers/coverage_run_parallel.py examples/impala/impala_single_node.py \ + collector.total_frames=80 \ + collector.frames_per_batch=20 \ + collector.num_workers=1 \ + logger.backend= \ + logger.test_interval=10 python .github/unittest/helpers/coverage_run_parallel.py examples/ppo/ppo_mujoco.py \ env.env_name=HalfCheetah-v4 \ collector.total_frames=40 \ diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 637a0ce7fb6..4945f6624c2 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -78,7 +78,7 @@ def main(cfg: "DictConfig"): # noqa: F821 } remote_config = { "num_cpus": cfg.remote_worker_resources.num_cpus, - "num_gpus": cfg.remote_worker_resources.num_gpus, + "num_gpus": cfg.remote_worker_resources.num_gpus if torch.cuda.device_count() else 0, "memory": cfg.remote_worker_resources.memory, } collector = RayCollector( From cbd923e99f2780438689cb649c0de9075b9ad17c Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 17:05:03 +0200 Subject: [PATCH 077/109] code examples tests --- examples/impala/config_multi_node_ray.yaml | 2 +- examples/impala/impala_multi_node_ray.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/impala/config_multi_node_ray.yaml b/examples/impala/config_multi_node_ray.yaml index 89e69c3336d..c88966fb395 100644 --- a/examples/impala/config_multi_node_ray.yaml +++ b/examples/impala/config_multi_node_ray.yaml @@ -2,7 +2,7 @@ env: env_name: PongNoFrameskip-v4 -# Ray Config - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html +# Ray init kwargs - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html ray_init_config: address: None num_cpus: None diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 4945f6624c2..dc3542e759d 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -78,7 +78,9 @@ def main(cfg: "DictConfig"): # noqa: F821 } remote_config = { "num_cpus": cfg.remote_worker_resources.num_cpus, - "num_gpus": cfg.remote_worker_resources.num_gpus if torch.cuda.device_count() else 0, + "num_gpus": cfg.remote_worker_resources.num_gpus + if torch.cuda.device_count() + else 0, "memory": cfg.remote_worker_resources.memory, } collector = RayCollector( From 2235c02146bed0c16eefeed0d0cd36249643ec58 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 17:13:19 +0200 Subject: [PATCH 078/109] code example with submitit --- .../impala/config_multi_node_submitit.yaml | 42 +++ examples/impala/impala_multi_node_ray.py | 44 +-- examples/impala/impala_multi_node_submitit.py | 275 ++++++++++++++++++ 3 files changed, 329 insertions(+), 32 deletions(-) create mode 100644 examples/impala/config_multi_node_submitit.yaml create mode 100644 examples/impala/impala_multi_node_submitit.py diff --git a/examples/impala/config_multi_node_submitit.yaml b/examples/impala/config_multi_node_submitit.yaml new file mode 100644 index 00000000000..e156ede92f6 --- /dev/null +++ b/examples/impala/config_multi_node_submitit.yaml @@ -0,0 +1,42 @@ +# Environment +env: + env_name: PongNoFrameskip-v4 + +# SLURM config +slurm_config: + timeout_min: 10, + slurm_partition: "train", + slurm_cpus_per_task: 32, + slurm_gpus_per_node: 0, + +# collector +collector: + frames_per_batch: 80 + total_frames: 200_000_000 + num_workers: 12 + +# logger +logger: + backend: wandb + exp_name: Atari_Schulman17 + test_interval: 200_000_000 + num_test_episodes: 3 + +# Optim +optim: + lr: 0.0006 + eps: 1e-8 + weight_decay: 0.0 + momentum: 0.0 + alpha: 0.99 + max_grad_norm: 40.0 + anneal_lr: True + +# loss +loss: + gamma: 0.99 + batch_size: 32 + sgd_updates: 1 + critic_coef: 0.5 + entropy_coef: 0.01 + loss_critic_type: l2 diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index dc3542e759d..67309940381 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -10,7 +10,7 @@ import hydra -@hydra.main(config_path=".", config_name="config_multi_node", version_base="1.1") +@hydra.main(config_path=".", config_name="config_multi_node_submitit", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time @@ -20,7 +20,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from tensordict import TensorDict from torchrl.collectors import SyncDataCollector - from torchrl.collectors.distributed import RayCollector + from torchrl.collectors.distributed import RPCDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type @@ -55,44 +55,24 @@ def main(cfg: "DictConfig"): # noqa: F821 actor, critic = make_ppo_models(cfg.env.env_name) actor, critic = actor.to(device), critic.to(device) - # Create collector - ray_init_config = { - "address": cfg.ray_init_config.address, - "num_cpus": cfg.ray_init_config.num_cpus, - "num_gpus": cfg.ray_init_config.num_gpus, - "resources": cfg.ray_init_config.resources, - "object_store_memory": cfg.ray_init_config.object_store_memory, - "local_mode": cfg.ray_init_config.local_mode, - "ignore_reinit_error": cfg.ray_init_config.ignore_reinit_error, - "include_dashboard": cfg.ray_init_config.include_dashboard, - "dashboard_host": cfg.ray_init_config.dashboard_host, - "dashboard_port": cfg.ray_init_config.dashboard_port, - "job_config": cfg.ray_init_config.job_config, - "configure_logging": cfg.ray_init_config.configure_logging, - "logging_level": cfg.ray_init_config.logging_level, - "logging_format": cfg.ray_init_config.logging_format, - "log_to_driver": cfg.ray_init_config.log_to_driver, - "namespace": cfg.ray_init_config.namespace, - "runtime_env": cfg.ray_init_config.runtime_env, - "storage": cfg.ray_init_config.storage, - } - remote_config = { - "num_cpus": cfg.remote_worker_resources.num_cpus, - "num_gpus": cfg.remote_worker_resources.num_gpus - if torch.cuda.device_count() - else 0, - "memory": cfg.remote_worker_resources.memory, + slurm_kwargs = { + "timeout_min": cfg.slurm.config.timeout_min, + "slurm_partition": cfg.slurm.config.slurm_partition, + "slurm_cpus_per_task": cfg.slurm.config.slurm_cpus_per_task, + "slurm_gpus_per_node": cfg.slurm.config.slurm_gpus_per_node, } - collector = RayCollector( + + # Create collector + collector = RPCDataCollector( create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, policy=actor, collector_class=SyncDataCollector, frames_per_batch=frames_per_batch, total_frames=total_frames, max_frames_per_traj=-1, - ray_init_config=ray_init_config, - remote_configs=remote_config, sync=False, + slurm_kwargs=slurm_kwargs, + launcher="submitit", storing_device=device, update_after_each_batch=True, ) diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py new file mode 100644 index 00000000000..f88ca114b94 --- /dev/null +++ b/examples/impala/impala_multi_node_submitit.py @@ -0,0 +1,275 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script reproduces the IMPALA Algorithm +results from Espeholt et al. 2018 for the on Atari Environments. +""" +import hydra + + +@hydra.main(config_path=".", config_name="config_multi_node_ray", version_base="1.1") +def main(cfg: "DictConfig"): # noqa: F821 + + import time + + import torch.optim + import tqdm + + from tensordict import TensorDict + from torchrl.collectors import SyncDataCollector + from torchrl.collectors.distributed import RayCollector + from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer + from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement + from torchrl.envs import ExplorationType, set_exploration_type + from torchrl.objectives import A2CLoss + from torchrl.objectives.value import VTrace + from torchrl.record.loggers import generate_exp_name, get_logger + from utils import eval_model, make_env, make_ppo_models + + device = "cpu" if not torch.cuda.device_count() else "cuda" + + # Correct for frame_skip + frame_skip = 4 + total_frames = cfg.collector.total_frames // frame_skip + frames_per_batch = cfg.collector.frames_per_batch // frame_skip + test_interval = cfg.logger.test_interval // frame_skip + + # Extract other config parameters + batch_size = cfg.loss.batch_size # Number of rollouts per batch + num_workers = ( + cfg.collector.num_workers + ) # Number of parallel workers collecting rollouts + lr = cfg.optim.lr + anneal_lr = cfg.optim.anneal_lr + sgd_updates = cfg.loss.sgd_updates + max_grad_norm = cfg.optim.max_grad_norm + num_test_episodes = cfg.logger.num_test_episodes + total_network_updates = ( + total_frames // (frames_per_batch * batch_size) + ) * cfg.loss.sgd_updates + + # Create models (check utils.py) + actor, critic = make_ppo_models(cfg.env.env_name) + actor, critic = actor.to(device), critic.to(device) + + # Create collector + ray_init_config = { + "address": cfg.ray_init_config.address, + "num_cpus": cfg.ray_init_config.num_cpus, + "num_gpus": cfg.ray_init_config.num_gpus, + "resources": cfg.ray_init_config.resources, + "object_store_memory": cfg.ray_init_config.object_store_memory, + "local_mode": cfg.ray_init_config.local_mode, + "ignore_reinit_error": cfg.ray_init_config.ignore_reinit_error, + "include_dashboard": cfg.ray_init_config.include_dashboard, + "dashboard_host": cfg.ray_init_config.dashboard_host, + "dashboard_port": cfg.ray_init_config.dashboard_port, + "job_config": cfg.ray_init_config.job_config, + "configure_logging": cfg.ray_init_config.configure_logging, + "logging_level": cfg.ray_init_config.logging_level, + "logging_format": cfg.ray_init_config.logging_format, + "log_to_driver": cfg.ray_init_config.log_to_driver, + "namespace": cfg.ray_init_config.namespace, + "runtime_env": cfg.ray_init_config.runtime_env, + "storage": cfg.ray_init_config.storage, + } + remote_config = { + "num_cpus": cfg.remote_worker_resources.num_cpus, + "num_gpus": cfg.remote_worker_resources.num_gpus + if torch.cuda.device_count() + else 0, + "memory": cfg.remote_worker_resources.memory, + } + collector = RayCollector( + create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, + policy=actor, + collector_class=SyncDataCollector, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + max_frames_per_traj=-1, + ray_init_config=ray_init_config, + remote_configs=remote_config, + sync=False, + storing_device=device, + update_after_each_batch=True, + ) + + # Create data buffer + sampler = SamplerWithoutReplacement() + data_buffer = TensorDictReplayBuffer( + storage=LazyMemmapStorage(frames_per_batch * batch_size), + sampler=sampler, + batch_size=frames_per_batch * batch_size, + ) + + # Create loss and adv modules + adv_module = VTrace( + gamma=cfg.loss.gamma, + value_network=critic, + actor_network=actor, + average_adv=False, + ) + loss_module = A2CLoss( + actor=actor, + critic=critic, + loss_critic_type=cfg.loss.loss_critic_type, + entropy_coef=cfg.loss.entropy_coef, + critic_coef=cfg.loss.critic_coef, + ) + # loss_module.set_keys(done="eol", terminated="eol") + + # Create optimizer + optim = torch.optim.RMSprop( + loss_module.parameters(), + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.eps, + alpha=cfg.optim.alpha, + ) + + # Create logger + logger = None + if cfg.logger.backend: + exp_name = generate_exp_name( + "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" + ) + logger = get_logger( + cfg.logger.backend, logger_name="impala", experiment_name=exp_name + ) + + # Create test environment + test_env = make_env(cfg.env.env_name, device, is_test=True) + test_env.eval() + + # Main loop + collected_frames = 0 + num_network_updates = 0 + start_time = time.time() + pbar = tqdm.tqdm(total=total_frames) + accumulator = [] + sampling_start = time.time() + for i, data in enumerate(collector): + + log_info = {} + sampling_time = time.time() - sampling_start + frames_in_batch = data.numel() + collected_frames += frames_in_batch * frame_skip + pbar.update(data.numel()) + + # Get training rewards and episode lengths + episode_rewards = data["next", "episode_reward"][data["next", "done"]] + if len(episode_rewards) > 0: + episode_length = data["next", "step_count"][data["next", "done"]] + log_info.update( + { + "train/reward": episode_rewards.mean().item(), + "train/episode_length": episode_length.sum().item() + / len(episode_length), + } + ) + + if len(accumulator) < batch_size: + accumulator.append(data) + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + continue + + losses = TensorDict({}, batch_size=[sgd_updates]) + training_start = time.time() + for j in range(sgd_updates): + + # Create a single batch of trajectories + stacked_data = torch.stack(accumulator, dim=0) + stacked_data = stacked_data.to(device) + + # Compute advantage + stacked_data = adv_module(stacked_data) + + # Add to replay buffer + stacked_data_reshape = stacked_data.reshape(-1) + data_buffer.extend(stacked_data_reshape) + + for batch in data_buffer: + + # Linearly decrease the learning rate and clip epsilon + alpha = 1.0 + if anneal_lr: + alpha = 1 - (num_network_updates / total_network_updates) + for group in optim.param_groups: + group["lr"] = lr * alpha + num_network_updates += 1 + + # Get a data batch + batch = batch.to(device) + + # Forward pass loss + loss = loss_module(batch) + losses[j] = loss.select( + "loss_critic", "loss_entropy", "loss_objective" + ).detach() + loss_sum = ( + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + ) + + # Backward pass + loss_sum.backward() + torch.nn.utils.clip_grad_norm_( + list(loss_module.parameters()), max_norm=max_grad_norm + ) + + # Update the networks + optim.step() + optim.zero_grad() + + # Get training losses and times + training_time = time.time() - training_start + losses = losses.apply(lambda x: x.float().mean(), batch_size=[]) + for key, value in losses.items(): + log_info.update({f"train/{key}": value.item()}) + log_info.update( + { + "train/lr": alpha * lr, + "train/sampling_time": sampling_time, + "train/training_time": training_time, + } + ) + + # Get test rewards + with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( + i * frames_in_batch * frame_skip + ) // test_interval: + actor.eval() + eval_start = time.time() + test_reward = eval_model( + actor, test_env, num_episodes=num_test_episodes + ) + eval_time = time.time() - eval_start + log_info.update( + { + "eval/reward": test_reward, + "eval/time": eval_time, + } + ) + actor.train() + + if logger: + for key, value in log_info.items(): + logger.log_scalar(key, value, collected_frames) + + collector.update_policy_weights_() + sampling_start = time.time() + accumulator = [] + + collector.shutdown() + end_time = time.time() + execution_time = end_time - start_time + print(f"Training took {execution_time:.2f} seconds to finish") + + +if __name__ == "__main__": + main() From 1a8efd1c6a87f7ec2ae564690a5b26920b1198d1 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 17:15:14 +0200 Subject: [PATCH 079/109] code example with submitit --- examples/impala/impala_multi_node_ray.py | 2 +- examples/impala/impala_multi_node_submitit.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 67309940381..26a83e66cef 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -10,7 +10,7 @@ import hydra -@hydra.main(config_path=".", config_name="config_multi_node_submitit", version_base="1.1") +@hydra.main(config_path=".", config_name="config_multi_node_ray", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index f88ca114b94..07b8981e044 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -10,7 +10,7 @@ import hydra -@hydra.main(config_path=".", config_name="config_multi_node_ray", version_base="1.1") +@hydra.main(config_path=".", config_name="config_multi_node_submitit", version_base="1.1") def main(cfg: "DictConfig"): # noqa: F821 import time From 3ef4001bea25e90888f4b62b351d626fc3dd2abc Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 17:16:47 +0200 Subject: [PATCH 080/109] code example with submitit --- examples/impala/impala_multi_node_ray.py | 42 ++++++++++++++----- examples/impala/impala_multi_node_submitit.py | 42 +++++-------------- 2 files changed, 42 insertions(+), 42 deletions(-) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 26a83e66cef..f88ca114b94 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -20,7 +20,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from tensordict import TensorDict from torchrl.collectors import SyncDataCollector - from torchrl.collectors.distributed import RPCDataCollector + from torchrl.collectors.distributed import RayCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type @@ -55,24 +55,44 @@ def main(cfg: "DictConfig"): # noqa: F821 actor, critic = make_ppo_models(cfg.env.env_name) actor, critic = actor.to(device), critic.to(device) - slurm_kwargs = { - "timeout_min": cfg.slurm.config.timeout_min, - "slurm_partition": cfg.slurm.config.slurm_partition, - "slurm_cpus_per_task": cfg.slurm.config.slurm_cpus_per_task, - "slurm_gpus_per_node": cfg.slurm.config.slurm_gpus_per_node, - } - # Create collector - collector = RPCDataCollector( + ray_init_config = { + "address": cfg.ray_init_config.address, + "num_cpus": cfg.ray_init_config.num_cpus, + "num_gpus": cfg.ray_init_config.num_gpus, + "resources": cfg.ray_init_config.resources, + "object_store_memory": cfg.ray_init_config.object_store_memory, + "local_mode": cfg.ray_init_config.local_mode, + "ignore_reinit_error": cfg.ray_init_config.ignore_reinit_error, + "include_dashboard": cfg.ray_init_config.include_dashboard, + "dashboard_host": cfg.ray_init_config.dashboard_host, + "dashboard_port": cfg.ray_init_config.dashboard_port, + "job_config": cfg.ray_init_config.job_config, + "configure_logging": cfg.ray_init_config.configure_logging, + "logging_level": cfg.ray_init_config.logging_level, + "logging_format": cfg.ray_init_config.logging_format, + "log_to_driver": cfg.ray_init_config.log_to_driver, + "namespace": cfg.ray_init_config.namespace, + "runtime_env": cfg.ray_init_config.runtime_env, + "storage": cfg.ray_init_config.storage, + } + remote_config = { + "num_cpus": cfg.remote_worker_resources.num_cpus, + "num_gpus": cfg.remote_worker_resources.num_gpus + if torch.cuda.device_count() + else 0, + "memory": cfg.remote_worker_resources.memory, + } + collector = RayCollector( create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, policy=actor, collector_class=SyncDataCollector, frames_per_batch=frames_per_batch, total_frames=total_frames, max_frames_per_traj=-1, + ray_init_config=ray_init_config, + remote_configs=remote_config, sync=False, - slurm_kwargs=slurm_kwargs, - launcher="submitit", storing_device=device, update_after_each_batch=True, ) diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index 07b8981e044..67309940381 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -20,7 +20,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from tensordict import TensorDict from torchrl.collectors import SyncDataCollector - from torchrl.collectors.distributed import RayCollector + from torchrl.collectors.distributed import RPCDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type @@ -55,44 +55,24 @@ def main(cfg: "DictConfig"): # noqa: F821 actor, critic = make_ppo_models(cfg.env.env_name) actor, critic = actor.to(device), critic.to(device) - # Create collector - ray_init_config = { - "address": cfg.ray_init_config.address, - "num_cpus": cfg.ray_init_config.num_cpus, - "num_gpus": cfg.ray_init_config.num_gpus, - "resources": cfg.ray_init_config.resources, - "object_store_memory": cfg.ray_init_config.object_store_memory, - "local_mode": cfg.ray_init_config.local_mode, - "ignore_reinit_error": cfg.ray_init_config.ignore_reinit_error, - "include_dashboard": cfg.ray_init_config.include_dashboard, - "dashboard_host": cfg.ray_init_config.dashboard_host, - "dashboard_port": cfg.ray_init_config.dashboard_port, - "job_config": cfg.ray_init_config.job_config, - "configure_logging": cfg.ray_init_config.configure_logging, - "logging_level": cfg.ray_init_config.logging_level, - "logging_format": cfg.ray_init_config.logging_format, - "log_to_driver": cfg.ray_init_config.log_to_driver, - "namespace": cfg.ray_init_config.namespace, - "runtime_env": cfg.ray_init_config.runtime_env, - "storage": cfg.ray_init_config.storage, - } - remote_config = { - "num_cpus": cfg.remote_worker_resources.num_cpus, - "num_gpus": cfg.remote_worker_resources.num_gpus - if torch.cuda.device_count() - else 0, - "memory": cfg.remote_worker_resources.memory, + slurm_kwargs = { + "timeout_min": cfg.slurm.config.timeout_min, + "slurm_partition": cfg.slurm.config.slurm_partition, + "slurm_cpus_per_task": cfg.slurm.config.slurm_cpus_per_task, + "slurm_gpus_per_node": cfg.slurm.config.slurm_gpus_per_node, } - collector = RayCollector( + + # Create collector + collector = RPCDataCollector( create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, policy=actor, collector_class=SyncDataCollector, frames_per_batch=frames_per_batch, total_frames=total_frames, max_frames_per_traj=-1, - ray_init_config=ray_init_config, - remote_configs=remote_config, sync=False, + slurm_kwargs=slurm_kwargs, + launcher="submitit", storing_device=device, update_after_each_batch=True, ) From fcc1121685ac5b88d30bee2fd4a1b896530d1002 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 17:22:16 +0200 Subject: [PATCH 081/109] code example with submitit --- examples/impala/config_multi_node_ray.yaml | 2 +- examples/impala/impala_multi_node_ray.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/impala/config_multi_node_ray.yaml b/examples/impala/config_multi_node_ray.yaml index c88966fb395..978aa50915b 100644 --- a/examples/impala/config_multi_node_ray.yaml +++ b/examples/impala/config_multi_node_ray.yaml @@ -27,7 +27,7 @@ ray_init_config: remote_worker_resources: num_cpus: 1 num_gpus: 0.25 - memory: 2 * 1024**3 + memory: 1073741824 # 1*1024**3 - 1GB # collector collector: diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index f88ca114b94..48320d39603 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -77,11 +77,11 @@ def main(cfg: "DictConfig"): # noqa: F821 "storage": cfg.ray_init_config.storage, } remote_config = { - "num_cpus": cfg.remote_worker_resources.num_cpus, - "num_gpus": cfg.remote_worker_resources.num_gpus + "num_cpus": float(cfg.remote_worker_resources.num_cpus), + "num_gpus": float(cfg.remote_worker_resources.num_gpus) if torch.cuda.device_count() else 0, - "memory": cfg.remote_worker_resources.memory, + "memory": float(cfg.remote_worker_resources.memory), } collector = RayCollector( create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, @@ -90,7 +90,7 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch=frames_per_batch, total_frames=total_frames, max_frames_per_traj=-1, - ray_init_config=ray_init_config, + # ray_init_config=ray_init_config, remote_configs=remote_config, sync=False, storing_device=device, From dd2a7f3668dc5d7b575bfdb52aeda258e6ef53f0 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 17:31:52 +0200 Subject: [PATCH 082/109] code example with submitit --- examples/impala/config_multi_node_ray.yaml | 28 +++++++++---------- .../impala/config_multi_node_submitit.yaml | 8 +++--- examples/impala/impala_multi_node_ray.py | 2 +- examples/impala/impala_multi_node_submitit.py | 4 ++- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/examples/impala/config_multi_node_ray.yaml b/examples/impala/config_multi_node_ray.yaml index 978aa50915b..7117578ded1 100644 --- a/examples/impala/config_multi_node_ray.yaml +++ b/examples/impala/config_multi_node_ray.yaml @@ -4,24 +4,24 @@ env: # Ray init kwargs - https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html ray_init_config: - address: None - num_cpus: None - num_gpus: None - resources: None - object_store_memory: None + address: null + num_cpus: null + num_gpus: null + resources: null + object_store_memory: null local_mode: False ignore_reinit_error: False - include_dashboard: None - dashboard_host: "127.0.0.1" - dashboard_port: None - job_config: None + include_dashboard: null + dashboard_host: 127.0.0.1 + dashboard_port: null + job_config: null configure_logging: True - logging_level: "info" - logging_format: None + logging_level: info + logging_format: null log_to_driver: True - namespace: None - runtime_env: None - storage: None + namespace: null + runtime_env: null + storage: null # Resources assigned to each IMPALA rollout collection worker remote_worker_resources: diff --git a/examples/impala/config_multi_node_submitit.yaml b/examples/impala/config_multi_node_submitit.yaml index e156ede92f6..d53c2fdf43c 100644 --- a/examples/impala/config_multi_node_submitit.yaml +++ b/examples/impala/config_multi_node_submitit.yaml @@ -4,10 +4,10 @@ env: # SLURM config slurm_config: - timeout_min: 10, - slurm_partition: "train", - slurm_cpus_per_task: 32, - slurm_gpus_per_node: 0, + timeout_min: 10 + slurm_partition: train + slurm_cpus_per_task: 32 + slurm_gpus_per_node: 0 # collector collector: diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 48320d39603..a0a520025d7 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -90,7 +90,7 @@ def main(cfg: "DictConfig"): # noqa: F821 frames_per_batch=frames_per_batch, total_frames=total_frames, max_frames_per_traj=-1, - # ray_init_config=ray_init_config, + ray_init_config=ray_init_config, remote_configs=remote_config, sync=False, storing_device=device, diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index 67309940381..1c09944d440 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -10,7 +10,9 @@ import hydra -@hydra.main(config_path=".", config_name="config_multi_node_submitit", version_base="1.1") +@hydra.main( + config_path=".", config_name="config_multi_node_submitit", version_base="1.1" +) def main(cfg: "DictConfig"): # noqa: F821 import time From 624b2d69ea0cb479ee1616e26e33c040acf6b494 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 17:33:54 +0200 Subject: [PATCH 083/109] code example with submitit --- examples/impala/impala_multi_node_ray.py | 1 - examples/impala/impala_multi_node_submitit.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index a0a520025d7..856893a165f 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -93,7 +93,6 @@ def main(cfg: "DictConfig"): # noqa: F821 ray_init_config=ray_init_config, remote_configs=remote_config, sync=False, - storing_device=device, update_after_each_batch=True, ) diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index 1c09944d440..7b97b07bffb 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -75,7 +75,6 @@ def main(cfg: "DictConfig"): # noqa: F821 sync=False, slurm_kwargs=slurm_kwargs, launcher="submitit", - storing_device=device, update_after_each_batch=True, ) From 7e300694c94d8af908d784b3bd9e4842d8cb5dba Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 5 Oct 2023 17:40:20 +0200 Subject: [PATCH 084/109] code example with submitit --- examples/impala/config_multi_node_submitit.yaml | 6 +++--- examples/impala/impala_multi_node_submitit.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/impala/config_multi_node_submitit.yaml b/examples/impala/config_multi_node_submitit.yaml index d53c2fdf43c..8ad08292c7e 100644 --- a/examples/impala/config_multi_node_submitit.yaml +++ b/examples/impala/config_multi_node_submitit.yaml @@ -6,14 +6,14 @@ env: slurm_config: timeout_min: 10 slurm_partition: train - slurm_cpus_per_task: 32 - slurm_gpus_per_node: 0 + slurm_cpus_per_task: 1 + slurm_gpus_per_node: 1 # collector collector: frames_per_batch: 80 total_frames: 200_000_000 - num_workers: 12 + num_workers: 4 # logger logger: diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index 7b97b07bffb..e72c8faff1d 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -58,10 +58,10 @@ def main(cfg: "DictConfig"): # noqa: F821 actor, critic = actor.to(device), critic.to(device) slurm_kwargs = { - "timeout_min": cfg.slurm.config.timeout_min, - "slurm_partition": cfg.slurm.config.slurm_partition, - "slurm_cpus_per_task": cfg.slurm.config.slurm_cpus_per_task, - "slurm_gpus_per_node": cfg.slurm.config.slurm_gpus_per_node, + "timeout_min": cfg.slurm_config.timeout_min, + "slurm_partition": cfg.slurm_config.slurm_partition, + "slurm_cpus_per_task": cfg.slurm_config.slurm_cpus_per_task, + "slurm_gpus_per_node": cfg.slurm_config.slurm_gpus_per_node, } # Create collector From 597623bc08365a38a63b4ab7704e40a87b6486ba Mon Sep 17 00:00:00 2001 From: albert bou Date: Tue, 14 Nov 2023 16:49:31 +0100 Subject: [PATCH 085/109] fix logging --- examples/impala/impala_single_node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index fea69b3fb14..d6523cae51a 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -129,9 +129,9 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(data.numel()) # Get training rewards and episode lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] + episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "done"]] + episode_length = data["next", "step_count"][data["next", "terminated"]] log_info.update( { "train/reward": episode_rewards.mean().item(), From 5c21c1e8395b597b565c183b04bcb08d04eeda2f Mon Sep 17 00:00:00 2001 From: albert bou Date: Sun, 19 Nov 2023 11:02:39 +0100 Subject: [PATCH 086/109] fix example --- examples/impala/impala_multi_node_ray.py | 5 +++-- examples/impala/impala_multi_node_submitit.py | 5 +++-- examples/impala/impala_single_node.py | 19 ++++++++++--------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 856893a165f..4d33760f4dc 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -189,8 +189,9 @@ def main(cfg: "DictConfig"): # noqa: F821 stacked_data = adv_module(stacked_data) # Add to replay buffer - stacked_data_reshape = stacked_data.reshape(-1) - data_buffer.extend(stacked_data_reshape) + for stacked_d in stacked_data: + stacked_data_reshape = stacked_d.reshape(-1) + data_buffer.extend(stacked_data_reshape) for batch in data_buffer: diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index e72c8faff1d..c5a6b80444d 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -171,8 +171,9 @@ def main(cfg: "DictConfig"): # noqa: F821 stacked_data = adv_module(stacked_data) # Add to replay buffer - stacked_data_reshape = stacked_data.reshape(-1) - data_buffer.extend(stacked_data_reshape) + for stacked_d in stacked_data: + stacked_data_reshape = stacked_d.reshape(-1) + data_buffer.extend(stacked_data_reshape) for batch in data_buffer: diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index d6523cae51a..64b1f7704d5 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -47,8 +47,8 @@ def main(cfg: "DictConfig"): # noqa: F821 max_grad_norm = cfg.optim.max_grad_norm num_test_episodes = cfg.logger.num_test_episodes total_network_updates = ( - total_frames // (frames_per_batch * batch_size) - ) * cfg.loss.sgd_updates + total_frames // (frames_per_batch * batch_size) + ) * cfg.loss.sgd_updates # Create models (check utils_atari.py) actor, critic = make_ppo_models(cfg.env.env_name) @@ -88,7 +88,6 @@ def main(cfg: "DictConfig"): # noqa: F821 entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, ) - loss_module.set_keys(done="eol", terminated="eol") # Create optimizer optim = torch.optim.RMSprop( @@ -136,7 +135,7 @@ def main(cfg: "DictConfig"): # noqa: F821 { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() - / len(episode_length), + / len(episode_length), } ) @@ -156,11 +155,13 @@ def main(cfg: "DictConfig"): # noqa: F821 stacked_data = stacked_data.to(device) # Compute advantage - stacked_data = adv_module(stacked_data) + with torch.no_grad(): + stacked_data = adv_module(stacked_data) # Add to replay buffer - stacked_data_reshape = stacked_data.reshape(-1) - data_buffer.extend(stacked_data_reshape) + for stacked_d in stacked_data: + stacked_data_reshape = stacked_d.reshape(-1) + data_buffer.extend(stacked_data_reshape) for batch in data_buffer: @@ -181,7 +182,7 @@ def main(cfg: "DictConfig"): # noqa: F821 "loss_critic", "loss_entropy", "loss_objective" ).detach() loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] ) # Backward pass @@ -210,7 +211,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.MODE): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( - i * frames_in_batch * frame_skip + i * frames_in_batch * frame_skip ) // test_interval: actor.eval() eval_start = time.time() From e47dbc32a920b5958d8afdd8d0f439cddd27b35e Mon Sep 17 00:00:00 2001 From: albert bou Date: Sun, 19 Nov 2023 11:50:22 +0100 Subject: [PATCH 087/109] fix example --- examples/impala/impala_multi_node_ray.py | 2 +- examples/impala/impala_multi_node_submitit.py | 2 +- examples/impala/impala_single_node.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 4d33760f4dc..112c5158073 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -118,7 +118,7 @@ def main(cfg: "DictConfig"): # noqa: F821 entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, ) - # loss_module.set_keys(done="eol", terminated="eol") + loss_module.set_keys(done="eol", terminated="eol") # Create optimizer optim = torch.optim.RMSprop( diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index c5a6b80444d..c354dd22364 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -100,7 +100,7 @@ def main(cfg: "DictConfig"): # noqa: F821 entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, ) - # loss_module.set_keys(done="eol", terminated="eol") + loss_module.set_keys(done="eol", terminated="eol") # Create optimizer optim = torch.optim.RMSprop( diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index 64b1f7704d5..516ca444492 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -88,6 +88,7 @@ def main(cfg: "DictConfig"): # noqa: F821 entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, ) + loss_module.set_keys(done="eol", terminated="eol") # Create optimizer optim = torch.optim.RMSprop( From c23401ad1a7ec2962e4b782487dc5f4e9ad870b2 Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Wed, 22 Nov 2023 15:28:01 +0100 Subject: [PATCH 088/109] Update examples/impala/impala_multi_node_ray.py Co-authored-by: Vincent Moens --- examples/impala/impala_multi_node_ray.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 112c5158073..153ba14d9a4 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -146,10 +146,9 @@ def main(cfg: "DictConfig"): # noqa: F821 # Main loop collected_frames = 0 num_network_updates = 0 - start_time = time.time() pbar = tqdm.tqdm(total=total_frames) accumulator = [] - sampling_start = time.time() + start_time = sampling_start = time.time() for i, data in enumerate(collector): log_info = {} From 886b4e0d8734aafddbc8230d5ca009b5d86041da Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Wed, 22 Nov 2023 15:28:18 +0100 Subject: [PATCH 089/109] Update torchrl/objectives/value/advantages.py Co-authored-by: Vincent Moens --- torchrl/objectives/value/advantages.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index c8d2a1ab71a..404984c0f8f 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -1361,6 +1361,7 @@ class VTrace(ValueEstimatorBase): value_network (TensorDictModule): value operator used to retrieve the value estimates. actor_network (TensorDictModule): actor operator used to retrieve the log prob. rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights. + Defaults to ``1.0``. c_thresh (Union[float, Tensor]): c clipping parameter for importance weights. average_adv (bool): if ``True``, the resulting advantage values will be standardized. Default is ``False``. From 803fc4f812f8178dacec2a48c3f7792bc7271c8e Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Wed, 22 Nov 2023 15:28:27 +0100 Subject: [PATCH 090/109] Update torchrl/objectives/value/advantages.py Co-authored-by: Vincent Moens --- torchrl/objectives/value/advantages.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 404984c0f8f..f278128fc6c 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -1412,9 +1412,9 @@ def __init__( average_adv: bool = False, differentiable: bool = False, skip_existing: Optional[bool] = None, - advantage_key: NestedKey = None, - value_target_key: NestedKey = None, - value_key: NestedKey = None, + advantage_key: Optional[NestedKey] = None, + value_target_key: Optional[NestedKey] = None, + value_key: Optional[NestedKey] = None, shifted: bool = False, ): super().__init__( From 72a3c6edc90d85d0c1730910db3d0db74b45f6e4 Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Wed, 22 Nov 2023 15:29:00 +0100 Subject: [PATCH 091/109] Update torchrl/objectives/value/advantages.py Co-authored-by: Vincent Moens --- torchrl/objectives/value/advantages.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index f278128fc6c..8e641dfb8c1 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -1363,6 +1363,7 @@ class VTrace(ValueEstimatorBase): rho_thresh (Union[float, Tensor]): rho clipping parameter for importance weights. Defaults to ``1.0``. c_thresh (Union[float, Tensor]): c clipping parameter for importance weights. + Defaults to ``1.0``. average_adv (bool): if ``True``, the resulting advantage values will be standardized. Default is ``False``. differentiable (bool, optional): if ``True``, gradients are propagated through From 4a061b5e8786e055d80cbfc5945e9ba66749ac4d Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Wed, 22 Nov 2023 15:29:20 +0100 Subject: [PATCH 092/109] Update examples/impala/impala_multi_node_ray.py Co-authored-by: Vincent Moens --- examples/impala/impala_multi_node_ray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 153ba14d9a4..07cd1f9c6f8 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -182,7 +182,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create a single batch of trajectories stacked_data = torch.stack(accumulator, dim=0) - stacked_data = stacked_data.to(device) + stacked_data = stacked_data.to(device, non_blocking=True) # Compute advantage stacked_data = adv_module(stacked_data) From e7069e4ee16c7c5f4dee06aa8f2e5a41f6005cbc Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Wed, 22 Nov 2023 15:29:34 +0100 Subject: [PATCH 093/109] Update examples/impala/impala_multi_node_ray.py Co-authored-by: Vincent Moens --- examples/impala/impala_multi_node_ray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 07cd1f9c6f8..9b95c96b39a 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -203,7 +203,7 @@ def main(cfg: "DictConfig"): # noqa: F821 num_network_updates += 1 # Get a data batch - batch = batch.to(device) + batch = batch.to(device, non_blocking=True) # Forward pass loss loss = loss_module(batch) From 5399cf1bdebd88e1eee4cdf99f4d4c4bd9630e7a Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 22 Nov 2023 15:33:34 +0100 Subject: [PATCH 094/109] merge main --- test/test_cost.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 42ad536b5e0..6d53c136c5f 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -10751,7 +10751,7 @@ def test_dispatch( value_net = TensorDictModule( nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) - if adv == VTrace: + if adv is VTrace: actor_net = TensorDictModule( nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] ) @@ -10812,7 +10812,7 @@ def test_diff_reward( value_net = TensorDictModule( nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) - if adv == VTrace: + if adv is VTrace: actor_net = TensorDictModule( nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] ) @@ -10884,7 +10884,7 @@ def test_non_differentiable(self, adv, shifted, kwargs): nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"] ) - if adv == VTrace: + if adv is VTrace: actor_net = TensorDictModule( nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] ) @@ -10969,7 +10969,7 @@ def test_skip_existing( else: value_net = None - if adv == VTrace: + if adv is VTrace: actor_net = TensorDictModule( nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] ) @@ -11050,7 +11050,7 @@ def test_skip_existing( ) def test_set_keys(self, value, adv, kwargs): value_net = TensorDictModule(nn.Linear(3, 1), in_keys=["obs"], out_keys=[value]) - if adv == VTrace: + if adv is VTrace: actor_net = TensorDictModule( nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] ) @@ -11096,7 +11096,7 @@ def test_set_deprecated_keys(self, adv, kwargs): with pytest.warns(DeprecationWarning): - if adv == VTrace: + if adv is VTrace: actor_net = TensorDictModule( nn.Linear(3, 4), in_keys=["obs"], out_keys=["logits"] ) From 39584ab60ac3e7d5567273a973bba29823c917fc Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 22 Nov 2023 15:37:27 +0100 Subject: [PATCH 095/109] fixes --- torchrl/objectives/value/advantages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 8e641dfb8c1..8f4c9c09ae8 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -821,7 +821,7 @@ def value_estimate( if self.average_rewards: reward = reward - reward.mean() - reward = reward / reward.std().clamp_min(1e-4) + reward = reward / reward.std().clamp_min(1e-5) tensordict.set( ("next", self.tensor_keys.reward), reward ) # we must update the rewards if they are used later in the code From c68fd40b2a81b68ccaf7a7a2b7cb0b552f530b42 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 22 Nov 2023 15:48:49 +0100 Subject: [PATCH 096/109] fixes --- examples/impala/impala_multi_node_ray.py | 2 +- examples/impala/impala_single_node.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 9b95c96b39a..be7a2ea81ec 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -181,7 +181,7 @@ def main(cfg: "DictConfig"): # noqa: F821 for j in range(sgd_updates): # Create a single batch of trajectories - stacked_data = torch.stack(accumulator, dim=0) + stacked_data = torch.stack(accumulator, dim=0).contiguous() stacked_data = stacked_data.to(device, non_blocking=True) # Compute advantage diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index 516ca444492..47cf31cc6df 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -152,8 +152,8 @@ def main(cfg: "DictConfig"): # noqa: F821 for j in range(sgd_updates): # Create a single batch of trajectories - stacked_data = torch.stack(accumulator, dim=0) - stacked_data = stacked_data.to(device) + stacked_data = torch.stack(accumulator, dim=0).contiguous() + stacked_data = stacked_data.to(device, non_blocking=True) # Compute advantage with torch.no_grad(): From 638c0d647433da8e4e9af03386047adacf2a2770 Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 22 Nov 2023 16:01:50 +0100 Subject: [PATCH 097/109] format --- examples/impala/impala_multi_node_submitit.py | 7 +++---- examples/impala/impala_single_node.py | 13 ++++++------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index c354dd22364..118913699f9 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -128,10 +128,9 @@ def main(cfg: "DictConfig"): # noqa: F821 # Main loop collected_frames = 0 num_network_updates = 0 - start_time = time.time() pbar = tqdm.tqdm(total=total_frames) accumulator = [] - sampling_start = time.time() + start_time = sampling_start = time.time() for i, data in enumerate(collector): log_info = {} @@ -164,8 +163,8 @@ def main(cfg: "DictConfig"): # noqa: F821 for j in range(sgd_updates): # Create a single batch of trajectories - stacked_data = torch.stack(accumulator, dim=0) - stacked_data = stacked_data.to(device) + stacked_data = torch.stack(accumulator, dim=0).contiguous() + stacked_data = stacked_data.to(device, non_blocking=True) # Compute advantage stacked_data = adv_module(stacked_data) diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index 47cf31cc6df..2cd1043f46f 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -47,8 +47,8 @@ def main(cfg: "DictConfig"): # noqa: F821 max_grad_norm = cfg.optim.max_grad_norm num_test_episodes = cfg.logger.num_test_episodes total_network_updates = ( - total_frames // (frames_per_batch * batch_size) - ) * cfg.loss.sgd_updates + total_frames // (frames_per_batch * batch_size) + ) * cfg.loss.sgd_updates # Create models (check utils_atari.py) actor, critic = make_ppo_models(cfg.env.env_name) @@ -116,10 +116,9 @@ def main(cfg: "DictConfig"): # noqa: F821 # Main loop collected_frames = 0 num_network_updates = 0 - start_time = time.time() pbar = tqdm.tqdm(total=total_frames) accumulator = [] - sampling_start = time.time() + start_time = sampling_start = time.time() for i, data in enumerate(collector): log_info = {} @@ -136,7 +135,7 @@ def main(cfg: "DictConfig"): # noqa: F821 { "train/reward": episode_rewards.mean().item(), "train/episode_length": episode_length.sum().item() - / len(episode_length), + / len(episode_length), } ) @@ -183,7 +182,7 @@ def main(cfg: "DictConfig"): # noqa: F821 "loss_critic", "loss_entropy", "loss_objective" ).detach() loss_sum = ( - loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] + loss["loss_critic"] + loss["loss_objective"] + loss["loss_entropy"] ) # Backward pass @@ -212,7 +211,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Get test rewards with torch.no_grad(), set_exploration_type(ExplorationType.MODE): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( - i * frames_in_batch * frame_skip + i * frames_in_batch * frame_skip ) // test_interval: actor.eval() eval_start = time.time() From 2f8b545926daf52f4ba33086be711e60e7dc006d Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 22 Nov 2023 16:09:21 +0100 Subject: [PATCH 098/109] fixes --- examples/impala/config_multi_node_ray.yaml | 3 +++ examples/impala/config_multi_node_submitit.yaml | 3 +++ examples/impala/config_single_node.yaml | 3 +++ examples/impala/impala_multi_node_ray.py | 2 +- examples/impala/impala_multi_node_submitit.py | 2 +- examples/impala/impala_single_node.py | 2 +- 6 files changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/impala/config_multi_node_ray.yaml b/examples/impala/config_multi_node_ray.yaml index 7117578ded1..925a655e9c2 100644 --- a/examples/impala/config_multi_node_ray.yaml +++ b/examples/impala/config_multi_node_ray.yaml @@ -23,6 +23,9 @@ ray_init_config: runtime_env: null storage: null +# Device for the forward and backward passes +device: "cuda:0" + # Resources assigned to each IMPALA rollout collection worker remote_worker_resources: num_cpus: 1 diff --git a/examples/impala/config_multi_node_submitit.yaml b/examples/impala/config_multi_node_submitit.yaml index 8ad08292c7e..f924e34fc27 100644 --- a/examples/impala/config_multi_node_submitit.yaml +++ b/examples/impala/config_multi_node_submitit.yaml @@ -2,6 +2,9 @@ env: env_name: PongNoFrameskip-v4 +# Device for the forward and backward passes +local_device: "cuda:0" + # SLURM config slurm_config: timeout_min: 10 diff --git a/examples/impala/config_single_node.yaml b/examples/impala/config_single_node.yaml index 86a11d6b40c..de6fc718552 100644 --- a/examples/impala/config_single_node.yaml +++ b/examples/impala/config_single_node.yaml @@ -2,6 +2,9 @@ env: env_name: PongNoFrameskip-v4 +# Device for the forward and backward passes +local_device: "cuda:0" + # collector collector: frames_per_batch: 80 diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index be7a2ea81ec..592bd839821 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -29,7 +29,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + device = torch.device(cfg.local_device) # Correct for frame_skip frame_skip = 4 diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index 118913699f9..8d80e200030 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -31,7 +31,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + device = torch.device(cfg.local_device) # Correct for frame_skip frame_skip = 4 diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index 2cd1043f46f..8d587064f26 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -28,7 +28,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from torchrl.record.loggers import generate_exp_name, get_logger from utils import eval_model, make_env, make_ppo_models - device = "cpu" if not torch.cuda.device_count() else "cuda" + device = torch.device(cfg.device) # Correct for frame_skip frame_skip = 4 From 6ddcb3ab86a07abaebcb9154451b46b12ce409fa Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 22 Nov 2023 16:30:35 +0100 Subject: [PATCH 099/109] fixes --- examples/impala/impala_multi_node_ray.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 592bd839821..3af972669b0 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -77,11 +77,11 @@ def main(cfg: "DictConfig"): # noqa: F821 "storage": cfg.ray_init_config.storage, } remote_config = { - "num_cpus": float(cfg.remote_worker_resources.num_cpus), - "num_gpus": float(cfg.remote_worker_resources.num_gpus) + "num_cpus": int(cfg.remote_worker_resources.num_cpus), + "num_gpus": int(cfg.remote_worker_resources.num_gpus) if torch.cuda.device_count() else 0, - "memory": float(cfg.remote_worker_resources.memory), + "memory": int(cfg.remote_worker_resources.memory), } collector = RayCollector( create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, From 94306bf129067be9163892620aa776c892eaf77e Mon Sep 17 00:00:00 2001 From: albert bou Date: Wed, 22 Nov 2023 16:32:21 +0100 Subject: [PATCH 100/109] fixes --- examples/impala/config_multi_node_ray.yaml | 2 +- examples/impala/config_single_node.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/impala/config_multi_node_ray.yaml b/examples/impala/config_multi_node_ray.yaml index 925a655e9c2..a101c8ab262 100644 --- a/examples/impala/config_multi_node_ray.yaml +++ b/examples/impala/config_multi_node_ray.yaml @@ -24,7 +24,7 @@ ray_init_config: storage: null # Device for the forward and backward passes -device: "cuda:0" +local_device: "cuda:0" # Resources assigned to each IMPALA rollout collection worker remote_worker_resources: diff --git a/examples/impala/config_single_node.yaml b/examples/impala/config_single_node.yaml index de6fc718552..51801dfacb7 100644 --- a/examples/impala/config_single_node.yaml +++ b/examples/impala/config_single_node.yaml @@ -3,7 +3,7 @@ env: env_name: PongNoFrameskip-v4 # Device for the forward and backward passes -local_device: "cuda:0" +device: "cuda:0" # collector collector: From 9cc0284b783d782d40dcff79f807910699458054 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 23 Nov 2023 09:22:11 +0100 Subject: [PATCH 101/109] fixes --- examples/impala/impala_multi_node_ray.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 3af972669b0..b969a9d85c8 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -77,11 +77,11 @@ def main(cfg: "DictConfig"): # noqa: F821 "storage": cfg.ray_init_config.storage, } remote_config = { - "num_cpus": int(cfg.remote_worker_resources.num_cpus), - "num_gpus": int(cfg.remote_worker_resources.num_gpus) + "num_cpus": cfg.remote_worker_resources.num_cpus, + "num_gpus": cfg.remote_worker_resources.num_gpus if torch.cuda.device_count() else 0, - "memory": int(cfg.remote_worker_resources.memory), + "memory": cfg.remote_worker_resources.memory, } collector = RayCollector( create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, From 63392f0692431634ef1114aa0d8d0b3abc3fed85 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 23 Nov 2023 09:29:04 +0100 Subject: [PATCH 102/109] fixes --- torchrl/objectives/value/advantages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 8f4c9c09ae8..d23d5b3e978 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -1354,7 +1354,7 @@ class VTrace(ValueEstimatorBase): """A class wrapper around V-Trace estimate functional. Refer to "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" - https://arxiv.org/abs/1802.01561 for more context. + :ref:`here `_ for more context. Args: gamma (scalar): exponential mean discount. From e61f3425ee5f5e5e4ebac1dae4243261c6e419ad Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 23 Nov 2023 10:09:08 +0100 Subject: [PATCH 103/109] fixes --- examples/impala/config_multi_node_ray.yaml | 2 +- examples/impala/config_multi_node_submitit.yaml | 2 +- examples/impala/config_single_node.yaml | 2 +- examples/impala/impala_multi_node_ray.py | 8 ++++++-- examples/impala/impala_multi_node_submitit.py | 8 ++++++-- examples/impala/impala_single_node.py | 5 ++++- 6 files changed, 19 insertions(+), 8 deletions(-) diff --git a/examples/impala/config_multi_node_ray.yaml b/examples/impala/config_multi_node_ray.yaml index a101c8ab262..e312b336651 100644 --- a/examples/impala/config_multi_node_ray.yaml +++ b/examples/impala/config_multi_node_ray.yaml @@ -41,7 +41,7 @@ collector: # logger logger: backend: wandb - exp_name: Atari_Schulman17 + exp_name: Atari_IMPALA test_interval: 200_000_000 num_test_episodes: 3 diff --git a/examples/impala/config_multi_node_submitit.yaml b/examples/impala/config_multi_node_submitit.yaml index f924e34fc27..cb07c6e8bf2 100644 --- a/examples/impala/config_multi_node_submitit.yaml +++ b/examples/impala/config_multi_node_submitit.yaml @@ -21,7 +21,7 @@ collector: # logger logger: backend: wandb - exp_name: Atari_Schulman17 + exp_name: Atari_IMPALA test_interval: 200_000_000 num_test_episodes: 3 diff --git a/examples/impala/config_single_node.yaml b/examples/impala/config_single_node.yaml index 51801dfacb7..d39407c1a69 100644 --- a/examples/impala/config_single_node.yaml +++ b/examples/impala/config_single_node.yaml @@ -14,7 +14,7 @@ collector: # logger logger: backend: wandb - exp_name: Atari_Schulman17 + exp_name: Atari_IMPALA test_interval: 200_000_000 num_test_episodes: 3 diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index b969a9d85c8..6d08f8bd277 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -136,7 +136,10 @@ def main(cfg: "DictConfig"): # noqa: F821 "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" ) logger = get_logger( - cfg.logger.backend, logger_name="impala", experiment_name=exp_name + cfg.logger.backend, + logger_name="impala", + experiment_name=exp_name, + project="impala", ) # Create test environment @@ -185,7 +188,8 @@ def main(cfg: "DictConfig"): # noqa: F821 stacked_data = stacked_data.to(device, non_blocking=True) # Compute advantage - stacked_data = adv_module(stacked_data) + with torch.no_grad(): + stacked_data = adv_module(stacked_data) # Add to replay buffer for stacked_d in stacked_data: diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index 8d80e200030..cb1a811891c 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -118,7 +118,10 @@ def main(cfg: "DictConfig"): # noqa: F821 "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" ) logger = get_logger( - cfg.logger.backend, logger_name="impala", experiment_name=exp_name + cfg.logger.backend, + logger_name="impala", + experiment_name=exp_name, + project="impala", ) # Create test environment @@ -167,7 +170,8 @@ def main(cfg: "DictConfig"): # noqa: F821 stacked_data = stacked_data.to(device, non_blocking=True) # Compute advantage - stacked_data = adv_module(stacked_data) + with torch.no_grad(): + stacked_data = adv_module(stacked_data) # Add to replay buffer for stacked_d in stacked_data: diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index 8d587064f26..334600653e7 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -106,7 +106,10 @@ def main(cfg: "DictConfig"): # noqa: F821 "IMPALA", f"{cfg.logger.exp_name}_{cfg.env.env_name}" ) logger = get_logger( - cfg.logger.backend, logger_name="impala", experiment_name=exp_name + cfg.logger.backend, + logger_name="impala", + experiment_name=exp_name, + project="impala", ) # Create test environment From 6d384d51ecd6fb9293fcff506122651b3d8954d8 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 23 Nov 2023 10:25:42 +0100 Subject: [PATCH 104/109] fixes --- examples/impala/impala_multi_node_ray.py | 4 ++-- examples/impala/impala_single_node.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 6d08f8bd277..a0d2d88c5a2 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -161,9 +161,9 @@ def main(cfg: "DictConfig"): # noqa: F821 pbar.update(data.numel()) # Get training rewards and episode lengths - episode_rewards = data["next", "episode_reward"][data["next", "done"]] + episode_rewards = data["next", "episode_reward"][data["next", "terminated"]] if len(episode_rewards) > 0: - episode_length = data["next", "step_count"][data["next", "done"]] + episode_length = data["next", "step_count"][data["next", "terminated"]] log_info.update( { "train/reward": episode_rewards.mean().item(), diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index 334600653e7..cd270f4c9e9 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -50,7 +50,7 @@ def main(cfg: "DictConfig"): # noqa: F821 total_frames // (frames_per_batch * batch_size) ) * cfg.loss.sgd_updates - # Create models (check utils_atari.py) + # Create models (check utils.py) actor, critic = make_ppo_models(cfg.env.env_name) actor, critic = actor.to(device), critic.to(device) @@ -177,7 +177,7 @@ def main(cfg: "DictConfig"): # noqa: F821 num_network_updates += 1 # Get a data batch - batch = batch.to(device) + batch = batch.to(device, non_blocking=True) # Forward pass loss loss = loss_module(batch) From 89770e452f6ed66fcaeb9e07ad2779ba507a61e1 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 23 Nov 2023 12:30:16 +0100 Subject: [PATCH 105/109] submitit example --- .../impala/config_multi_node_submitit.yaml | 1 + examples/impala/impala_multi_node_submitit.py | 27 ++++++++++++++----- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/examples/impala/config_multi_node_submitit.yaml b/examples/impala/config_multi_node_submitit.yaml index cb07c6e8bf2..24fa8c6a762 100644 --- a/examples/impala/config_multi_node_submitit.yaml +++ b/examples/impala/config_multi_node_submitit.yaml @@ -14,6 +14,7 @@ slurm_config: # collector collector: + backend: gloo frames_per_batch: 80 total_frames: 200_000_000 num_workers: 4 diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index cb1a811891c..864269ecf29 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -22,7 +22,7 @@ def main(cfg: "DictConfig"): # noqa: F821 from tensordict import TensorDict from torchrl.collectors import SyncDataCollector - from torchrl.collectors.distributed import RPCDataCollector + from torchrl.collectors.distributed import DistributedDataCollector from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement from torchrl.envs import ExplorationType, set_exploration_type @@ -63,20 +63,33 @@ def main(cfg: "DictConfig"): # noqa: F821 "slurm_cpus_per_task": cfg.slurm_config.slurm_cpus_per_task, "slurm_gpus_per_node": cfg.slurm_config.slurm_gpus_per_node, } - # Create collector - collector = RPCDataCollector( + device_str = "device" if num_workers <= 1 else "devices" + if cfg.collector.backend == "nccl": + collector_kwargs = {device_str: "cuda:0", f"storing_{device_str}": "cuda:0"} + elif cfg.collector.backend == "gloo": + collector_kwargs = {device_str: "cpu", f"storing_{device_str}": "cpu"} + else: + raise NotImplementedError( + f"device assignment not implemented for backend {cfg.collector.backend}" + ) + import ipdb; ipdb.set_trace() + collector = DistributedDataCollector( create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, policy=actor, - collector_class=SyncDataCollector, + num_workers_per_collector=1, frames_per_batch=frames_per_batch, total_frames=total_frames, - max_frames_per_traj=-1, - sync=False, + collector_class=SyncDataCollector, + collector_kwargs=collector_kwargs, slurm_kwargs=slurm_kwargs, + storing_device="cuda:0" if cfg.collector.backend == "nccl" else "cpu", launcher="submitit", update_after_each_batch=True, - ) + backend=cfg.collector.backend, + ) + + import ipdb; ipdb.set_trace() # Create data buffer sampler = SamplerWithoutReplacement() From 9132a60603aa639233044e2dc7221d9303476f66 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 23 Nov 2023 12:45:26 +0100 Subject: [PATCH 106/109] submitit example --- examples/impala/config_multi_node_submitit.yaml | 2 +- examples/impala/impala_multi_node_submitit.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/impala/config_multi_node_submitit.yaml b/examples/impala/config_multi_node_submitit.yaml index 24fa8c6a762..f632ba15dc2 100644 --- a/examples/impala/config_multi_node_submitit.yaml +++ b/examples/impala/config_multi_node_submitit.yaml @@ -17,7 +17,7 @@ collector: backend: gloo frames_per_batch: 80 total_frames: 200_000_000 - num_workers: 4 + num_workers: 1 # logger logger: diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index 864269ecf29..3355febbfaf 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -73,7 +73,6 @@ def main(cfg: "DictConfig"): # noqa: F821 raise NotImplementedError( f"device assignment not implemented for backend {cfg.collector.backend}" ) - import ipdb; ipdb.set_trace() collector = DistributedDataCollector( create_env_fn=[make_env(cfg.env.env_name, device)] * num_workers, policy=actor, @@ -85,11 +84,9 @@ def main(cfg: "DictConfig"): # noqa: F821 slurm_kwargs=slurm_kwargs, storing_device="cuda:0" if cfg.collector.backend == "nccl" else "cpu", launcher="submitit", - update_after_each_batch=True, + # update_after_each_batch=True, backend=cfg.collector.backend, - ) - - import ipdb; ipdb.set_trace() + ) # Create data buffer sampler = SamplerWithoutReplacement() From 89a803b75e4f5b31068348b1622a0fb8878d2c97 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 23 Nov 2023 12:52:08 +0100 Subject: [PATCH 107/109] README --- examples/impala/README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/impala/README.md b/examples/impala/README.md index b7bf7a1fe16..00e0d010b82 100644 --- a/examples/impala/README.md +++ b/examples/impala/README.md @@ -26,3 +26,8 @@ You can execute the multi-node IMPALA algorithm on Atari environments by running ```bash python impala_single_node_ray.py ``` +or + +```bash +python impala_single_node_submitit.py +``` From bd02b30626abe7c662e5cc38ecd0c441c8e279a6 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 23 Nov 2023 14:31:05 +0100 Subject: [PATCH 108/109] fix tests --- test/test_cost.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_cost.py b/test/test_cost.py index 6d53c136c5f..35297c3a1e6 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5307,7 +5307,7 @@ def test_dcql(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value) - if td_est is ValueEstimators.GAE: + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return From 0a382bb1f543881f4b041b1d957705bcee775045 Mon Sep 17 00:00:00 2001 From: albert bou Date: Thu, 23 Nov 2023 21:29:10 +0100 Subject: [PATCH 109/109] fix unused_args --- torchrl/objectives/value/advantages.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index d23d5b3e978..42ba404c05d 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -1159,7 +1159,7 @@ def __init__( def forward( self, tensordict: TensorDictBase, - *unused_args, + *, params: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, ) -> TensorDictBase: @@ -1462,7 +1462,7 @@ def in_keys(self): def forward( self, tensordict: TensorDictBase, - *unused_args, + *, params: Optional[List[Tensor]] = None, target_params: Optional[List[Tensor]] = None, ) -> TensorDictBase: