Skip to content

Commit

Permalink
PPO (#272)
Browse files Browse the repository at this point in the history
* Adding initial PPO Code

* Added buffer sampling and solved some bugs

* ppo agent: device and type errors fixed

* ppo updater fixed

* ppo config updated

* Updated Hive to use gym spaces instead of raw tuples to represent action and observation spaces

* Updated tests to affect api change of 1106ec2

* Adding initial PPO Code

* Added buffer sampling and solved some bugs

* ppo agent: device and type errors fixed

* ppo updater fixed

* ppo config updated

* ppo replay added

* ppo replay fixed

* ppo agent updated

* ppo agent and config updated

* ppo code running but buggy

* cartpole working

* ppo configs

* ppo net fixed

* atari configs added

* ppo_nets done

* ppo_replay done

* ppo env wrappers added

* ppo agent done

* configs done

* stack size > 1 handled temporarily

* linting fixed

* last batch drop fix

* config changes

* shared network added

* reward wrapper added

* linting fixed

* docs fixed

* replay changed

* update loop

* type specification

* env wrappers registered

* linting fixed

* Removed one off transition, cleaned up replay buffer

* Fixed linter issues

* wrapper error fixed

* added vars to dict; fixed long lines and var names; moved wrapper registry

* config fixed

* addded normalisation and fixed log

* norm filed added

* norm bug fixed

* rew norm updated

* fixes

* fixing norm bug; config

* config fixes

* obs norm

* hardcoded wrappers added

* normaliser shape fixed

* rew shape fixed; norm structure updated

* rew norm

* configs and wrapper fixed

* Fixed formatting and naming

* Added env wrapper logic

* Renamed PPO Replay Buffer to On Policy Replay buffer

* Made PPO Stateless Agent

* Fixed linting issues

* Minor modifications

* Fixed changed

* Formatting and minor changes

* Refactored Advatange Computation

* Reformating with black

* Renaming

* Refactored Normalization code

* Added saving and loading of state dict for normalizers

* Fixed multiplayer replay buffer for PPO

* Fixed minor bug

* Renamed file

* Added lr annealing

---------

Co-authored-by: Sriyash <[email protected]>
Co-authored-by: sriyash.poddar <[email protected]>
Co-authored-by: Darshan Patil <[email protected]>
Co-authored-by: sriyash.poddar <[email protected]>
Co-authored-by: sriyash.poddar <[email protected]>
Co-authored-by: sriyash.poddar <[email protected]>
Co-authored-by: sriyash.poddar <[email protected]>
Co-authored-by: Sriyash Poddar <[email protected]>
  • Loading branch information
9 people authored Mar 28, 2023
1 parent 805ac4a commit 06c485c
Show file tree
Hide file tree
Showing 14 changed files with 1,258 additions and 3 deletions.
2 changes: 2 additions & 0 deletions hive/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from hive.agents.dqn import DQNAgent
from hive.agents.drqn import DRQNAgent
from hive.agents.legal_moves_rainbow import LegalMovesRainbowAgent
from hive.agents.ppo import PPOAgent
from hive.agents.rainbow import RainbowDQNAgent
from hive.agents.random import RandomAgent
from hive.agents.td3 import TD3
Expand All @@ -16,6 +17,7 @@
"DQNAgent": DQNAgent,
"DRQNAgent": DRQNAgent,
"LegalMovesRainbowAgent": LegalMovesRainbowAgent,
"PPOAgent": PPOAgent,
"RainbowDQNAgent": RainbowDQNAgent,
"RandomAgent": RandomAgent,
"TD3": TD3,
Expand Down
449 changes: 449 additions & 0 deletions hive/agents/ppo.py

Large diffs are not rendered by default.

119 changes: 119 additions & 0 deletions hive/agents/qnets/ac_nets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Tuple, Union

import gymnasium as gym
import numpy as np
import torch
from gymnasium.spaces import Box, Discrete

from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.utils import calculate_output_dim


class CategoricalHead(torch.nn.Module):
"""A module that implements a discrete actor head. It uses the ouput from
the :obj:`actor_net`, and adds creates a
:py:class:`~torch.distributions.categorical.Categorical` object to compute
the action distribution."""

def __init__(
self, feature_dim: Tuple[int], action_space: gym.spaces.Discrete
) -> None:
"""
Args:
feature dim: Expected output shape of the actor network.
action_shape: Expected shape of actions.
"""
super().__init__()
self.network = torch.nn.Linear(feature_dim, action_space.n)
self.distribution = torch.distributions.categorical.Categorical

def forward(self, x):
logits = self.network(x)
return self.distribution(logits=logits)


class GaussianPolicyHead(torch.nn.Module):
"""A module that implements a continuous actor head. It uses the output from the
:obj:`actor_net` and state independent learnable parameter :obj:`policy_logstd` to
create a :py:class:`~torch.distributions.normal.Normal` object to compute
the action distribution."""

def __init__(self, feature_dim: Tuple[int], action_space: gym.spaces.Box) -> None:
"""
Args:
feature dim: Expected output shape of the actor network.
action_shape: Expected shape of actions.
"""
super().__init__()
self._action_shape = action_space.shape
self.policy_mean = torch.nn.Sequential(
torch.nn.Linear(feature_dim, np.prod(self._action_shape))
)
self.policy_logstd = torch.nn.Parameter(
torch.zeros(1, np.prod(action_space.shape))
)
self.distribution = torch.distributions.normal.Normal

def forward(self, x):
_mean = self.policy_mean(x)
_std = self.policy_logstd.repeat(x.shape[0], 1).exp()
distribution = self.distribution(
torch.reshape(_mean, (x.size(0), *self._action_shape)),
torch.reshape(_std, (x.size(0), *self._action_shape)),
)
return distribution


class ActorCriticNetwork(torch.nn.Module):
"""A module that implements the actor and critic computation. It puts together
the :obj:`representation_network`, :obj:`actor_net` and :obj:`critic_net`, then
adds two final :py:class:`~torch.nn.Linear` layers to compute the action and state
value."""

def __init__(
self,
representation_network: torch.nn.Module,
actor_net: FunctionApproximator,
critic_net: FunctionApproximator,
network_output_dim: Union[int, Tuple[int]],
action_space: Union[Box, Discrete],
continuous_action: bool,
) -> None:
super().__init__()
self._network = representation_network
self._continuous_action = continuous_action
if actor_net is None:
actor_network = torch.nn.Identity()
else:
actor_network = actor_net(network_output_dim)
feature_dim = np.prod(calculate_output_dim(actor_network, network_output_dim))
actor_head = GaussianPolicyHead if self._continuous_action else CategoricalHead

self.actor = torch.nn.Sequential(
actor_network,
torch.nn.Flatten(),
actor_head(feature_dim, action_space),
)

if critic_net is None:
critic_network = torch.nn.Identity()
else:
critic_network = critic_net(network_output_dim)
feature_dim = np.prod(calculate_output_dim(critic_network, network_output_dim))
self.critic = torch.nn.Sequential(
critic_network,
torch.nn.Flatten(),
torch.nn.Linear(feature_dim, 1),
)

def forward(self, x, action=None):
hidden_state = self._network(x)
distribution = self.actor(hidden_state)
value = self.critic(hidden_state)
if action is None:
action = distribution.sample()

logprob, entropy = distribution.log_prob(action), distribution.entropy()
if self._continuous_action:
logprob, entropy = logprob.sum(dim=-1), entropy.sum(dim=-1)
return action, logprob, entropy, value
177 changes: 177 additions & 0 deletions hive/agents/qnets/normalizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import abc
from typing import Tuple

import numpy as np

from hive.utils.registry import Registrable, registry


# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class MeanStd:
"""Tracks the mean, variance and count of values."""

# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self, epsilon=1e-4, shape=()):
"""Tracks the mean, variance and count of values."""
self.mean = np.zeros(shape, "float64")
self.var = np.ones(shape, "float64")
self.count = epsilon

def update(self, x):
"""Updates the mean, var and count from a batch of samples."""
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)

