Skip to content

Commit

Permalink
Polish repository. Add IDQL pipelines. Polish all pipelines.
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Oct 30, 2024
1 parent 7cf37d8 commit 5867ed7
Show file tree
Hide file tree
Showing 195 changed files with 18,858 additions and 3,580 deletions.
4 changes: 4 additions & 0 deletions grl/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict
import torch
import numpy as np
from tensordict import TensorDict


def obs_transform(obs, device):
Expand All @@ -11,6 +12,8 @@ def obs_transform(obs, device):
obs = {k: torch.from_numpy(v).float().to(device) for k, v in obs.items()}
elif isinstance(obs, torch.Tensor):
obs = obs.float().to(device)
elif isinstance(obs, TensorDict):
obs = obs.to(device)
else:
raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray")

Expand Down Expand Up @@ -40,3 +43,4 @@ def action_transform(action, return_as_torch_tensor: bool = False):
from .qgpo import QGPOAgent
from .srpo import SRPOAgent
from .gm import GPAgent
from .idql import IDQLAgent
5 changes: 5 additions & 0 deletions grl/agents/gm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def act(
if self.config.t_span is not None
else None
),
solver_config=(
self.config.solver_config
if hasattr(self.config, "solver_config")
else None
),
)
.squeeze(0)
.cpu()
Expand Down
72 changes: 72 additions & 0 deletions grl/agents/idql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Dict, Union

import numpy as np
import torch
from easydict import EasyDict

from grl.agents import obs_transform, action_transform


class IDQLAgent:
"""
Overview:
The IDQL agent.
Interface:
``__init__``, ``action``
"""

def __init__(
self,
config: EasyDict,
model: Union[torch.nn.Module, torch.nn.ModuleDict],
):
"""
Overview:
Initialize the agent.
Arguments:
config (:obj:`EasyDict`): The configuration.
model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model.
"""

self.config = config
self.device = config.device
self.model = model.to(self.device)

def act(
self,
obs: Union[np.ndarray, torch.Tensor, Dict],
return_as_torch_tensor: bool = False,
) -> Union[np.ndarray, torch.Tensor, Dict]:
"""
Overview:
Given an observation, return an action.
Arguments:
obs (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The observation.
return_as_torch_tensor (:obj:`bool`): Whether to return the action as a torch tensor.
Returns:
action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action.
"""

obs = obs_transform(obs, self.device)

with torch.no_grad():

# ---------------------------------------
# Customized inference code ↓
# ---------------------------------------

obs = obs.unsqueeze(0)
action = (
self.model["IDQLPolicy"]
.get_action(
state=obs,
)
.squeeze(0)
.cpu()
.detach()
.numpy()
)

action = action_transform(action, return_as_torch_tensor)

return action
9 changes: 5 additions & 4 deletions grl/agents/srpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class SRPOAgent:
"""
Overview:
The QGPO agent.
The SRPO agent.
Interface:
``__init__``, ``action``
"""
Expand Down Expand Up @@ -54,9 +54,10 @@ def act(
# ---------------------------------------
# Customized inference code ↓
# ---------------------------------------

action = self.model(obs)

obs = obs.unsqueeze(0)
action = (
self.model["SRPOPolicy"].policy(obs).squeeze(0).detach().cpu().numpy()
)
# ---------------------------------------
# Customized inference code ↑
# ---------------------------------------
Expand Down
5 changes: 3 additions & 2 deletions grl/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base import BaseAlgorithm
from .qgpo import QGPOAlgorithm, QGPOCritic, QGPOPolicy
from .srpo import SRPOAlgorithm, SRPOCritic, SRPOPolicy
from .gmpo import GMPOAlgorithm, GMPOCritic, GMPOPolicy
from .gmpg import GMPGAlgorithm, GMPGCritic, GMPGPolicy
from .idql import IDQLAlgorithm, IDQLCritic, IDQLPolicy
from .qgpo import QGPOAlgorithm, QGPOCritic, QGPOPolicy
from .srpo import SRPOAlgorithm, SRPOCritic, SRPOPolicy
Loading

0 comments on commit 5867ed7

Please sign in to comment.