From 7dee5adf0dd4552bf5032e0abd5b1501022425ad Mon Sep 17 00:00:00 2001 From: zjowowen Date: Sun, 16 Jun 2024 18:49:59 +0800 Subject: [PATCH] Polish APIs and documents. --- README.md | 3 +- README.zh.md | 3 +- docs/source/tutorials/quick_start/index.rst | 12 +- grl/agents/__init__.py | 38 +++++ grl/agents/base.py | 30 +--- grl/agents/gm.py | 31 +--- grl/agents/gp.py | 33 ++--- grl/agents/qgpo.py | 136 +----------------- grl/agents/srpo.py | 30 +--- grl/algorithms/gmpg.py | 6 + grl/algorithms/gmpo.py | 18 +++ grl/algorithms/gp.py | 8 ++ grl/algorithms/qgpo.py | 20 ++- grl/algorithms/srpo.py | 98 +++++++++++-- grl/datasets/minari_dataset.py | 118 +++++++-------- ...ptimal_transport_conditional_flow_model.py | 8 +- grl/generative_models/diffusion_process.py | 2 - grl/generative_models/sro.py | 2 +- grl/generative_models/stochastic_process.py | 8 +- grl/numerical_methods/monte_carlo.py | 1 - grl/numerical_methods/ode.py | 2 - grl/numerical_methods/sde.py | 3 - .../simulators/gym_env_simulator.py | 3 +- .../value_network/one_shot_value_function.py | 4 +- grl/rl_modules/value_network/q_network.py | 12 ++ grl/rl_modules/value_network/value_network.py | 12 ++ grl/unittest/agents/functions.py | 90 ++++++++++++ .../tutorials/customized_modules.py | 30 ++++ .../tutorials/dict_tensor_ode.py | 24 ++++ .../swiss_roll/swiss_roll_diffusion.py | 5 +- .../swiss_roll/swiss_roll_dpmsolver.py | 17 ++- .../swiss_roll/swiss_roll_energy_condition.py | 65 ++++++++- .../swiss_roll/swiss_roll_icfm.py | 5 +- .../swiss_roll/swiss_roll_likelihood.py | 39 ++++- .../swiss_roll/swiss_roll_otcfm.py | 5 +- .../swiss_roll/swiss_roll_sdesolver.py | 1 - .../swiss_roll/swiss_roll_sf2m.py | 6 +- 37 files changed, 578 insertions(+), 350 deletions(-) create mode 100644 grl/unittest/agents/functions.py rename grl/test/test_customized_modules.py => grl_pipelines/tutorials/customized_modules.py (86%) rename grl/test/test_dict_tensor_ode.py => grl_pipelines/tutorials/dict_tensor_ode.py (88%) rename grl/test/test_swiss_roll.py => grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_diffusion.py (96%) rename grl/test/test_swiss_roll_dpmsolver.py => grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_dpmsolver.py (92%) rename grl/test/test_swiss_roll_energy_condition.py => grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_energy_condition.py (90%) rename grl/test/test_swiss_roll_icfm.py => grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_icfm.py (96%) rename grl/test/test_swiss_roll_likelihood.py => grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_likelihood.py (86%) rename grl/test/test_swiss_roll_otcfm.py => grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_otcfm.py (96%) rename grl/test/test_swiss_roll_sdesolver.py => grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sdesolver.py (99%) rename grl/test/test_swiss_roll_SchrodingerBridge.py => grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sf2m.py (97%) diff --git a/README.md b/README.md index e0217d9..3bb1ea4 100644 --- a/README.md +++ b/README.md @@ -86,8 +86,9 @@ Here is an example of how to train a diffusion model for Q-guided policy optimiz Install the required dependencies: ```bash -pip install gym[box2d]==0.23.1 +pip install 'gym[box2d]==0.23.1' ``` +(The gym version can be from 0.23 to 0.25 for box2d environments, but it is recommended to use 0.23.1 for compatibility with D4RL.) Download dataset from [here](https://drive.google.com/file/d/1YnT-Oeu9LPKuS_ZqNc5kol_pMlJ1DwyG/view?usp=drive_link) and save it as `data.npz` in the current directory. diff --git a/README.zh.md b/README.zh.md index 22c9abf..66ef425 100644 --- a/README.zh.md +++ b/README.zh.md @@ -83,8 +83,9 @@ docker run -it --rm --gpus all opendilab/grl:torch2.3.0-cuda12.1-cudnn8-runtime 安装所需依赖: ```bash -pip install gym[box2d]==0.23.1 +pip install 'gym[box2d]==0.23.1' ``` +(此处的 gym 版本可以为'0.23~0.25', 但为了同时兼容 D4RL 推荐使用 '0.23.1'。) 数据集可以从 [这里](https://drive.google.com/file/d/1YnT-Oeu9LPKuS_ZqNc5kol_pMlJ1DwyG/view?usp=drive_link) 下载,请将其置于工作路径下,并命名为 `data.npz`。 diff --git a/docs/source/tutorials/quick_start/index.rst b/docs/source/tutorials/quick_start/index.rst index b2cb00e..6ef1203 100644 --- a/docs/source/tutorials/quick_start/index.rst +++ b/docs/source/tutorials/quick_start/index.rst @@ -1,11 +1,19 @@ Quick Start =========== -GenerativeRL provides a simple and flexible interface for training and deploying reinforcement learning agents powered by generative models. Here's an example of how to use the library to train a Q-guided policy optimization (QGPO) agent on the HalfCheetah environment and deploy it for evaluation. +Generative model in GenerativeRL +--------- + +GenerativeRL support easy-to-use APIs for training and deploying generative model. +We provide a simple example of how to train a diffusion model on the swiss roll dataset in [Colab](https://colab.research.google.com/drive/18yHUAmcMh_7xq2U6TBCtcLKX2y4YvNyk?usp=drive_link). + +More usage examples can be found in the folder `grl_pipelines/tutorials/`. -Code Example +Reinforcement Learning ------------ +GenerativeRL provides a simple and flexible interface for training and deploying reinforcement learning agents powered by generative models. Here's an example of how to use the library to train a Q-guided policy optimization (QGPO) agent on the HalfCheetah environment and deploy it for evaluation. + .. code-block:: python from grl_pipelines.diffusion_model.configurations.halfcheetah_qgpo import config diff --git a/grl/agents/__init__.py b/grl/agents/__init__.py index 6581f46..6960d75 100644 --- a/grl/agents/__init__.py +++ b/grl/agents/__init__.py @@ -1,3 +1,41 @@ +from typing import Dict +import torch +import numpy as np + + +def obs_transform(obs, device): + + if isinstance(obs, np.ndarray): + obs = torch.from_numpy(obs).float().to(device) + elif isinstance(obs, Dict): + obs = {k: torch.from_numpy(v).float().to(device) for k, v in obs.items()} + elif isinstance(obs, torch.Tensor): + obs = obs.float().to(device) + else: + raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray") + + return obs + + +def action_transform(action, return_as_torch_tensor: bool = False): + if isinstance(action, Dict): + if return_as_torch_tensor: + action = {k: v.cpu() for k, v in action.items()} + else: + action = {k: v.cpu().numpy() for k, v in action.items()} + elif isinstance(action, torch.Tensor): + if return_as_torch_tensor: + action = action.cpu() + else: + action = action.numpy() + elif isinstance(action, np.ndarray): + pass + else: + raise ValueError("action must be a dict, torch.Tensor, or np.ndarray") + + return action + + from .base import BaseAgent from .qgpo import QGPOAgent from .gm import GPAgent diff --git a/grl/agents/base.py b/grl/agents/base.py index 0c76f0e..7689fff 100644 --- a/grl/agents/base.py +++ b/grl/agents/base.py @@ -1,9 +1,11 @@ -from typing import Any, Dict, List, Tuple, Union +from typing import Dict, Union import numpy as np import torch from easydict import EasyDict +from grl.agents import obs_transform, action_transform + class BaseAgent: @@ -39,16 +41,7 @@ def act( action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. """ - if isinstance(obs, np.ndarray): - obs = torch.from_numpy(obs).float().to(self.device) - elif isinstance(obs, Dict): - obs = { - k: torch.from_numpy(v).float().to(self.device) for k, v in obs.items() - } - elif isinstance(obs, torch.Tensor): - obs = obs.float().to(self.device) - else: - raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray") + obs = obs_transform(obs, self.device) with torch.no_grad(): @@ -62,19 +55,6 @@ def act( # Customized inference code ↑ # --------------------------------------- - if isinstance(action, Dict): - if return_as_torch_tensor: - action = {k: v.cpu() for k, v in action.items()} - else: - action = {k: v.cpu().numpy() for k, v in action.items()} - elif isinstance(action, torch.Tensor): - if return_as_torch_tensor: - action = action.cpu() - else: - action = action.numpy() - elif isinstance(action, np.ndarray): - pass - else: - raise ValueError("action must be a dict, torch.Tensor, or np.ndarray") + action = action_transform(action, return_as_torch_tensor) return action diff --git a/grl/agents/gm.py b/grl/agents/gm.py index cb6910e..e900f10 100644 --- a/grl/agents/gm.py +++ b/grl/agents/gm.py @@ -1,14 +1,17 @@ -from typing import Any, Dict, List, Tuple, Union +from typing import Dict, Union import numpy as np import torch from easydict import EasyDict +from grl.agents import obs_transform, action_transform + class GPAgent: """ Overview: The agent trained for generative policies. + This class is designed to be used with the ``GMPGAlgorithm`` and ``GMPOAlgorithm``. Interface: ``__init__``, ``action`` """ @@ -45,16 +48,7 @@ def act( action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. """ - if isinstance(obs, np.ndarray): - obs = torch.from_numpy(obs).float().to(self.device) - elif isinstance(obs, Dict): - obs = { - k: torch.from_numpy(v).float().to(self.device) for k, v in obs.items() - } - elif isinstance(obs, torch.Tensor): - obs = obs.float().to(self.device) - else: - raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray") + obs = obs_transform(obs, self.device) with torch.no_grad(): @@ -82,19 +76,6 @@ def act( # Customized inference code ↑ # --------------------------------------- - if isinstance(action, Dict): - if return_as_torch_tensor: - action = {k: v.cpu() for k, v in action.items()} - else: - action = {k: v.cpu().numpy() for k, v in action.items()} - elif isinstance(action, torch.Tensor): - if return_as_torch_tensor: - action = action.cpu() - else: - action = action.numpy() - elif isinstance(action, np.ndarray): - pass - else: - raise ValueError("action must be a dict, torch.Tensor, or np.ndarray") + action = action_transform(action, return_as_torch_tensor) return action diff --git a/grl/agents/gp.py b/grl/agents/gp.py index c5c15d2..bf04182 100644 --- a/grl/agents/gp.py +++ b/grl/agents/gp.py @@ -1,14 +1,19 @@ -from typing import Any, Dict, List, Tuple, Union +from typing import Dict, Union import numpy as np import torch from easydict import EasyDict +from grl.agents import obs_transform, action_transform + class GPAgent: """ Overview: The agent trained for generative policies. + This class is designed to be used with the ``GPAlgorithm``. + ``GPAlgorithm`` is an experimental algorithm pipeline that is not included in the official release, which is divided into two parts: ``GMPGAlgorithm`` and ``GMPOAlgorithm``. + And this agent is going to be deprecated in the future. Interface: ``__init__``, ``action`` """ @@ -50,16 +55,7 @@ def act( action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. """ - if isinstance(obs, np.ndarray): - obs = torch.from_numpy(obs).float().to(self.device) - elif isinstance(obs, Dict): - obs = { - k: torch.from_numpy(v).float().to(self.device) for k, v in obs.items() - } - elif isinstance(obs, torch.Tensor): - obs = obs.float().to(self.device) - else: - raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray") + obs = obs_transform(obs, self.device) with torch.no_grad(): @@ -91,19 +87,6 @@ def act( # Customized inference code ↑ # --------------------------------------- - if isinstance(action, Dict): - if return_as_torch_tensor: - action = {k: v.cpu() for k, v in action.items()} - else: - action = {k: v.cpu().numpy() for k, v in action.items()} - elif isinstance(action, torch.Tensor): - if return_as_torch_tensor: - action = action.cpu() - else: - action = action.numpy() - elif isinstance(action, np.ndarray): - pass - else: - raise ValueError("action must be a dict, torch.Tensor, or np.ndarray") + action = action_transform(action, return_as_torch_tensor) return action diff --git a/grl/agents/qgpo.py b/grl/agents/qgpo.py index e4f8ed5..9deb8a2 100644 --- a/grl/agents/qgpo.py +++ b/grl/agents/qgpo.py @@ -1,9 +1,11 @@ -from typing import Any, Dict, List, Tuple, Union +from typing import Dict, Union import numpy as np import torch from easydict import EasyDict +from grl.agents import obs_transform, action_transform + class QGPOAgent: """ @@ -50,16 +52,7 @@ def act( action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. """ - if isinstance(obs, np.ndarray): - obs = torch.from_numpy(obs).float().to(self.device) - elif isinstance(obs, Dict): - obs = { - k: torch.from_numpy(v).float().to(self.device) for k, v in obs.items() - } - elif isinstance(obs, torch.Tensor): - obs = obs.float().to(self.device) - else: - raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray") + obs = obs_transform(obs, self.device) with torch.no_grad(): @@ -89,125 +82,6 @@ def act( # Customized inference code ↑ # --------------------------------------- - if isinstance(action, Dict): - if return_as_torch_tensor: - action = {k: v.cpu() for k, v in action.items()} - else: - action = {k: v.cpu().numpy() for k, v in action.items()} - elif isinstance(action, torch.Tensor): - if return_as_torch_tensor: - action = action.cpu() - else: - action = action.numpy() - elif isinstance(action, np.ndarray): - pass - else: - raise ValueError("action must be a dict, torch.Tensor, or np.ndarray") - - return action - - -class QGPOISAgent: - """ - Overview: - The QGPO agent trained by importance sampling. - Interface: - ``__init__``, ``action`` - """ - - def __init__( - self, - config: EasyDict, - model: Union[torch.nn.Module, torch.nn.ModuleDict], - ): - """ - Overview: - Initialize the agent. - Arguments: - config (:obj:`EasyDict`): The configuration. - model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. - """ - - self.config = config - self.device = config.device - self.model = model.to(self.device) - - if hasattr(self.config, "guidance_scale"): - self.guidance_scale = self.config.guidance_scale - else: - self.guidance_scale = 1.0 - - def act( - self, - obs: Union[np.ndarray, torch.Tensor, Dict], - return_as_torch_tensor: bool = False, - ) -> Union[np.ndarray, torch.Tensor, Dict]: - """ - Overview: - Given an observation, return an action. - Arguments: - obs (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The observation. - return_as_torch_tensor (:obj:`bool`): Whether to return the action as a torch tensor. - Returns: - action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. - """ - - if isinstance(obs, np.ndarray): - obs = torch.from_numpy(obs).float().to(self.device) - elif isinstance(obs, Dict): - obs = { - k: torch.from_numpy(v).float().to(self.device) for k, v in obs.items() - } - elif isinstance(obs, torch.Tensor): - obs = obs.float().to(self.device) - else: - raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray") - - with torch.no_grad(): - - # --------------------------------------- - # Customized inference code ↓ - # --------------------------------------- - - obs = obs.unsqueeze(0) - action = ( - self.model["GuidedPolicy"] - .sample( - base_model=self.model["QGPOPolicy"].diffusion_model.model, - guided_model=self.model[ - "QGPOPolicy" - ].diffusion_model_important_sampling.model, - state=obs, - t_span=( - torch.linspace(0.0, 1.0, self.config.t_span).to(obs.device) - if self.config.t_span is not None - else None - ), - guidance_scale=self.guidance_scale, - ) - .squeeze(0) - .cpu() - .detach() - .numpy() - ) - - # --------------------------------------- - # Customized inference code ↑ - # --------------------------------------- - - if isinstance(action, Dict): - if return_as_torch_tensor: - action = {k: v.cpu() for k, v in action.items()} - else: - action = {k: v.cpu().numpy() for k, v in action.items()} - elif isinstance(action, torch.Tensor): - if return_as_torch_tensor: - action = action.cpu() - else: - action = action.numpy() - elif isinstance(action, np.ndarray): - pass - else: - raise ValueError("action must be a dict, torch.Tensor, or np.ndarray") + action = action_transform(action, return_as_torch_tensor) return action diff --git a/grl/agents/srpo.py b/grl/agents/srpo.py index ea45ead..bf5d121 100644 --- a/grl/agents/srpo.py +++ b/grl/agents/srpo.py @@ -1,9 +1,11 @@ -from typing import Any, Dict, List, Tuple, Union +from typing import Dict, Union import numpy as np import torch from easydict import EasyDict +from grl.agents import obs_transform, action_transform + class SRPOAgent: """ @@ -45,16 +47,7 @@ def act( action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. """ - if isinstance(obs, np.ndarray): - obs = torch.from_numpy(obs).float().to(self.device) - elif isinstance(obs, Dict): - obs = { - k: torch.from_numpy(v).float().to(self.device) for k, v in obs.items() - } - elif isinstance(obs, torch.Tensor): - obs = obs.float().to(self.device) - else: - raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray") + obs = obs_transform(obs, self.device) with torch.no_grad(): @@ -68,19 +61,6 @@ def act( # Customized inference code ↑ # --------------------------------------- - if isinstance(action, Dict): - if return_as_torch_tensor: - action = {k: v.cpu() for k, v in action.items()} - else: - action = {k: v.cpu().numpy() for k, v in action.items()} - elif isinstance(action, torch.Tensor): - if return_as_torch_tensor: - action = action.cpu() - else: - action = action.numpy() - elif isinstance(action, np.ndarray): - pass - else: - raise ValueError("action must be a dict, torch.Tensor, or np.ndarray") + action = action_transform(action, return_as_torch_tensor) return action diff --git a/grl/algorithms/gmpg.py b/grl/algorithms/gmpg.py index 7ee1c10..9f79c57 100644 --- a/grl/algorithms/gmpg.py +++ b/grl/algorithms/gmpg.py @@ -492,6 +492,12 @@ def policy_gradient_loss_by_REINFORCE_softmax( class GMPGAlgorithm: + """ + Overview: + The Generative Model Policy Gradient(GMPG) algorithm. + Interfaces: + ``__init__``, ``train``, ``deploy`` + """ def __init__( self, diff --git a/grl/algorithms/gmpo.py b/grl/algorithms/gmpo.py index 1e53800..fa640c5 100644 --- a/grl/algorithms/gmpo.py +++ b/grl/algorithms/gmpo.py @@ -113,8 +113,20 @@ def iql_q_loss(self, state, action, reward, done, next_v, discount): class GMPOPolicy(nn.Module): + """ + Overview: + GMPO policy network for GMPO algorithm, which includes the base model (optinal), the guided model and the critic. + Interfaces: + ``__init__``, ``forward``, ``sample``, ``compute_q``, ``behaviour_policy_loss``, ``policy_optimization_loss_by_advantage_weighted_regression``, ``policy_optimization_loss_by_advantage_weighted_regression_softmax`` + """ def __init__(self, config: EasyDict): + """ + Overview: + Initialize the GMPO policy network. + Arguments: + config (:obj:`EasyDict`): The configuration dict. + """ super().__init__() self.config = config self.device = config.device @@ -392,6 +404,12 @@ def policy_optimization_loss_by_advantage_weighted_regression_softmax( class GMPOAlgorithm: + """ + Overview: + The Generative Model Policy Optimization(GMPO) algorithm. + Interfaces: + ``__init__``, ``train``, ``deploy`` + """ def __init__( self, diff --git a/grl/algorithms/gp.py b/grl/algorithms/gp.py index 3f65201..fd35265 100644 --- a/grl/algorithms/gp.py +++ b/grl/algorithms/gp.py @@ -902,6 +902,14 @@ def q_loss( class GPAlgorithm: + """ + Overview: + The algorithm pipeline of the Generative Policy algorithm. + ``GPAlgorithm`` is an experimental algorithm pipeline that is not included in the official release, which is divided into two parts: ``GMPGAlgorithm`` and ``GMPOAlgorithm`` for clarity. + And this agent is going to be deprecated in the future. + Interfaces: + ``__init__``, ``train``, ``deploy`` + """ def __init__( self, diff --git a/grl/algorithms/qgpo.py b/grl/algorithms/qgpo.py index 6e61739..86e8c1b 100644 --- a/grl/algorithms/qgpo.py +++ b/grl/algorithms/qgpo.py @@ -3,7 +3,7 @@ ############################################################# import copy -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import List, Tuple, Union import numpy as np import torch @@ -126,6 +126,12 @@ def q_loss( class QGPOPolicy(nn.Module): + """ + Overview: + QGPO policy network. + Interfaces: + ``__init__``, ``forward``, ``sample``, ``behaviour_policy_sample``, ``compute_q``, ``behaviour_policy_loss``, ``energy_guidance_loss``, ``q_loss`` + """ def __init__(self, config: EasyDict): super().__init__() @@ -279,6 +285,12 @@ def q_loss( class QGPOAlgorithm: + """ + Overview: + Q-guided policy optimization (QGPO) algorithm, which is an offline reinforcement learning algorithm that uses energy-based diffusion model for policy modeling. + Interfaces: + ``__init__``, ``train``, ``deploy`` + """ def __init__( self, @@ -593,6 +605,12 @@ def policy(obs: np.ndarray) -> np.ndarray: wandb.finish() def deploy(self, config: EasyDict = None) -> QGPOAgent: + """ + Overview: + Deploy the model using the given configuration. + Arguments: + config (:obj:`EasyDict`): The deployment configuration. + """ if config is not None: config = merge_two_dicts_into_newone(self.config.deploy, config) diff --git a/grl/algorithms/srpo.py b/grl/algorithms/srpo.py index f39c900..c278480 100644 --- a/grl/algorithms/srpo.py +++ b/grl/algorithms/srpo.py @@ -27,7 +27,14 @@ class Dirac_Policy(nn.Module): - def __init__(self, action_dim, state_dim, layer=2): + """ + Overview: + The deterministic policy network used in SRPO algorithm. + Interfaces: + ``__init__``, ``forward``, ``select_actions`` + """ + + def __init__(self, action_dim: int, state_dim: int, layer: int = 2): super().__init__() self.net = MultiLayerPerceptron( hidden_sizes=[state_dim] + [256 for _ in range(layer)], @@ -36,19 +43,39 @@ def __init__(self, action_dim, state_dim, layer=2): final_activation="tanh", ) - def forward(self, state): + def forward(self, state: torch.Tensor): return self.net(state) - def select_actions(self, state): + def select_actions(self, state: torch.Tensor): return self(state) def asymmetric_l2_loss(u, tau): + """ + Overview: + Calculate the asymmetric L2 loss, which is used in Implicit Q-Learning. + Arguments: + u (:obj:`torch.Tensor`): The input tensor. + tau (:obj:`float`): The threshold. + """ return torch.mean(torch.abs(tau - (u < 0).float()) * u**2) class ValueFunction(nn.Module): - def __init__(self, state_dim): + """ + Overview: + The value network used in SRPO algorithm. + Interfaces: + ``__init__``, ``forward`` + """ + + def __init__(self, state_dim: int): + """ + Overview: + Initialize the value network. + Arguments: + state_dim (:obj:`int`): The dimension of the state. + """ super().__init__() self.v = MultiLayerPerceptron( hidden_sizes=[state_dim, 256, 256], @@ -57,17 +84,43 @@ def __init__(self, state_dim): ) def forward(self, state): + """ + Overview: + Forward pass of the value network. + Arguments: + state (:obj:`torch.Tensor`): The input state. + """ return self.v(state) class SRPOCritic(nn.Module): + """ + Overview: + The critic network used in SRPO algorithm. + Interfaces: + ``__init__``, ``v_loss``, ``q_loss + """ + def __init__(self, config) -> None: + """ + Overview: + Initialize the critic network. + Arguments: + config (:obj:`EasyDict`): The configuration. + """ super().__init__() self.q0 = DoubleQNetwork(config.DoubleQNetwork) self.q0_target = copy.deepcopy(self.q0).requires_grad_(False) self.vf = ValueFunction(config.sdim) def v_loss(self, data, tau): + """ + Overview: + Calculate the value loss. + Arguments: + data (:obj:`Dict`): The input data. + tau (:obj:`float`): The threshold. + """ s = data["s"] a = data["a"] r = data["r"] @@ -83,6 +136,14 @@ def v_loss(self, data, tau): return v_loss, next_v def q_loss(self, data, next_v, discount): + """ + Overview: + Calculate the Q loss. + Arguments: + data (:obj:`Dict`): The input data. + next_v (:obj:`torch.Tensor`): The input next state value. + discount (:obj:`float`): The discount factor. + """ # Update Q function s = data["s"] a = data["a"] @@ -95,7 +156,20 @@ def q_loss(self, data, next_v, discount): class SRPOPolicy(nn.Module): + """ + Overview: + The SRPO policy network. + Interfaces: + ``__init__``, ``forward``, ``behaviour_policy_loss``, ``v_loss``, ``q_loss``, ``srpo_actor_loss`` + """ + def __init__(self, config: EasyDict): + """ + Overview: + Initialize the SRPO policy network. + Arguments: + config (:obj:`EasyDict`): The configuration. + """ super().__init__() self.config = config self.device = config.device @@ -138,8 +212,8 @@ def behaviour_policy_loss( def v_loss( self, - data, - tau=0.9, + data: Dict[str, torch.Tensor], + tau: int = 0.9, ) -> torch.Tensor: """ Overview: @@ -157,8 +231,8 @@ def v_loss( def q_loss( self, - data, - next_v, + data: Dict[str, torch.Tensor], + next_v: torch.Tensor, discount_factor: float = 1.0, ) -> torch.Tensor: """ @@ -179,7 +253,7 @@ def q_loss( def srpo_actor_loss( self, - data, + data: Dict[str, torch.Tensor], ) -> torch.Tensor: """ Overview: @@ -568,6 +642,12 @@ def policy(obs: np.ndarray) -> np.ndarray: wandb.finish() def deploy(self, config: EasyDict = None) -> SRPOAgent: + """ + Overview: + Deploy the model using the given configuration. + Arguments: + config (:obj:`EasyDict`): The deployment configuration. + """ if config is not None: config = merge_two_dicts_into_newone(self.config.deploy, config) diff --git a/grl/datasets/minari_dataset.py b/grl/datasets/minari_dataset.py index c44e78f..c1aed5b 100644 --- a/grl/datasets/minari_dataset.py +++ b/grl/datasets/minari_dataset.py @@ -10,74 +10,7 @@ class MinariDataset(torch.utils.data.Dataset): """ Overview: - Dataset for QGPO && SRPOAlgorithm algorithm. The training of QGPO && SRPOAlgorithm algorithm is based on contrastive energy prediction, \ - which needs true action and fake action. The true action is sampled from the dataset, and the fake action \ - is sampled from the action support generated by the behaviour policy. - Interface: - ``__init__``, ``__getitem__``, ``__len__``. - """ - - def __init__(self): - """ - Overview: - Initialization method of MinariDataset class - """ - pass - - def __getitem__(self, index): - """ - Overview: - Get data by index - Arguments: - index (:obj:`int`): Index of data - Returns: - data (:obj:`dict`): Data dict - - .. note:: - The data dict contains the following keys: - - s (:obj:`torch.Tensor`): State - a (:obj:`torch.Tensor`): Action - r (:obj:`torch.Tensor`): Reward - s_ (:obj:`torch.Tensor`): Next state - d (:obj:`torch.Tensor`): Is finished - fake_a (:obj:`torch.Tensor`): Fake action for contrastive energy prediction and qgpo training \ - (fake action is sampled from the action support generated by the behaviour policy) - fake_a_ (:obj:`torch.Tensor`): Fake next action for contrastive energy prediction and qgpo training \ - (fake action is sampled from the action support generated by the behaviour policy) - """ - - data = { - "s": self.states[index % self.len], - "a": self.actions[index % self.len], - "r": self.rewards[index % self.len], - "s_": self.next_states[index % self.len], - "d": self.is_finished[index % self.len], - "fake_a": ( - self.fake_actions[index % self.len] - if hasattr(self, "fake_actions") - else 0.0 - ), # self.fake_actions - "fake_a_": ( - self.fake_next_actions[index % self.len] - if hasattr(self, "fake_next_actions") - else 0.0 - ), # self.fake_next_actions - } - return data - - def __len__(self): - return self.len - - @abstractmethod - def return_range(self, dataset, max_episode_steps): - raise NotImplementedError - - -class MinariDataset(MinariDataset): - """ - Overview: - Dataset for QGPO && SRPOAlgorithm algorithm. The training of QGPO && SRPOAlgorithm algorithm is based on contrastive energy prediction, \ + Minari Dataset for QGPO && SRPOAlgorithm algorithm. The training of QGPO && SRPOAlgorithm algorithm is based on contrastive energy prediction, \ which needs true action and fake action. The true action is sampled from the dataset, and the fake action \ is sampled from the action support generated by the behaviour policy. Interface: @@ -155,6 +88,55 @@ def __init__( self.len = self.states.shape[0] log.info(f"{self.len} data loaded in MinariDataset") + def __getitem__(self, index): + """ + Overview: + Get data by index + Arguments: + index (:obj:`int`): Index of data + Returns: + data (:obj:`dict`): Data dict + + .. note:: + The data dict contains the following keys: + + s (:obj:`torch.Tensor`): State + a (:obj:`torch.Tensor`): Action + r (:obj:`torch.Tensor`): Reward + s_ (:obj:`torch.Tensor`): Next state + d (:obj:`torch.Tensor`): Is finished + fake_a (:obj:`torch.Tensor`): Fake action for contrastive energy prediction and qgpo training \ + (fake action is sampled from the action support generated by the behaviour policy) + fake_a_ (:obj:`torch.Tensor`): Fake next action for contrastive energy prediction and qgpo training \ + (fake action is sampled from the action support generated by the behaviour policy) + """ + + data = { + "s": self.states[index % self.len], + "a": self.actions[index % self.len], + "r": self.rewards[index % self.len], + "s_": self.next_states[index % self.len], + "d": self.is_finished[index % self.len], + "fake_a": ( + self.fake_actions[index % self.len] + if hasattr(self, "fake_actions") + else 0.0 + ), # self.fake_actions + "fake_a_": ( + self.fake_next_actions[index % self.len] + if hasattr(self, "fake_next_actions") + else 0.0 + ), # self.fake_next_actions + } + return data + + def __len__(self): + return self.len + + @abstractmethod + def return_range(self, dataset, max_episode_steps): + raise NotImplementedError + def return_range(dataset, max_episode_steps): returns, lengths = [], [] ep_ret, ep_len = 0.0, 0 diff --git a/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.py b/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.py index e06e750..3798985 100644 --- a/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.py +++ b/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.py @@ -8,13 +8,8 @@ from easydict import EasyDict from tensordict import TensorDict -from grl.generative_models.diffusion_process import DiffusionProcess from grl.generative_models.intrinsic_model import IntrinsicModel -from grl.generative_models.model_functions.data_prediction_function import ( - DataPredictionFunction, -) -from grl.generative_models.model_functions.noise_function import NoiseFunction -from grl.generative_models.model_functions.score_function import ScoreFunction + from grl.generative_models.model_functions.velocity_function import VelocityFunction from grl.generative_models.random_generator import gaussian_random_variable from grl.generative_models.stochastic_process import StochasticProcess @@ -27,7 +22,6 @@ from grl.numerical_methods.numerical_solvers.sde_solver import SDESolver from grl.numerical_methods.probability_path import ( ConditionalProbabilityPath, - GaussianConditionalProbabilityPath, ) from grl.utils import find_parameters diff --git a/grl/generative_models/diffusion_process.py b/grl/generative_models/diffusion_process.py index 22aaf6c..5275a2f 100644 --- a/grl/generative_models/diffusion_process.py +++ b/grl/generative_models/diffusion_process.py @@ -3,9 +3,7 @@ import torch import torch.nn as nn import treetensor -from easydict import EasyDict from tensordict import TensorDict -from torch.distributions import Distribution from grl.numerical_methods.ode import ODE from grl.numerical_methods.probability_path import GaussianConditionalProbabilityPath diff --git a/grl/generative_models/sro.py b/grl/generative_models/sro.py index f5bf733..e8de100 100644 --- a/grl/generative_models/sro.py +++ b/grl/generative_models/sro.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Union import torch import torch.nn as nn diff --git a/grl/generative_models/stochastic_process.py b/grl/generative_models/stochastic_process.py index e5fc3bb..efbef24 100644 --- a/grl/generative_models/stochastic_process.py +++ b/grl/generative_models/stochastic_process.py @@ -1,18 +1,12 @@ -from typing import Callable, Union +from typing import Union import torch -import torch.nn as nn import treetensor -from easydict import EasyDict from tensordict import TensorDict -from torch.distributions import Distribution -from grl.numerical_methods.ode import ODE from grl.numerical_methods.probability_path import ( ConditionalProbabilityPath, - GaussianConditionalProbabilityPath, ) -from grl.numerical_methods.sde import SDE class StochasticProcess: diff --git a/grl/numerical_methods/monte_carlo.py b/grl/numerical_methods/monte_carlo.py index 36b094f..3edd2a9 100644 --- a/grl/numerical_methods/monte_carlo.py +++ b/grl/numerical_methods/monte_carlo.py @@ -4,7 +4,6 @@ import matplotlib.pyplot as plt import torch import torch.distributions.uniform as uniform -import torch.nn.functional as F class MonteCarloSampler: diff --git a/grl/numerical_methods/ode.py b/grl/numerical_methods/ode.py index 82ec889..467d6b2 100644 --- a/grl/numerical_methods/ode.py +++ b/grl/numerical_methods/ode.py @@ -1,7 +1,5 @@ from typing import Callable, Union -import torch -from tensordict import TensorDict from torch import nn diff --git a/grl/numerical_methods/sde.py b/grl/numerical_methods/sde.py index cfe0da1..424d24f 100644 --- a/grl/numerical_methods/sde.py +++ b/grl/numerical_methods/sde.py @@ -1,8 +1,5 @@ from typing import Callable, Union -import torch -from easydict import EasyDict -from tensordict import TensorDict from torch import nn diff --git a/grl/rl_modules/simulators/gym_env_simulator.py b/grl/rl_modules/simulators/gym_env_simulator.py index b657895..b6054f7 100644 --- a/grl/rl_modules/simulators/gym_env_simulator.py +++ b/grl/rl_modules/simulators/gym_env_simulator.py @@ -1,8 +1,7 @@ -from typing import Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Union import gym import torch -from easydict import EasyDict class GymEnvSimulator: diff --git a/grl/rl_modules/value_network/one_shot_value_function.py b/grl/rl_modules/value_network/one_shot_value_function.py index 4214544..6711b42 100644 --- a/grl/rl_modules/value_network/one_shot_value_function.py +++ b/grl/rl_modules/value_network/one_shot_value_function.py @@ -1,5 +1,5 @@ import copy -from typing import Any, Dict, List, Tuple, Union +from typing import Tuple, Union import torch import torch.nn as nn @@ -12,7 +12,7 @@ class OneShotValueFunction(nn.Module): """ Overview: - Value network for one-shot cases. + Value network for one-shot cases, which means that no Bellman backup is needed for training. Interfaces: ``__init__``, ``forward`` """ diff --git a/grl/rl_modules/value_network/q_network.py b/grl/rl_modules/value_network/q_network.py index da1a970..bbfa25b 100644 --- a/grl/rl_modules/value_network/q_network.py +++ b/grl/rl_modules/value_network/q_network.py @@ -10,8 +10,20 @@ class QNetwork(nn.Module): + """ + Overview: + Q network, which is used to approximate the Q value. + Interfaces: + ``__init__``, ``forward`` + """ def __init__(self, config: EasyDict): + """ + Overview: + Initialization of Q network. + Arguments: + config (:obj:`EasyDict`): The configuration dict. + """ super().__init__() self.config = config self.model = torch.nn.ModuleDict() diff --git a/grl/rl_modules/value_network/value_network.py b/grl/rl_modules/value_network/value_network.py index 7d1fbf7..a9888b7 100644 --- a/grl/rl_modules/value_network/value_network.py +++ b/grl/rl_modules/value_network/value_network.py @@ -10,8 +10,20 @@ class VNetwork(nn.Module): + """ + Overview: + Value network, which is used to approximate the value function. + Interfaces: + ``__init__``, ``forward`` + """ def __init__(self, config: EasyDict): + """ + Overview: + Initialization of value network. + Arguments: + config (:obj:`EasyDict`): The configuration dict. + """ super().__init__() self.config = config self.model = torch.nn.ModuleDict() diff --git a/grl/unittest/agents/functions.py b/grl/unittest/agents/functions.py new file mode 100644 index 0000000..c4a4d27 --- /dev/null +++ b/grl/unittest/agents/functions.py @@ -0,0 +1,90 @@ +import unittest +import numpy as np +import torch +from grl.agents import obs_transform, action_transform + +# Assume obs_transform and action_transform are defined in the same module or imported properly here. + + +class TestTransforms(unittest.TestCase): + + def setUp(self): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def test_obs_transform_numpy(self): + obs = np.array([1, 2, 3], dtype=np.float32) + transformed = obs_transform(obs, self.device) + self.assertIsInstance(transformed, torch.Tensor) + self.assertTrue(transformed.is_floating_point()) + self.assertEqual(transformed.device, self.device) + np.testing.assert_array_equal(transformed.cpu().numpy(), obs) + + def test_obs_transform_dict(self): + obs = { + "a": np.array([1, 2, 3], dtype=np.float32), + "b": np.array([4, 5, 6], dtype=np.float32), + } + transformed = obs_transform(obs, self.device) + self.assertIsInstance(transformed, dict) + for k, v in transformed.items(): + self.assertIsInstance(v, torch.Tensor) + self.assertTrue(v.is_floating_point()) + self.assertEqual(v.device, self.device) + np.testing.assert_array_equal(v.cpu().numpy(), obs[k]) + + def test_obs_transform_tensor(self): + obs = torch.tensor([1, 2, 3], dtype=torch.float32) + transformed = obs_transform(obs, self.device) + self.assertIsInstance(transformed, torch.Tensor) + self.assertTrue(transformed.is_floating_point()) + self.assertEqual(transformed.device, self.device) + self.assertTrue(torch.equal(transformed.cpu(), obs)) + + def test_obs_transform_invalid(self): + obs = [1, 2, 3] + with self.assertRaises(ValueError): + obs_transform(obs, self.device) + + def test_action_transform_dict(self): + action = { + "a": torch.tensor([1, 2, 3], dtype=torch.float32), + "b": torch.tensor([4, 5, 6], dtype=torch.float32), + } + transformed = action_transform(action, return_as_torch_tensor=True) + self.assertIsInstance(transformed, dict) + for k, v in transformed.items(): + self.assertIsInstance(v, torch.Tensor) + self.assertFalse(v.is_cuda) + self.assertTrue(torch.equal(v, action[k].cpu())) + + transformed = action_transform(action, return_as_torch_tensor=False) + self.assertIsInstance(transformed, dict) + for k, v in transformed.items(): + self.assertIsInstance(v, np.ndarray) + np.testing.assert_array_equal(v, action[k].cpu().numpy()) + + def test_action_transform_tensor(self): + action = torch.tensor([1, 2, 3], dtype=torch.float32).to(self.device) + transformed = action_transform(action, return_as_torch_tensor=True) + self.assertIsInstance(transformed, torch.Tensor) + self.assertFalse(transformed.is_cuda) + self.assertTrue(torch.equal(transformed, action.cpu())) + + transformed = action_transform(action, return_as_torch_tensor=False) + self.assertIsInstance(transformed, np.ndarray) + np.testing.assert_array_equal(transformed, action.cpu().numpy()) + + def test_action_transform_numpy(self): + action = np.array([1, 2, 3], dtype=np.float32) + transformed = action_transform(action) + self.assertIsInstance(transformed, np.ndarray) + np.testing.assert_array_equal(transformed, action) + + def test_action_transform_invalid(self): + action = [1, 2, 3] + with self.assertRaises(ValueError): + action_transform(action) + + +if __name__ == "__main__": + unittest.main() diff --git a/grl/test/test_customized_modules.py b/grl_pipelines/tutorials/customized_modules.py similarity index 86% rename from grl/test/test_customized_modules.py rename to grl_pipelines/tutorials/customized_modules.py index 33f10ff..3840d3f 100644 --- a/grl/test/test_customized_modules.py +++ b/grl_pipelines/tutorials/customized_modules.py @@ -1,3 +1,33 @@ +################################################################################################ +# This script demonstrates how to use customized neural network modules in GRL. +# +# In this example, we define a customized neural network module named `MyModule` +# and use it in the DiffusionModel. For convenience, we redefine `MyModule` by +# reusing `TemporalSpatialResidualNet` in this script. +# +# We can call `register_module` to register the customized module, such as: +# ``` +# register_module(TemporalSpatialResidualNet, "MyModule") +# ``` +# The `register_module` function is used to register the customized module to the +# module registry. The module registry is a global dictionary that stores the mapping +# from the module name to the module class. The module registry is used to create +# the module instance by the module name. +# +# The module name is used to specify the module type in the configuration file, such as: +# ``` +# backbone=dict( +# type="MyModule", +# args=dict( +# hidden_sizes=[512, 256, 128], +# output_dim=x_size, +# t_dim=t_embedding_dim, +# ), +# ), +# ``` +# The module type is used to create the module instance in the `DiffusionModel`. +################################################################################################ + import os import signal import sys diff --git a/grl/test/test_dict_tensor_ode.py b/grl_pipelines/tutorials/dict_tensor_ode.py similarity index 88% rename from grl/test/test_dict_tensor_ode.py rename to grl_pipelines/tutorials/dict_tensor_ode.py index c642d4a..1d0eae4 100644 --- a/grl/test/test_dict_tensor_ode.py +++ b/grl_pipelines/tutorials/dict_tensor_ode.py @@ -1,3 +1,27 @@ +################################################################################################ +# This script demonstrates how to use a dictionary tensor in the ODE solver for the diffusion model. +# +# We create a customized neural network module named `MyModule` and use it in the DiffusionModel, +# which gets a dictionary tensor as input. This module is registered in the module registry by +# calling `register_module`. +# +# We also use the `DictTensorODESolver` in the diffusion model to solve the ODE with a dictionary tensor, +# which requires "torchdyn" as the library. +# +# The data is generated by wrapping the original data with a dictionary tensor, such as: +# ``` +# batch_data = treetensor.torch.tensor(dict(x=batch_data)) +# batch_data = batch_data.to(config.device) +# ``` +# +# The training process is similar to the original diffusion model, but the input data is a dictionary tensor. +# For example, using score matching loss or flow matching loss: +# ``` +# loss = diffusion_model.score_matching_loss(batch_data) +# ``` +# It is worth noting that this dictionary tensor sampled from the diffusion model does not support sampling with automatic differentiation. +################################################################################################ + import os import signal import sys diff --git a/grl/test/test_swiss_roll.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_diffusion.py similarity index 96% rename from grl/test/test_swiss_roll.py rename to grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_diffusion.py index 51be298..80c54db 100644 --- a/grl/test/test_swiss_roll.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_diffusion.py @@ -1,3 +1,7 @@ +################################################################################################ +# This script demonstrates how to use a diffusion model to train Swiss Roll dataset. +################################################################################################ + import os import signal import sys @@ -11,7 +15,6 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt import torch -import torch.nn as nn from easydict import EasyDict from matplotlib import animation diff --git a/grl/test/test_swiss_roll_dpmsolver.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_dpmsolver.py similarity index 92% rename from grl/test/test_swiss_roll_dpmsolver.py rename to grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_dpmsolver.py index 48af997..3ac363b 100644 --- a/grl/test/test_swiss_roll_dpmsolver.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_dpmsolver.py @@ -1,3 +1,19 @@ +################################################################################################ +# This script demonstrates how to use DPM solver in the DiffusionModel for training Swiss Roll dataset. +# We can change the solver type to DPM solver in the configuration file, such as: +# ``` +# solver=dict( +# type="DPMSolver", +# args=dict( +# order=2, +# device=device, +# steps=17, +# ), +# ), +# ``` +# The DPM solver is a high-order solver that can solve the diffusion process with high accuracy and fast speed theorectically. +################################################################################################ + import os import signal import sys @@ -11,7 +27,6 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt import torch -import torch.nn as nn from easydict import EasyDict from matplotlib import animation diff --git a/grl/test/test_swiss_roll_energy_condition.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_energy_condition.py similarity index 90% rename from grl/test/test_swiss_roll_energy_condition.py rename to grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_energy_condition.py index 1e89713..43bcf17 100644 --- a/grl/test/test_swiss_roll_energy_condition.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_energy_condition.py @@ -1,7 +1,68 @@ +################################################################################################ +# This script demonstrates how to use an energy-conditioned diffusion model to train the Swiss Roll +# dataset with artificial values. We can model the energy function with a value function model and +# use the energy function to guide the diffusion process. We use a OneShotValueFunction as the value +# function model in this script. +# +# Configuration for OneShotValueFunction: +# +# value_function_model=dict( +# device=device, +# v_alpha=1.0, +# DoubleVNetwork=dict( +# state_encoder=dict( +# type="GaussianFourierProjectionEncoder", +# args=dict( +# embed_dim=128, +# x_shape=[x_size], +# scale=0.5, +# ), +# ), +# backbone=dict( +# type="ConcatenateMLP", +# args=dict( +# hidden_sizes=[128 * x_size, 256, 256], +# output_size=1, +# activation="silu", +# ), +# ), +# ), +# ), +# +# Then we can use the value function model to guide the diffusion process in the energy-conditioned +# diffusion model. An energy-conditioned diffusion model is a diffusion model that is conditioned on +# the energy function, which has an extra intermediate energy guidance module. +# +# Configuration for energy guidance: +# +# energy_guidance=dict( +# t_encoder=t_encoder, +# backbone=dict( +# type="ConcatenateMLP", +# args=dict( +# hidden_sizes=[x_size + t_embedding_dim, 256, 256], +# output_size=1, +# activation="silu", +# ), +# ), +# ), +# +# We can train the energy-conditioned diffusion model with the energy guidance loss and the score +# matching loss, such as: +# +# energy_guidance_loss = energy_conditioned_diffusion_model.energy_guidance_loss( +# x=train_fake_x, +# ) +# energy_guidance_optimizer.zero_grad() +# energy_guidance_loss.backward() +# energy_guidance_optimizer.step() +# +# The fake_x is sampled from the energy-conditioned diffusion model in a way of data augmentation. +################################################################################################ + + import multiprocessing as mp import os -import signal -import sys import matplotlib import numpy as np diff --git a/grl/test/test_swiss_roll_icfm.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_icfm.py similarity index 96% rename from grl/test/test_swiss_roll_icfm.py rename to grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_icfm.py index 7375081..1cc6f3f 100644 --- a/grl/test/test_swiss_roll_icfm.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_icfm.py @@ -1,3 +1,7 @@ +################################################################################################ +# This script demonstrates how to use an Independent Conditional Flow Matching (ICFM), which is a flow model, to train Swiss Roll dataset. +################################################################################################ + import os import signal import sys @@ -11,7 +15,6 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt import torch -import torch.nn as nn from easydict import EasyDict from matplotlib import animation diff --git a/grl/test/test_swiss_roll_likelihood.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_likelihood.py similarity index 86% rename from grl/test/test_swiss_roll_likelihood.py rename to grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_likelihood.py index 3c94afa..2f14e81 100644 --- a/grl/test/test_swiss_roll_likelihood.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_likelihood.py @@ -1,3 +1,41 @@ +################################################################################################ +# This script demonstrates how to train a diffusion model on the Swiss Roll dataset and evaluate +# the log-likelihood. We can use the API `compute_likelihood` to compute the log-likelihood of the +# diffusion model on the Swiss Roll dataset. +# +# To compute the log-likelihood using the Hutchinson trace estimator, set the argument +# `using_Hutchinson_trace_estimator` to True: +# +# logp = compute_likelihood( +# model=diffusion_model, +# x=torch.tensor(data).to(config.device), +# using_Hutchinson_trace_estimator=True, +# ) +# +# To compute the log-likelihood using the exact trace estimator, set the argument +# `using_Hutchinson_trace_estimator` to False: +# +# logp = compute_likelihood( +# model=diffusion_model, +# x=torch.tensor(data).to(config.device), +# using_Hutchinson_trace_estimator=False, +# ) +# +# We support computing the log-likelihood with automatic differentiation by setting the argument +# `with_grad` to True: +# +# logp = compute_likelihood( +# model=diffusion_model, +# x=torch.tensor(data).to(config.device), +# using_Hutchinson_trace_estimator=True, +# with_grad=True, +# ) +# +# However, for numerical stability, only the model type `velocity_function` is supported for +# computing the log-likelihood with automatic differentiation. Otherwise, the parameters of the +# model will soon become NaN. +################################################################################################ + import os import signal import sys @@ -11,7 +49,6 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt import torch -import torch.nn as nn from easydict import EasyDict from matplotlib import animation diff --git a/grl/test/test_swiss_roll_otcfm.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_otcfm.py similarity index 96% rename from grl/test/test_swiss_roll_otcfm.py rename to grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_otcfm.py index 0ffc24d..a6d903e 100644 --- a/grl/test/test_swiss_roll_otcfm.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_otcfm.py @@ -1,3 +1,7 @@ +################################################################################################ +# This script demonstrates how to use an Optimal Transport Conditional Flow Matching (OT-CFM), which is a flow model, to train Swiss Roll dataset. +################################################################################################ + import os import signal import sys @@ -11,7 +15,6 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt import torch -import torch.nn as nn from easydict import EasyDict from matplotlib import animation diff --git a/grl/test/test_swiss_roll_sdesolver.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sdesolver.py similarity index 99% rename from grl/test/test_swiss_roll_sdesolver.py rename to grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sdesolver.py index 4e67103..5c7a5f3 100644 --- a/grl/test/test_swiss_roll_sdesolver.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sdesolver.py @@ -11,7 +11,6 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt import torch -import torch.nn as nn from easydict import EasyDict from matplotlib import animation diff --git a/grl/test/test_swiss_roll_SchrodingerBridge.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sf2m.py similarity index 97% rename from grl/test/test_swiss_roll_SchrodingerBridge.py rename to grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sf2m.py index 204280f..5f4b0c1 100644 --- a/grl/test/test_swiss_roll_SchrodingerBridge.py +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_sf2m.py @@ -1,3 +1,7 @@ +################################################################################################ +# This script demonstrates how to use SF2M, which is a bridge model, to train Swiss Roll dataset. +################################################################################################ + import os import signal import sys @@ -11,14 +15,12 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt import torch -import torch.nn as nn from easydict import EasyDict from matplotlib import animation from grl.generative_models.bridge_flow_model.schrodinger_bridge_conditional_flow_model import ( SchrodingerBridgeConditionalFlowModel, ) -from grl.generative_models.metric import compute_likelihood from grl.utils import set_seed from grl.utils.log import log