def update_from_moments(self, batch_mean, batch_var, batch_count):
"""Updates from batch mean, variance and count moments."""
self.mean, self.var, self.count = self.update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count
)

def update_mean_var_count_from_moments(
self, mean, var, count, batch_mean, batch_var, batch_count
):
"""Updates the mean, var and count using the previous mean, var, count
and batch values."""
delta = batch_mean - mean
tot_count = count + batch_count

new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count

return new_mean, new_var, new_count

def state_dict(self):
"""Returns the state as a dictionary."""
return {"mean": self.mean, "var": self.var, "count": self.count}

def load_state_dict(self, state_dict):
"""Loads the state from a dictionary."""
self.mean = state_dict["mean"]
self.var = state_dict["var"]
self.count = state_dict["count"]


class Normalizer(Registrable):
"""A wrapper for callables that produce normalization functions.
These wrapped callables can be partially initialized through configuration
files or command line arguments.
"""

@classmethod
def type_name(cls):
"""
Returns:
"norm_fn"
"""
return "norm_fn"

@abc.abstractmethod
def state_dict(self):
"""Returns the state of the normalizer as a dictionary."""

@abc.abstractmethod
def load_state_dict(self, state_dict):
"""Loads the normalizer state from a dictionary."""


class MovingAvgNormalizer(Normalizer):
"""Implements a moving average normalization and clipping function. Normalizes
input data with the running mean and std. The normalized data is then clipped
within the specified range.
"""

