diff --git a/CHANGELOG b/CHANGELOG new file mode 100644 index 0000000..a6f1e32 --- /dev/null +++ b/CHANGELOG @@ -0,0 +1,14 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] +### Added +- New feature for better integration with generative models. + +## [0.0.1] - 2024-10-31 +### Added +- Initial release of `GenerativeRL` supporting Generative Models for RL. \ No newline at end of file diff --git a/grl/generative_models/bridge_flow_model/guided_bridge_conditional_flow_model.py b/grl/generative_models/bridge_flow_model/guided_bridge_conditional_flow_model.py index 26d6522..8199d92 100644 --- a/grl/generative_models/bridge_flow_model/guided_bridge_conditional_flow_model.py +++ b/grl/generative_models/bridge_flow_model/guided_bridge_conditional_flow_model.py @@ -218,10 +218,24 @@ def sample_forward_process( # x.shape = (B*N, D) if condition is not None: - condition = torch.repeat_interleave( - condition, torch.prod(extra_batch_size), dim=0 - ) - # condition.shape = (B*N, D) + if isinstance(condition, torch.Tensor): + condition = torch.repeat_interleave( + condition, torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, TensorDict): + condition = TensorDict( + { + key: torch.repeat_interleave( + condition[key], torch.prod(extra_batch_size), dim=0 + ) + for key in condition.keys() + }, + batch_size=torch.prod(extra_batch_size) * condition.shape, + device=condition.device, + ) + else: + raise NotImplementedError("Not implemented") if isinstance(solver, DPMSolver): raise NotImplementedError("Not implemented") diff --git a/grl/generative_models/bridge_flow_model/schrodinger_bridge_conditional_flow_model.py b/grl/generative_models/bridge_flow_model/schrodinger_bridge_conditional_flow_model.py index 50abdec..cd59f5d 100644 --- a/grl/generative_models/bridge_flow_model/schrodinger_bridge_conditional_flow_model.py +++ b/grl/generative_models/bridge_flow_model/schrodinger_bridge_conditional_flow_model.py @@ -206,10 +206,24 @@ def sample_forward_process( # x.shape = (B*N, D) if condition is not None: - condition = torch.repeat_interleave( - condition, torch.prod(extra_batch_size), dim=0 - ) - # condition.shape = (B*N, D) + if isinstance(condition, torch.Tensor): + condition = torch.repeat_interleave( + condition, torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, TensorDict): + condition = TensorDict( + { + key: torch.repeat_interleave( + condition[key], torch.prod(extra_batch_size), dim=0 + ) + for key in condition.keys() + }, + batch_size=torch.prod(extra_batch_size) * condition.shape, + device=condition.device, + ) + else: + raise NotImplementedError("Not implemented") if isinstance(solver, DPMSolver): raise NotImplementedError("Not implemented") diff --git a/grl/generative_models/conditional_flow_model/guided_conditional_flow_model.py b/grl/generative_models/conditional_flow_model/guided_conditional_flow_model.py index cbd9e78..8b76da9 100644 --- a/grl/generative_models/conditional_flow_model/guided_conditional_flow_model.py +++ b/grl/generative_models/conditional_flow_model/guided_conditional_flow_model.py @@ -220,10 +220,24 @@ def sample_forward_process( # x.shape = (B*N, D) if condition is not None: - condition = torch.repeat_interleave( - condition, torch.prod(extra_batch_size), dim=0 - ) - # condition.shape = (B*N, D) + if isinstance(condition, torch.Tensor): + condition = torch.repeat_interleave( + condition, torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, TensorDict): + condition = TensorDict( + { + key: torch.repeat_interleave( + condition[key], torch.prod(extra_batch_size), dim=0 + ) + for key in condition.keys() + }, + batch_size=torch.prod(extra_batch_size) * condition.shape, + device=condition.device, + ) + else: + raise NotImplementedError("Not implemented") if isinstance(solver, DPMSolver): raise NotImplementedError("Not implemented") diff --git a/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.py b/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.py index f9e908c..b37f228 100644 --- a/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.py +++ b/grl/generative_models/conditional_flow_model/independent_conditional_flow_model.py @@ -261,10 +261,24 @@ def sample_forward_process( # x.shape = (B*N, D) if condition is not None: - condition = torch.repeat_interleave( - condition, torch.prod(extra_batch_size), dim=0 - ) - # condition.shape = (B*N, D) + if isinstance(condition, torch.Tensor): + condition = torch.repeat_interleave( + condition, torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, TensorDict): + condition = TensorDict( + { + key: torch.repeat_interleave( + condition[key], torch.prod(extra_batch_size), dim=0 + ) + for key in condition.keys() + }, + batch_size=torch.prod(extra_batch_size) * condition.shape, + device=condition.device, + ) + else: + raise NotImplementedError("Not implemented") if isinstance(solver, DPMSolver): raise NotImplementedError("Not implemented") diff --git a/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.py b/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.py index 3798985..5ccc70f 100644 --- a/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.py +++ b/grl/generative_models/conditional_flow_model/optimal_transport_conditional_flow_model.py @@ -212,10 +212,24 @@ def sample_forward_process( # x.shape = (B*N, D) if condition is not None: - condition = torch.repeat_interleave( - condition, torch.prod(extra_batch_size), dim=0 - ) - # condition.shape = (B*N, D) + if isinstance(condition, torch.Tensor): + condition = torch.repeat_interleave( + condition, torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, TensorDict): + condition = TensorDict( + { + key: torch.repeat_interleave( + condition[key], torch.prod(extra_batch_size), dim=0 + ) + for key in condition.keys() + }, + batch_size=torch.prod(extra_batch_size) * condition.shape, + device=condition.device, + ) + else: + raise NotImplementedError("Not implemented") if isinstance(solver, DPMSolver): raise NotImplementedError("Not implemented") diff --git a/grl/generative_models/diffusion_model/diffusion_model.py b/grl/generative_models/diffusion_model/diffusion_model.py index dd0dcf3..7093379 100644 --- a/grl/generative_models/diffusion_model/diffusion_model.py +++ b/grl/generative_models/diffusion_model/diffusion_model.py @@ -239,10 +239,16 @@ def sample_forward_process( ) # condition.shape = (B*N, D) elif isinstance(condition, TensorDict): - for key in condition.keys(): - condition[key] = torch.repeat_interleave( - condition[key], torch.prod(extra_batch_size), dim=0 - ) + condition = TensorDict( + { + key: torch.repeat_interleave( + condition[key], torch.prod(extra_batch_size), dim=0 + ) + for key in condition.keys() + }, + batch_size=torch.prod(extra_batch_size) * condition.shape, + device=condition.device, + ) else: raise NotImplementedError("Not implemented") diff --git a/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.py b/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.py index b65f843..22c4ffc 100644 --- a/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.py +++ b/grl/generative_models/diffusion_model/energy_conditional_diffusion_model.py @@ -342,26 +342,31 @@ def sample_forward_process( # x.shape = (B*N, D) if condition is not None: - if isinstance(condition, TensorDict): - repeated_condition = TensorDict( + if isinstance(condition, torch.Tensor): + condition = torch.repeat_interleave( + condition, torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, treetensor.torch.Tensor): + for key in condition.keys(): + condition[key] = torch.repeat_interleave( + condition[key], torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, TensorDict): + condition = TensorDict( { key: torch.repeat_interleave( - value, torch.prod(extra_batch_size), dim=0 + condition[key], torch.prod(extra_batch_size), dim=0 ) - for key, value in condition.items() + for key in condition.keys() }, - batch_size=int( - torch.prod( - torch.tensor([*condition.batch_size, extra_batch_size]) - ) - ), + batch_size=torch.prod(extra_batch_size) * condition.shape, + device=condition.device, ) - repeated_condition.to(condition.device) - condition = repeated_condition else: - condition = torch.repeat_interleave( - condition, torch.prod(extra_batch_size), dim=0 - ) + raise NotImplementedError("Not implemented") + if isinstance(solver, DPMSolver): def noise_function_with_energy_guidance(t, x, condition): diff --git a/grl/generative_models/diffusion_model/guided_diffusion_model.py b/grl/generative_models/diffusion_model/guided_diffusion_model.py index 487d65c..5b9c62f 100644 --- a/grl/generative_models/diffusion_model/guided_diffusion_model.py +++ b/grl/generative_models/diffusion_model/guided_diffusion_model.py @@ -233,25 +233,30 @@ def sample_forward_process( # x.shape = (B*N, D) if condition is not None: - if isinstance(condition, TensorDict): - repeated_condition = TensorDict( + if isinstance(condition, torch.Tensor): + condition = torch.repeat_interleave( + condition, torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, treetensor.torch.Tensor): + for key in condition.keys(): + condition[key] = torch.repeat_interleave( + condition[key], torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, TensorDict): + condition = TensorDict( { key: torch.repeat_interleave( - value, torch.prod(extra_batch_size), dim=0 + condition[key], torch.prod(extra_batch_size), dim=0 ) - for key, value in condition.items() - } + for key in condition.keys() + }, + batch_size=torch.prod(extra_batch_size) * condition.shape, + device=condition.device, ) - repeated_condition.batch_size = torch.Size( - [torch.prod(extra_batch_size).item()] - ) - repeated_condition.to(condition.device) - condition = repeated_condition else: - condition = torch.repeat_interleave( - condition, torch.prod(extra_batch_size), dim=0 - ) - # condition.shape = (B*N, D) + raise NotImplementedError("Not implemented") if isinstance(solver, DPMSolver): # Note: DPMSolver does not support t_span argument assignment diff --git a/setup.py b/setup.py index 9d8d78f..f1f22be 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,17 @@ from setuptools import setup, find_packages +with open('README.md', 'r', encoding="utf-8") as f: + readme = f.read() + setup( name='GenerativeRL', version='0.0.1', description='PyTorch implementations of generative reinforcement learning algorithms', + long_description=readme, + long_description_content_type='text/markdown', author='OpenDILab', + author_email="zjowowen@gmail.com", + url="https://github.com/opendilab/GenerativeRL", packages=find_packages( exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), @@ -47,5 +54,15 @@ 'black', 'isort', ], - } + }, + classifiers=[ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + ], + license="Apache-2.0", )