Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPO #272

Merged
merged 86 commits into from
Mar 28, 2023
Merged

PPO #272

Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
dae80b6
Adding initial PPO Code
kshitijkg May 3, 2022
f817ce9
Added buffer sampling and solved some bugs
kshitijkg May 3, 2022
4de8417
ppo agent: device and type errors fixed
sriyash421 May 3, 2022
bc5817e
ppo updater fixed
sriyash421 May 4, 2022
f422b36
ppo config updated
May 4, 2022
d1b7ee2
Updated Hive to use gym spaces instead of raw tuples to represent act…
dapatil211 Apr 20, 2022
81ac508
Updated tests to affect api change of 1106ec2
dapatil211 Apr 20, 2022
e2c1133
Adding initial PPO Code
kshitijkg May 3, 2022
6eb1e4e
Added buffer sampling and solved some bugs
kshitijkg May 3, 2022
2277abd
ppo agent: device and type errors fixed
sriyash421 May 3, 2022
d780683
ppo updater fixed
sriyash421 May 4, 2022
e0cc263
ppo config updated
May 4, 2022
ab5ef03
ppo replay added
sriyash421 May 6, 2022
301e77c
ppo replay conflict
sriyash421 May 6, 2022
6459606
ppo replay fixed
sriyash421 May 6, 2022
babfcce
ppo agent updated
sriyash421 May 6, 2022
c0f9039
ppo agent and config updated
sriyash421 May 6, 2022
a973cb3
ppo code running but buggy
May 9, 2022
0c78e2d
cartpole working
May 12, 2022
f5271bf
ppo configs
May 18, 2022
ad3ed9b
ppo net fixed
May 24, 2022
0f63802
merge dev
May 24, 2022
f598433
atari configs added
May 24, 2022
a3c3c1c
ppo_nets done
May 24, 2022
47e106f
ppo_replay done
May 24, 2022
12a4b73
ppo env wrappers added
May 24, 2022
277b9bc
ppo agent done
May 24, 2022
9417e05
configs done
May 24, 2022
d3616a7
stack size > 1 handled temporarily
May 27, 2022
9de6e3a
linting fixed
sriyash421 May 27, 2022
c151b7b
Merge branch 'dev' into ppo_spaces
sriyash421 Jun 27, 2022
7201e97
last batch drop fix
sriyash421 Jun 27, 2022
45a9da4
config changes
sriyash421 Jun 29, 2022
4ea7527
Merge branch 'ppo_spaces' of github.com:chandar-lab/RLHive into ppo_s…
sriyash421 Jun 29, 2022
a9848e1
shared network added
sriyash421 Jul 7, 2022
3adcb73
Merge branch 'dev' into ppo_spaces
sriyash421 Jul 7, 2022
54996f3
reward wrapper added
sriyash421 Jul 13, 2022
fa9297b
linting fixed
sriyash421 Jul 13, 2022
2ac73ba
Merge branch 'dev' into ppo_spaces
sriyash421 Jul 13, 2022
3d11136
Merge branch 'dev' into ppo_spaces
sriyash421 Jul 29, 2022
4f5b4b8
docs fixed
sriyash421 Aug 28, 2022
cc4aab8
replay changed
sriyash421 Aug 28, 2022
2d55afa
update loop
sriyash421 Aug 28, 2022
dcf2aa7
type specification
sriyash421 Aug 28, 2022
936c3b1
env wrappers registered
sriyash421 Aug 28, 2022
04d8692
linting fixed
sriyash421 Aug 29, 2022
360dc00
Merge branch 'dev' into ppo_spaces
kshitijkg Sep 25, 2022
b1613cd
Removed one off transition, cleaned up replay buffer
kshitijkg Sep 25, 2022
bdcd11e
Fixed linter issues
kshitijkg Sep 25, 2022
5a8e2da
wrapper error fixed
sriyash421 Sep 29, 2022
a54377a
added vars to dict; fixed long lines and var names; moved wrapper reg…
sriyash421 Oct 11, 2022
9680185
config fixed
sriyash421 Oct 13, 2022
2c9295f
addded normalisation and fixed log
sriyash421 Oct 13, 2022
767f96c
norm filed added
sriyash421 Oct 14, 2022
b4f2ea1
norm bug fixed
sriyash421 Nov 3, 2022
58f5ec2
rew norm updated
sriyash421 Nov 11, 2022
306faea
fixes
sriyash421 Nov 11, 2022
35d6aeb
fixing norm bug; config
sriyash421 Nov 23, 2022
7d31faf
config fixes
sriyash421 Nov 23, 2022
b84722e
obs norm
sriyash421 Nov 24, 2022
a4c1692
hardcoded wrappers added
sriyash421 Nov 24, 2022
11ccb21
normaliser shape fixed
sriyash421 Dec 6, 2022
0991e84
rew shape fixed; norm structure updated
sriyash421 Dec 6, 2022
c7f42a1
rew norm
sriyash421 Dec 6, 2022
84d933e
configs and wrapper fixed
sriyash421 Dec 7, 2022
3f01532
merge dev
sriyash421 Dec 19, 2022
54799c2
Merge branch 'dev' into ppo_spaces
sriyash421 Dec 19, 2022
8fb9902
Fixed formatting and naming
kshitijkg Jan 30, 2023
bd5c587
Added env wrapper logic
kshitijkg Jan 30, 2023
697a78c
Merging dev
kshitijkg Jan 30, 2023
a1e77fa
Renamed PPO Replay Buffer to On Policy Replay buffer
kshitijkg Jan 30, 2023
031f462
Made PPO Stateless Agent
kshitijkg Jan 30, 2023
28733ec
Fixed linting issues
kshitijkg Jan 30, 2023
8885a89
Minor modifications
kshitijkg Feb 7, 2023
0e42146
Fixed changed
kshitijkg Feb 8, 2023
d785c85
Formatting and minor changes
kshitijkg Mar 2, 2023
4946874
Merge branch 'dev' into ppo_spaces
dapatil211 Mar 20, 2023
308f111
Refactored Advatange Computation
kshitijkg Mar 21, 2023
543fc74
Reformating with black
kshitijkg Mar 21, 2023
43c3fb1
Renaming
kshitijkg Mar 21, 2023
4d82f99
Refactored Normalization code
kshitijkg Mar 21, 2023
e7d08d5
Added saving and loading of state dict for normalizers
kshitijkg Mar 21, 2023
aba7c49
Fixed multiplayer replay buffer for PPO
kshitijkg Mar 21, 2023
000c4e4
Fixed minor bug
kshitijkg Mar 22, 2023
3d6d076
Renamed file
kshitijkg Mar 22, 2023
aabeed0
Added lr annealing
dapatil211 Mar 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 53 additions & 41 deletions hive/agents/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