def __init__(
self, shape: Tuple[int, ...], epsilon: float = 1e-4, clip: np.float32 = np.inf
):
"""
Args:
epsilon (float): minimum value of variance to avoid division by 0.
shape (tuple[int]): The shape of input data.
clip (np.float32): The clip value for the normalised data.
"""
super().__init__()
self._rms = MeanStd(epsilon, shape)
self._shape = shape
self._epsilon = epsilon
self._clip = clip

def __call__(self, input_data):
input_data = np.array([input_data])
input_data = (
(input_data - self._rms.mean) / np.sqrt(self._rms.var + self._epsilon)
)[0]
if self._clip is not None:
input_data = np.clip(input_data, -self._clip, self._clip)
return input_data

def update(self, input_data):
self._rms.update(input_data)

def state_dict(self):
return self._rms.state_dict()

def load_state_dict(self, state_dict):
self._rms.load_state_dict(state_dict)


class RewardNormalizer(Normalizer):
"""Normalizes and clips rewards from the environment. Applies a discount-based
scaling scheme, where the rewards are divided by the standard deviation of a
rolling discounted sum of the rewards. The scaled rewards are then clipped within
specified range.
"""

def __init__(self, gamma: float, epsilon: float = 1e-4, clip: np.float32 = np.inf):
"""
Args:
gamma (float): discount factor for the agent.
epsilon (float): minimum value of variance to avoid division by 0.
clip (np.float32): The clip value for the normalised data.
"""
super().__init__()
self._return_rms = MeanStd(epsilon, ())
self._epsilon = epsilon
self._clip = clip
self._gamma = gamma
self._returns = np.zeros(1)

def __call__(self, rew):
rew = np.array([rew])
rew = (rew / np.sqrt(self._return_rms.var + self._epsilon))[0]
if self._clip is not None:
rew = np.clip(rew, -self._clip, self._clip)
return rew

def update(self, rew, done):
self._returns = self._returns * self._gamma + rew
self._return_rms.update(self._returns)
self._returns *= 1 - done

def state_dict(self):
state_dict = self._return_rms.state_dict()
state_dict["returns"] = self._returns
return state_dict

def load_state_dict(self, state_dict):
self._returns = state_dict["returns"]
state_dict.pop("returns")
self._return_rms.load_state_dict(state_dict)


registry.register_all(
Normalizer,
{
"RewardNormalizer": RewardNormalizer,
"MovingAvgNormalizer": MovingAvgNormalizer,
},
)

get_norm_fn = getattr(registry, f"get_{Normalizer.type_name()}")
73 changes: 73 additions & 0 deletions hive/configs/atari/ppo.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
name: 'SingleAgentRunner'
kwargs:
experiment_manager:
name: 'Experiment'
kwargs:
name: &run_name 'atari-ppo'
save_dir: 'experiment'
saving_schedule:
name: 'PeriodicSchedule'
kwargs:
off_value: False
on_value: True
period: 1000000

train_steps: 10000000
test_frequency: 250000
test_episodes: 10
max_steps_per_episode: 27000
stack_size: &stack_size 4
environment:
name: 'AtariEnv'
kwargs:
env_name: 'Breakout'

agent:
name: 'PPOAgent'
kwargs:
representation_net:
name: 'ConvNetwork'
kwargs:
channels: [32, 64, 64]
kernel_sizes: [8, 4, 3]
strides: [4, 2, 1]
paddings: [2, 2, 1]
mlp_layers: [512]
optimizer_fn:
name: 'Adam'
kwargs:
lr: .00025
init_fn:
name: 'orthogonal'
replay_buffer:
name: 'OnPolicyReplayBuffer'
kwargs:
stack_size: *stack_size
compute_advantage_fn:
name: "gae_advantages"
kwargs:
gae_lambda: 0.95

discount_rate: .99
grad_clip: .5
clip_coefficient: .1
entropy_coefficient: .0
clip_value_loss: True
value_fn_coefficient: .5
transitions_per_update: 4096
num_epochs_per_update: 4
normalize_advantages: True
batch_size: 256
device: 'cuda'
id: 'agent'
# List of logger configs used.
loggers:
-
name: ChompLogger
-
name: WandbLogger
kwargs:
project: Hive
name: *run_name
resume: "allow"
start_method: "fork"
Loading

0 comments on commit 06c485c

Please sign in to comment.