Skip to content

Commit

Permalink
Polish APIs and documents.
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Jun 16, 2024
1 parent 3eab74d commit 7dee5ad
Show file tree
Hide file tree
Showing 37 changed files with 578 additions and 350 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ Here is an example of how to train a diffusion model for Q-guided policy optimiz

Install the required dependencies:
```bash
pip install gym[box2d]==0.23.1
pip install 'gym[box2d]==0.23.1'
```
(The gym version can be from 0.23 to 0.25 for box2d environments, but it is recommended to use 0.23.1 for compatibility with D4RL.)

Download dataset from [here](https://drive.google.com/file/d/1YnT-Oeu9LPKuS_ZqNc5kol_pMlJ1DwyG/view?usp=drive_link) and save it as `data.npz` in the current directory.

Expand Down
3 changes: 2 additions & 1 deletion README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ docker run -it --rm --gpus all opendilab/grl:torch2.3.0-cuda12.1-cudnn8-runtime

安装所需依赖:
```bash
pip install gym[box2d]==0.23.1
pip install 'gym[box2d]==0.23.1'
```
(此处的 gym 版本可以为'0.23~0.25', 但为了同时兼容 D4RL 推荐使用 '0.23.1'。)

数据集可以从 [这里](https://drive.google.com/file/d/1YnT-Oeu9LPKuS_ZqNc5kol_pMlJ1DwyG/view?usp=drive_link) 下载,请将其置于工作路径下,并命名为 `data.npz`

Expand Down
12 changes: 10 additions & 2 deletions docs/source/tutorials/quick_start/index.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
Quick Start
===========

GenerativeRL provides a simple and flexible interface for training and deploying reinforcement learning agents powered by generative models. Here's an example of how to use the library to train a Q-guided policy optimization (QGPO) agent on the HalfCheetah environment and deploy it for evaluation.
Generative model in GenerativeRL
---------

GenerativeRL support easy-to-use APIs for training and deploying generative model.
We provide a simple example of how to train a diffusion model on the swiss roll dataset in [Colab](https://colab.research.google.com/drive/18yHUAmcMh_7xq2U6TBCtcLKX2y4YvNyk?usp=drive_link).

More usage examples can be found in the folder `grl_pipelines/tutorials/`.

Code Example
Reinforcement Learning
------------

GenerativeRL provides a simple and flexible interface for training and deploying reinforcement learning agents powered by generative models. Here's an example of how to use the library to train a Q-guided policy optimization (QGPO) agent on the HalfCheetah environment and deploy it for evaluation.

.. code-block:: python
from grl_pipelines.diffusion_model.configurations.halfcheetah_qgpo import config
Expand Down
38 changes: 38 additions & 0 deletions grl/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,41 @@
from typing import Dict
import torch
import numpy as np


def obs_transform(obs, device):

if isinstance(obs, np.ndarray):
obs = torch.from_numpy(obs).float().to(device)
elif isinstance(obs, Dict):
obs = {k: torch.from_numpy(v).float().to(device) for k, v in obs.items()}
elif isinstance(obs, torch.Tensor):
obs = obs.float().to(device)
else:
raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray")

return obs


def action_transform(action, return_as_torch_tensor: bool = False):
if isinstance(action, Dict):
if return_as_torch_tensor:
action = {k: v.cpu() for k, v in action.items()}
else:
action = {k: v.cpu().numpy() for k, v in action.items()}
elif isinstance(action, torch.Tensor):
if return_as_torch_tensor:
action = action.cpu()
else:
action = action.numpy()
elif isinstance(action, np.ndarray):
pass
else:
raise ValueError("action must be a dict, torch.Tensor, or np.ndarray")

return action


from .base import BaseAgent
from .qgpo import QGPOAgent
from .gm import GPAgent
30 changes: 5 additions & 25 deletions grl/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, Dict, List, Tuple, Union
from typing import Dict, Union

import numpy as np
import torch
from easydict import EasyDict

from grl.agents import obs_transform, action_transform


class BaseAgent:

Expand Down Expand Up @@ -39,16 +41,7 @@ def act(
action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action.
"""

if isinstance(obs, np.ndarray):
obs = torch.from_numpy(obs).float().to(self.device)
elif isinstance(obs, Dict):
obs = {
k: torch.from_numpy(v).float().to(self.device) for k, v in obs.items()
}
elif isinstance(obs, torch.Tensor):
obs = obs.float().to(self.device)
else:
raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray")
obs = obs_transform(obs, self.device)

with torch.no_grad():

Expand All @@ -62,19 +55,6 @@ def act(
# Customized inference code ↑
# ---------------------------------------

if isinstance(action, Dict):
if return_as_torch_tensor:
action = {k: v.cpu() for k, v in action.items()}
else:
action = {k: v.cpu().numpy() for k, v in action.items()}
elif isinstance(action, torch.Tensor):
if return_as_torch_tensor:
action = action.cpu()
else:
action = action.numpy()
elif isinstance(action, np.ndarray):
pass
else:
raise ValueError("action must be a dict, torch.Tensor, or np.ndarray")
action = action_transform(action, return_as_torch_tensor)

return action
31 changes: 6 additions & 25 deletions grl/agents/gm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from typing import Any, Dict, List, Tuple, Union
from typing import Dict, Union

import numpy as np
import torch
from easydict import EasyDict

from grl.agents import obs_transform, action_transform


class GPAgent:
"""
Overview:
The agent trained for generative policies.
This class is designed to be used with the ``GMPGAlgorithm`` and ``GMPOAlgorithm``.
Interface:
``__init__``, ``action``
"""
Expand Down Expand Up @@ -45,16 +48,7 @@ def act(
action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action.
"""

if isinstance(obs, np.ndarray):
obs = torch.from_numpy(obs).float().to(self.device)
elif isinstance(obs, Dict):
obs = {
k: torch.from_numpy(v).float().to(self.device) for k, v in obs.items()
}
elif isinstance(obs, torch.Tensor):
obs = obs.float().to(self.device)
else:
raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray")
obs = obs_transform(obs, self.device)

with torch.no_grad():

Expand Down Expand Up @@ -82,19 +76,6 @@ def act(
# Customized inference code ↑
# ---------------------------------------

if isinstance(action, Dict):
if return_as_torch_tensor:
action = {k: v.cpu() for k, v in action.items()}
else:
action = {k: v.cpu().numpy() for k, v in action.items()}
elif isinstance(action, torch.Tensor):
if return_as_torch_tensor:
action = action.cpu()
else:
action = action.numpy()
elif isinstance(action, np.ndarray):
pass
else:
raise ValueError("action must be a dict, torch.Tensor, or np.ndarray")
action = action_transform(action, return_as_torch_tensor)

return action
33 changes: 8 additions & 25 deletions grl/agents/gp.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from typing import Any, Dict, List, Tuple, Union
from typing import Dict, Union

import numpy as np
import torch
from easydict import EasyDict

from grl.agents import obs_transform, action_transform


class GPAgent:
"""
Overview:
The agent trained for generative policies.
This class is designed to be used with the ``GPAlgorithm``.
``GPAlgorithm`` is an experimental algorithm pipeline that is not included in the official release, which is divided into two parts: ``GMPGAlgorithm`` and ``GMPOAlgorithm``.
And this agent is going to be deprecated in the future.
Interface:
``__init__``, ``action``
"""
Expand Down Expand Up @@ -50,16 +55,7 @@ def act(
action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action.
"""

if isinstance(obs, np.ndarray):
obs = torch.from_numpy(obs).float().to(self.device)
elif isinstance(obs, Dict):
obs = {
k: torch.from_numpy(v).float().to(self.device) for k, v in obs.items()
}
elif isinstance(obs, torch.Tensor):
obs = obs.float().to(self.device)
else:
raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray")
obs = obs_transform(obs, self.device)

with torch.no_grad():

Expand Down Expand Up @@ -91,19 +87,6 @@ def act(
# Customized inference code ↑
# ---------------------------------------

if isinstance(action, Dict):
if return_as_torch_tensor:
action = {k: v.cpu() for k, v in action.items()}
else:
action = {k: v.cpu().numpy() for k, v in action.items()}
elif isinstance(action, torch.Tensor):
if return_as_torch_tensor:
action = action.cpu()
else:
action = action.numpy()
elif isinstance(action, np.ndarray):
pass
else:
raise ValueError("action must be a dict, torch.Tensor, or np.ndarray")
action = action_transform(action, return_as_torch_tensor)

return action
Loading

0 comments on commit 7dee5ad

Please sign in to comment.