-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Adding initial PPO Code * Added buffer sampling and solved some bugs * ppo agent: device and type errors fixed * ppo updater fixed * ppo config updated * Updated Hive to use gym spaces instead of raw tuples to represent action and observation spaces * Updated tests to affect api change of 1106ec2 * Adding initial PPO Code * Added buffer sampling and solved some bugs * ppo agent: device and type errors fixed * ppo updater fixed * ppo config updated * ppo replay added * ppo replay fixed * ppo agent updated * ppo agent and config updated * ppo code running but buggy * cartpole working * ppo configs * ppo net fixed * atari configs added * ppo_nets done * ppo_replay done * ppo env wrappers added * ppo agent done * configs done * stack size > 1 handled temporarily * linting fixed * last batch drop fix * config changes * shared network added * reward wrapper added * linting fixed * docs fixed * replay changed * update loop * type specification * env wrappers registered * linting fixed * Removed one off transition, cleaned up replay buffer * Fixed linter issues * wrapper error fixed * added vars to dict; fixed long lines and var names; moved wrapper registry * config fixed * addded normalisation and fixed log * norm filed added * norm bug fixed * rew norm updated * fixes * fixing norm bug; config * config fixes * obs norm * hardcoded wrappers added * normaliser shape fixed * rew shape fixed; norm structure updated * rew norm * configs and wrapper fixed * Fixed formatting and naming * Added env wrapper logic * Renamed PPO Replay Buffer to On Policy Replay buffer * Made PPO Stateless Agent * Fixed linting issues * Minor modifications * Fixed changed * Formatting and minor changes * Refactored Advatange Computation * Reformating with black * Renaming * Refactored Normalization code * Added saving and loading of state dict for normalizers * Fixed multiplayer replay buffer for PPO * Fixed minor bug * Renamed file * Added lr annealing --------- Co-authored-by: Sriyash <[email protected]> Co-authored-by: sriyash.poddar <[email protected]> Co-authored-by: Darshan Patil <[email protected]> Co-authored-by: sriyash.poddar <[email protected]> Co-authored-by: sriyash.poddar <[email protected]> Co-authored-by: sriyash.poddar <[email protected]> Co-authored-by: sriyash.poddar <[email protected]> Co-authored-by: Sriyash Poddar <[email protected]>
- Loading branch information
1 parent
805ac4a
commit 06c485c
Showing
14 changed files
with
1,258 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from typing import Tuple, Union | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
import torch | ||
from gymnasium.spaces import Box, Discrete | ||
|
||
from hive.agents.qnets.base import FunctionApproximator | ||
from hive.agents.qnets.utils import calculate_output_dim | ||
|
||
|
||
class CategoricalHead(torch.nn.Module): | ||
"""A module that implements a discrete actor head. It uses the ouput from | ||
the :obj:`actor_net`, and adds creates a | ||
:py:class:`~torch.distributions.categorical.Categorical` object to compute | ||
the action distribution.""" | ||
|
||
def __init__( | ||
self, feature_dim: Tuple[int], action_space: gym.spaces.Discrete | ||
) -> None: | ||
""" | ||
Args: | ||
feature dim: Expected output shape of the actor network. | ||
action_shape: Expected shape of actions. | ||
""" | ||
super().__init__() | ||
self.network = torch.nn.Linear(feature_dim, action_space.n) | ||
self.distribution = torch.distributions.categorical.Categorical | ||
|
||
def forward(self, x): | ||
logits = self.network(x) | ||
return self.distribution(logits=logits) | ||
|
||
|
||
class GaussianPolicyHead(torch.nn.Module): | ||
"""A module that implements a continuous actor head. It uses the output from the | ||
:obj:`actor_net` and state independent learnable parameter :obj:`policy_logstd` to | ||
create a :py:class:`~torch.distributions.normal.Normal` object to compute | ||
the action distribution.""" | ||
|
||
def __init__(self, feature_dim: Tuple[int], action_space: gym.spaces.Box) -> None: | ||
""" | ||
Args: | ||
feature dim: Expected output shape of the actor network. | ||
action_shape: Expected shape of actions. | ||
""" | ||
super().__init__() | ||
self._action_shape = action_space.shape | ||
self.policy_mean = torch.nn.Sequential( | ||
torch.nn.Linear(feature_dim, np.prod(self._action_shape)) | ||
) | ||
self.policy_logstd = torch.nn.Parameter( | ||
torch.zeros(1, np.prod(action_space.shape)) | ||
) | ||
self.distribution = torch.distributions.normal.Normal | ||
|
||
def forward(self, x): | ||
_mean = self.policy_mean(x) | ||
_std = self.policy_logstd.repeat(x.shape[0], 1).exp() | ||
distribution = self.distribution( | ||
torch.reshape(_mean, (x.size(0), *self._action_shape)), | ||
torch.reshape(_std, (x.size(0), *self._action_shape)), | ||
) | ||
return distribution | ||
|
||
|
||
class ActorCriticNetwork(torch.nn.Module): | ||
"""A module that implements the actor and critic computation. It puts together | ||
the :obj:`representation_network`, :obj:`actor_net` and :obj:`critic_net`, then | ||
adds two final :py:class:`~torch.nn.Linear` layers to compute the action and state | ||
value.""" | ||
|
||
def __init__( | ||
self, | ||
representation_network: torch.nn.Module, | ||
actor_net: FunctionApproximator, | ||
critic_net: FunctionApproximator, | ||
network_output_dim: Union[int, Tuple[int]], | ||
action_space: Union[Box, Discrete], | ||
continuous_action: bool, | ||
) -> None: | ||
super().__init__() | ||
self._network = representation_network | ||
self._continuous_action = continuous_action | ||
if actor_net is None: | ||
actor_network = torch.nn.Identity() | ||
else: | ||
actor_network = actor_net(network_output_dim) | ||
feature_dim = np.prod(calculate_output_dim(actor_network, network_output_dim)) | ||
actor_head = GaussianPolicyHead if self._continuous_action else CategoricalHead | ||
|
||
self.actor = torch.nn.Sequential( | ||
actor_network, | ||
torch.nn.Flatten(), | ||
actor_head(feature_dim, action_space), | ||
) | ||
|
||
if critic_net is None: | ||
critic_network = torch.nn.Identity() | ||
else: | ||
critic_network = critic_net(network_output_dim) | ||
feature_dim = np.prod(calculate_output_dim(critic_network, network_output_dim)) | ||
self.critic = torch.nn.Sequential( | ||
critic_network, | ||
torch.nn.Flatten(), | ||
torch.nn.Linear(feature_dim, 1), | ||
) | ||
|
||
def forward(self, x, action=None): | ||
hidden_state = self._network(x) | ||
distribution = self.actor(hidden_state) | ||
value = self.critic(hidden_state) | ||
if action is None: | ||
action = distribution.sample() | ||
|
||
logprob, entropy = distribution.log_prob(action), distribution.entropy() | ||
if self._continuous_action: | ||
logprob, entropy = logprob.sum(dim=-1), entropy.sum(dim=-1) | ||
return action, logprob, entropy, value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import abc | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
|
||
from hive.utils.registry import Registrable, registry | ||
|
||
|
||
# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py | ||
class MeanStd: | ||
"""Tracks the mean, variance and count of values.""" | ||
|
||
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm | ||
def __init__(self, epsilon=1e-4, shape=()): | ||
"""Tracks the mean, variance and count of values.""" | ||
self.mean = np.zeros(shape, "float64") | ||
self.var = np.ones(shape, "float64") | ||
self.count = epsilon | ||
|
||
def update(self, x): | ||
"""Updates the mean, var and count from a batch of samples.""" | ||
batch_mean = np.mean(x, axis=0) | ||
batch_var = np.var(x, axis=0) | ||
batch_count = x.shape[0] | ||
self.update_from_moments(batch_mean, batch_var, batch_count) | ||
|
||
def update_from_moments(self, batch_mean, batch_var, batch_count): | ||
"""Updates from batch mean, variance and count moments.""" | ||
self.mean, self.var, self.count = self.update_mean_var_count_from_moments( | ||
self.mean, self.var, self.count, batch_mean, batch_var, batch_count | ||
) | ||
|
||
def update_mean_var_count_from_moments( | ||
self, mean, var, count, batch_mean, batch_var, batch_count | ||
): | ||
"""Updates the mean, var and count using the previous mean, var, count | ||
and batch values.""" | ||
delta = batch_mean - mean | ||
tot_count = count + batch_count | ||
|
||
new_mean = mean + delta * batch_count / tot_count | ||
m_a = var * count | ||
m_b = batch_var * batch_count | ||
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count | ||
new_var = M2 / tot_count | ||
new_count = tot_count | ||
|
||
return new_mean, new_var, new_count | ||
|
||
def state_dict(self): | ||
"""Returns the state as a dictionary.""" | ||
return {"mean": self.mean, "var": self.var, "count": self.count} | ||
|
||
def load_state_dict(self, state_dict): | ||
"""Loads the state from a dictionary.""" | ||
self.mean = state_dict["mean"] | ||
self.var = state_dict["var"] | ||
self.count = state_dict["count"] | ||
|
||
|
||
class Normalizer(Registrable): | ||
"""A wrapper for callables that produce normalization functions. | ||
These wrapped callables can be partially initialized through configuration | ||
files or command line arguments. | ||
""" | ||
|
||
@classmethod | ||
def type_name(cls): | ||
""" | ||
Returns: | ||
"norm_fn" | ||
""" | ||
return "norm_fn" | ||
|
||
@abc.abstractmethod | ||
def state_dict(self): | ||
"""Returns the state of the normalizer as a dictionary.""" | ||
|
||
@abc.abstractmethod | ||
def load_state_dict(self, state_dict): | ||
"""Loads the normalizer state from a dictionary.""" | ||
|
||
|
||
class MovingAvgNormalizer(Normalizer): | ||
"""Implements a moving average normalization and clipping function. Normalizes | ||
input data with the running mean and std. The normalized data is then clipped | ||
within the specified range. | ||
""" | ||
|
||
def __init__( | ||
self, shape: Tuple[int, ...], epsilon: float = 1e-4, clip: np.float32 = np.inf | ||
): | ||
""" | ||
Args: | ||
epsilon (float): minimum value of variance to avoid division by 0. | ||
shape (tuple[int]): The shape of input data. | ||
clip (np.float32): The clip value for the normalised data. | ||
""" | ||
super().__init__() | ||
self._rms = MeanStd(epsilon, shape) | ||
self._shape = shape | ||
self._epsilon = epsilon | ||
self._clip = clip | ||
|
||
def __call__(self, input_data): | ||
input_data = np.array([input_data]) | ||
input_data = ( | ||
(input_data - self._rms.mean) / np.sqrt(self._rms.var + self._epsilon) | ||
)[0] | ||
if self._clip is not None: | ||
input_data = np.clip(input_data, -self._clip, self._clip) | ||
return input_data | ||
|
||
def update(self, input_data): | ||
self._rms.update(input_data) | ||
|
||
def state_dict(self): | ||
return self._rms.state_dict() | ||
|
||
def load_state_dict(self, state_dict): | ||
self._rms.load_state_dict(state_dict) | ||
|
||
|
||
class RewardNormalizer(Normalizer): | ||
"""Normalizes and clips rewards from the environment. Applies a discount-based | ||
scaling scheme, where the rewards are divided by the standard deviation of a | ||
rolling discounted sum of the rewards. The scaled rewards are then clipped within | ||
specified range. | ||
""" | ||
|
||
def __init__(self, gamma: float, epsilon: float = 1e-4, clip: np.float32 = np.inf): | ||
""" | ||
Args: | ||
gamma (float): discount factor for the agent. | ||
epsilon (float): minimum value of variance to avoid division by 0. | ||
clip (np.float32): The clip value for the normalised data. | ||
""" | ||
super().__init__() | ||
self._return_rms = MeanStd(epsilon, ()) | ||
self._epsilon = epsilon | ||
self._clip = clip | ||
self._gamma = gamma | ||
self._returns = np.zeros(1) | ||
|
||
def __call__(self, rew): | ||
rew = np.array([rew]) | ||
rew = (rew / np.sqrt(self._return_rms.var + self._epsilon))[0] | ||
if self._clip is not None: | ||
rew = np.clip(rew, -self._clip, self._clip) | ||
return rew | ||
|
||
def update(self, rew, done): | ||
self._returns = self._returns * self._gamma + rew | ||
self._return_rms.update(self._returns) | ||
self._returns *= 1 - done | ||
|
||
def state_dict(self): | ||
state_dict = self._return_rms.state_dict() | ||
state_dict["returns"] = self._returns | ||
return state_dict | ||
|
||
def load_state_dict(self, state_dict): | ||
self._returns = state_dict["returns"] | ||
state_dict.pop("returns") | ||
self._return_rms.load_state_dict(state_dict) | ||
|
||
|
||
registry.register_all( | ||
Normalizer, | ||
{ | ||
"RewardNormalizer": RewardNormalizer, | ||
"MovingAvgNormalizer": MovingAvgNormalizer, | ||
}, | ||
) | ||
|
||
get_norm_fn = getattr(registry, f"get_{Normalizer.type_name()}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
name: 'SingleAgentRunner' | ||
kwargs: | ||
experiment_manager: | ||
name: 'Experiment' | ||
kwargs: | ||
name: &run_name 'atari-ppo' | ||
save_dir: 'experiment' | ||
saving_schedule: | ||
name: 'PeriodicSchedule' | ||
kwargs: | ||
off_value: False | ||
on_value: True | ||
period: 1000000 | ||
|
||
train_steps: 10000000 | ||
test_frequency: 250000 | ||
test_episodes: 10 | ||
max_steps_per_episode: 27000 | ||
stack_size: &stack_size 4 | ||
environment: | ||
name: 'AtariEnv' | ||
kwargs: | ||
env_name: 'Breakout' | ||
|
||
agent: | ||
name: 'PPOAgent' | ||
kwargs: | ||
representation_net: | ||
name: 'ConvNetwork' | ||
kwargs: | ||
channels: [32, 64, 64] | ||
kernel_sizes: [8, 4, 3] | ||
strides: [4, 2, 1] | ||
paddings: [2, 2, 1] | ||
mlp_layers: [512] | ||
optimizer_fn: | ||
name: 'Adam' | ||
kwargs: | ||
lr: .00025 | ||
init_fn: | ||
name: 'orthogonal' | ||
replay_buffer: | ||
name: 'OnPolicyReplayBuffer' | ||
kwargs: | ||
stack_size: *stack_size | ||
compute_advantage_fn: | ||
name: "gae_advantages" | ||
kwargs: | ||
gae_lambda: 0.95 | ||
|
||
discount_rate: .99 | ||
grad_clip: .5 | ||
clip_coefficient: .1 | ||
entropy_coefficient: .0 | ||
clip_value_loss: True | ||
value_fn_coefficient: .5 | ||
transitions_per_update: 4096 | ||
num_epochs_per_update: 4 | ||
normalize_advantages: True | ||
batch_size: 256 | ||
device: 'cuda' | ||
id: 'agent' | ||
# List of logger configs used. | ||
loggers: | ||
- | ||
name: ChompLogger | ||
- | ||
name: WandbLogger | ||
kwargs: | ||
project: Hive | ||
name: *run_name | ||
resume: "allow" | ||
start_method: "fork" |
Oops, something went wrong.