-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Polish repository. Add IDQL pipelines. Polish all pipelines.
- Loading branch information
Showing
195 changed files
with
18,858 additions
and
3,580 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from typing import Dict, Union | ||
|
||
import numpy as np | ||
import torch | ||
from easydict import EasyDict | ||
|
||
from grl.agents import obs_transform, action_transform | ||
|
||
|
||
class IDQLAgent: | ||
""" | ||
Overview: | ||
The IDQL agent. | ||
Interface: | ||
``__init__``, ``action`` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
config: EasyDict, | ||
model: Union[torch.nn.Module, torch.nn.ModuleDict], | ||
): | ||
""" | ||
Overview: | ||
Initialize the agent. | ||
Arguments: | ||
config (:obj:`EasyDict`): The configuration. | ||
model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. | ||
""" | ||
|
||
self.config = config | ||
self.device = config.device | ||
self.model = model.to(self.device) | ||
|
||
def act( | ||
self, | ||
obs: Union[np.ndarray, torch.Tensor, Dict], | ||
return_as_torch_tensor: bool = False, | ||
) -> Union[np.ndarray, torch.Tensor, Dict]: | ||
""" | ||
Overview: | ||
Given an observation, return an action. | ||
Arguments: | ||
obs (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The observation. | ||
return_as_torch_tensor (:obj:`bool`): Whether to return the action as a torch tensor. | ||
Returns: | ||
action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. | ||
""" | ||
|
||
obs = obs_transform(obs, self.device) | ||
|
||
with torch.no_grad(): | ||
|
||
# --------------------------------------- | ||
# Customized inference code ↓ | ||
# --------------------------------------- | ||
|
||
obs = obs.unsqueeze(0) | ||
action = ( | ||
self.model["IDQLPolicy"] | ||
.get_action( | ||
state=obs, | ||
) | ||
.squeeze(0) | ||
.cpu() | ||
.detach() | ||
.numpy() | ||
) | ||
|
||
action = action_transform(action, return_as_torch_tensor) | ||
|
||
return action |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .base import BaseAlgorithm | ||
from .qgpo import QGPOAlgorithm, QGPOCritic, QGPOPolicy | ||
from .srpo import SRPOAlgorithm, SRPOCritic, SRPOPolicy | ||
from .gmpo import GMPOAlgorithm, GMPOCritic, GMPOPolicy | ||
from .gmpg import GMPGAlgorithm, GMPGCritic, GMPGPolicy | ||
from .idql import IDQLAlgorithm, IDQLCritic, IDQLPolicy | ||
from .qgpo import QGPOAlgorithm, QGPOCritic, QGPOPolicy | ||
from .srpo import SRPOAlgorithm, SRPOCritic, SRPOPolicy |
Oops, something went wrong.