from hive.agents.agent import Agent
from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.normalizer import NormalizationFn
from hive.agents.qnets.ppo_nets import PPOActorCriticNetwork
from hive.agents.qnets.normalizer import (
MovingAvgNormalizer,
RewardNormalizer,
)
from hive.agents.qnets.ac_nets import ActorCriticNetwork
from hive.agents.qnets.utils import (
InitializationFn,
calculate_output_dim,
Expand All @@ -33,8 +36,8 @@ def __init__(
init_fn: InitializationFn = None,
optimizer_fn: OptimizerFn = None,
critic_loss_fn: LossFn = None,
observation_normalization_fn: NormalizationFn = None,
reward_normalization_fn: NormalizationFn = None,
observation_normalizer: MovingAvgNormalizer = None,
reward_normalizer: RewardNormalizer = None,
stack_size: int = 1,
replay_buffer: OnPolicyReplayBuffer = None,
discount_rate: float = 0.99,
Expand Down Expand Up @@ -75,9 +78,10 @@ def __init__(
If None, defaults to :py:class:`~torch.optim.Adam`.
critic_loss_fn (LossFn): The loss function used to optimize the critic. If
None, defaults to :py:class:`~torch.nn.MSELoss`.
observation_normalizer (NormalizationFn): The function for normalizing
observations
reward_normalizer (NormalizationFn): The function for normalizing rewards
observation_normalizer (MovingAvgNormalizer): The function for
normalizing observations
reward_normalizer (RewardNormalizer): The function for normalizing
rewards
stack_size (int): Number of observations stacked to create the state fed
to the agent.
replay_buffer (OnPolicyReplayBuffer): The replay buffer that the agent will
Expand Down Expand Up @@ -123,17 +127,15 @@ def __init__(
actor_net,
critic_net,
)
if observation_normalization_fn is not None:
self._observation_normalization_fn = observation_normalization_fn(
self._state_size
)
if observation_normalizer is not None:
self._observation_normalizer = observation_normalizer(self._state_size)
else:
self._observation_normalization_fn = None
self._observation_normalizer = None

if reward_normalization_fn is not None:
self._reward_normalization_fn = reward_normalization_fn(discount_rate)
if reward_normalizer is not None:
self._reward_normalizer = reward_normalizer(discount_rate)
else:
self._reward_normalization_fn = None
self._reward_normalizer = None

if optimizer_fn is None:
optimizer_fn = torch.optim.Adam
Expand Down Expand Up @@ -187,7 +189,7 @@ def create_networks(self, representation_net, actor_net, critic_net):
network = representation_net(self._state_size)

network_output_shape = calculate_output_dim(network, self._state_size)
self._actor_critic = PPOActorCriticNetwork(
self._actor_critic = ActorCriticNetwork(
network,
actor_net,
critic_net,
Expand Down Expand Up @@ -215,15 +217,15 @@ def preprocess_update_info(self, update_info, agent_traj_state):
update_info: Contains the information from the current timestep that the
agent should use to update itself.
"""
if self._observation_normalization_fn:
update_info["observation"] = self._observation_normalization_fn(
if self._observation_normalizer:
update_info["observation"] = self._observation_normalizer(
update_info["observation"]
)

done = update_info["terminated"] or update_info["truncated"]
if self._reward_normalization_fn:
self._reward_normalization_fn.update(update_info["reward"], done)
update_info["reward"] = self._reward_normalization_fn(update_info["reward"])
if self._reward_normalizer:
self._reward_normalizer.update(update_info["reward"], done)
update_info["reward"] = self._reward_normalizer(update_info["reward"])

preprocessed_update_info = {
"observation": update_info["observation"],
Expand Down Expand Up @@ -278,9 +280,9 @@ def act(self, observation, agent_traj_state=None):
"""
if agent_traj_state is None:
agent_traj_state = {}
if self._observation_normalization_fn:
self._observation_normalization_fn.update(observation)
observation = self._observation_normalization_fn(observation)
if self._observation_normalizer:
self._observation_normalizer.update(observation)
observation = self._observation_normalizer(observation)
action, logprob, value = self.get_action_logprob_value(observation)
agent_traj_state["logprob"] = logprob
agent_traj_state["value"] = value
Expand All @@ -305,8 +307,8 @@ def update(self, update_info, agent_traj_state=None):
)

if self._replay_buffer.size() >= self._transitions_per_update - 1:
if self._observation_normalization_fn:
update_info["next_observation"] = self._observation_normalization_fn(
if self._observation_normalizer:
update_info["next_observation"] = self._observation_normalizer(
update_info["next_observation"]
)
_, _, values = self.get_action_logprob_value(
Expand Down Expand Up @@ -376,29 +378,33 @@ def update(self, update_info, agent_traj_state=None):
approx_kl = ((ratios - 1) - logratios).mean()

if self._logger.should_log(self._timescale):
self._logger.log_scalar(
"actor_loss", actor_loss, self._timescale
)
self._logger.log_scalar(
"critic_loss", critic_loss, self._timescale
)
self._logger.log_scalar(
"entropy_loss", entr_loss, self._timescale
)
self._logger.log_scalar("approx_kl", approx_kl, self._timescale)

if self._target_kl is not None:
if approx_kl > self._target_kl:
break
metrics = {
"actor_loss": actor_loss,
"critic_loss": critic_loss,
"entropy_loss": entr_loss,
"approx_kl": approx_kl,
}
self._logger.log_metrics(metrics, prefix=self._timescale)
if self._target_kl is not None and self._target_kl < approx_kl:
break
self._replay_buffer.reset()
return agent_traj_state

def save(self, dname):
torch.save(
state_dict = (
{
"actor_critic": self._actor_critic.state_dict(),
"optimizer": self._optimizer.state_dict(),
},
)
if self._observation_normalizer:
state_dict[
"observation_normalizer"
] = self._observation_normalizer.state_dict()
if self._reward_normalizer:
state_dict["reward_normalizer"] = self._reward_normalizer.state_dict()
torch.save(
state_dict,
os.path.join(dname, "agent.pt"),
)
replay_dir = os.path.join(dname, "replay")
Expand All @@ -410,3 +416,9 @@ def load(self, dname):
self._actor_critic.load_state_dict(checkpoint["actor_critic"])
self._optimizer.load_state_dict(checkpoint["optimizer"])
self._replay_buffer.load(os.path.join(dname, "replay"))
if self._observation_normalizer:
self._observation_normalizer.load_state_dict(
checkpoint["observation_normalizer"]
)
if self._reward_normalizer:
self._reward_normalizer.load_state_dict(checkpoint["reward_normalizer"])
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def forward(self, x):
return distribution


class PPOActorCriticNetwork(torch.nn.Module):
"""A module that implements the PPO actor and critic computation. It puts together
class ActorCriticNetwork(torch.nn.Module):
"""A module that implements the actor and critic computation. It puts together
the :obj:`representation_network`, :obj:`actor_net` and :obj:`critic_net`, then
adds two final :py:class:`~torch.nn.Linear` layers to compute the action and state
value."""
Expand Down
115 changes: 69 additions & 46 deletions hive/agents/qnets/normalizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
from typing import Tuple

import numpy as np
Expand Down Expand Up @@ -46,24 +47,44 @@ def update_mean_var_count_from_moments(

return new_mean, new_var, new_count

def state_dict(self):
"""Returns the state as a dictionary."""
return {"mean": self.mean, "var": self.var, "count": self.count}

class BaseNormalizationFn(object):
"""Implements the base normalization function."""
def load_state_dict(self, state_dict):
"""Loads the state from a dictionary."""
self.mean = state_dict["mean"]
self.var = state_dict["var"]
self.count = state_dict["count"]

def __init__(self, *args, **kwds):
pass

def __call__(self, *args, **kwds):
return NotImplementedError
class Normalizer(Registrable):
"""A wrapper for callables that produce normalization functions.

def update(self, *args, **kwds):
return NotImplementedError
These wrapped callables can be partially initialized through configuration
files or command line arguments.
"""

@classmethod
def type_name(cls):
"""
Returns:
"norm_fn"
"""
return "norm_fn"

@abc.abstractmethod
def state_dict(self):
"""Returns the state of the normalizer as a dictionary."""

class ObservationNormalizationFn(BaseNormalizationFn):
"""Implements a normalization function. Transforms output by
normalising the input data by the running :obj:`mean` and
:obj:`std`, and clipping the normalised data on :obj:`clip`
def load_state_dict(self, state_dict):
kshitijkg marked this conversation as resolved.
Show resolved Hide resolved
"""Loads the normalizer state from a dictionary."""


class MovingAvgNormalizer(Normalizer):
"""Implements a moving average normalization and clipping function. Normalizes
input data with the running mean and std. The normalized data is then clipped
within the specified range.
"""

def __init__(
Expand All @@ -76,26 +97,35 @@ def __init__(
clip (np.float32): The clip value for the normalised data.
"""
super().__init__()
self.obs_rms = MeanStd(epsilon, shape)
self._rms = MeanStd(epsilon, shape)
self._shape = shape
self._epsilon = epsilon
self._clip = clip

def __call__(self, obs):
obs = np.array([obs])
obs = ((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self._epsilon))[0]
def __call__(self, input_data):
input_data = np.array([input_data])
input_data = (
(input_data - self._rms.mean) / np.sqrt(self._rms.var + self._epsilon)
)[0]
if self._clip is not None:
obs = np.clip(obs, -self._clip, self._clip)
return obs
input_data = np.clip(input_data, -self._clip, self._clip)
return input_data

def update(self, input_data):
self._rms.update(input_data)

def update(self, obs):
self.obs_rms.update(obs)
def state_dict(self):
return self._rms.state_dict()

def load_state_dict(self, state_dict):
self._rms.load_state_dict(state_dict)

class RewardNormalizationFn(BaseNormalizationFn):
"""Implements a normalization function. Transforms output by
normalising the input data by the running :obj:`mean` and
:obj:`std`, and clipping the normalised data on :obj:`clip`

class RewardNormalizer(Normalizer):
"""Normalizes and clips rewards from the environment. Applies a discount-based
scaling scheme, where the rewards are divided by the standard deviation of a
rolling discounted sum of the rewards. The scaled rewards are then clipped within
specified range.
"""

def __init__(self, gamma: float, epsilon: float = 1e-4, clip: np.float32 = np.inf):
Expand All @@ -106,48 +136,41 @@ def __init__(self, gamma: float, epsilon: float = 1e-4, clip: np.float32 = np.in
clip (np.float32): The clip value for the normalised data.
"""
super().__init__()
self.return_rms = MeanStd(epsilon, ())
self._return_rms = MeanStd(epsilon, ())
self._epsilon = epsilon
self._clip = clip
self._gamma = gamma
self._returns = np.zeros(1)

def __call__(self, rew):
rew = np.array([rew])
rew = (rew / np.sqrt(self.return_rms.var + self._epsilon))[0]
rew = (rew / np.sqrt(self._return_rms.var + self._epsilon))[0]
if self._clip is not None:
rew = np.clip(rew, -self._clip, self._clip)
return rew

def update(self, rew, done):
self._returns = self._returns * self._gamma + rew
self.return_rms.update(self._returns)
self._return_rms.update(self._returns)
self._returns *= 1 - done

def state_dict(self):
state_dict = self._return_rms.state_dict()
state_dict["returns"] = self._returns
return state_dict

class NormalizationFn(Registrable):
"""A wrapper for callables that produce normalization functions.

These wrapped callables can be partially initialized through configuration
files or command line arguments.
"""

@classmethod
def type_name(cls):
"""
Returns:
"norm_fn"
"""
return "norm_fn"
def load_state_dict(self, state_dict):
self._returns = state_dict["returns"]
state_dict.pop("returns")
self._return_rms.load_state_dict(state_dict)


registry.register_all(
NormalizationFn,
Normalizer,
{
"BaseNormalization": BaseNormalizationFn,
"RewardNormalization": RewardNormalizationFn,
"ObservationNormalization": ObservationNormalizationFn,
"RewardNormalizer": RewardNormalizer,
"MovingAvgNormalizer": MovingAvgNormalizer,
},
)

get_norm_fn = getattr(registry, f"get_{NormalizationFn.type_name()}")
get_norm_fn = getattr(registry, f"get_{Normalizer.type_name()}")
8 changes: 4 additions & 4 deletions hive/configs/mujoco/ppo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ kwargs:
name: 'MLPNetwork'
kwargs:
hidden_units: [64, 64]
observation_normalization_fn:
name: 'ObservationNormalization'
observation_normalizer:
name: 'MovingAvgNormalizer'
kwargs:
clip: 10
reward_normalization_fn:
name: 'RewardNormalization'
reward_normalizer:
name: 'RewardNormalizer'
kwargs:
clip: 10
replay_buffer:
Expand Down
2 changes: 1 addition & 1 deletion hive/envs/gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def create_env(self, env_name, env_wrappers, **kwargs):
Args:
env_name (str): Name of the environment
"""
self._env = gym.make(env_name)
self._env = gym.make(env_name, **kwargs)

if env_wrappers is not None:
self._env = apply_wrappers(self._env, env_wrappers)
Expand Down