Skip to content

Commit

Permalink
Polish sampling method.
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Oct 31, 2024
1 parent 4d9276e commit 1169fc5
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 53 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- New feature for better integration with generative models.

## [0.0.1] - 2024-10-31
### Added
- Initial release of `GenerativeRL` supporting Generative Models for RL.
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,24 @@ def sample_forward_process(
# x.shape = (B*N, D)

if condition is not None:
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
if isinstance(condition, torch.Tensor):
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
elif isinstance(condition, TensorDict):
condition = TensorDict(
{
key: torch.repeat_interleave(
condition[key], torch.prod(extra_batch_size), dim=0
)
for key in condition.keys()
},
batch_size=torch.prod(extra_batch_size) * condition.shape,
device=condition.device,
)
else:
raise NotImplementedError("Not implemented")

if isinstance(solver, DPMSolver):
raise NotImplementedError("Not implemented")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,24 @@ def sample_forward_process(
# x.shape = (B*N, D)

if condition is not None:
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
if isinstance(condition, torch.Tensor):
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
elif isinstance(condition, TensorDict):
condition = TensorDict(
{
key: torch.repeat_interleave(
condition[key], torch.prod(extra_batch_size), dim=0
)
for key in condition.keys()
},
batch_size=torch.prod(extra_batch_size) * condition.shape,
device=condition.device,
)
else:
raise NotImplementedError("Not implemented")

if isinstance(solver, DPMSolver):
raise NotImplementedError("Not implemented")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,24 @@ def sample_forward_process(
# x.shape = (B*N, D)

if condition is not None:
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
if isinstance(condition, torch.Tensor):
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
elif isinstance(condition, TensorDict):
condition = TensorDict(
{
key: torch.repeat_interleave(
condition[key], torch.prod(extra_batch_size), dim=0
)
for key in condition.keys()
},
batch_size=torch.prod(extra_batch_size) * condition.shape,
device=condition.device,
)
else:
raise NotImplementedError("Not implemented")

if isinstance(solver, DPMSolver):
raise NotImplementedError("Not implemented")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,24 @@ def sample_forward_process(
# x.shape = (B*N, D)

if condition is not None:
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
if isinstance(condition, torch.Tensor):
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
elif isinstance(condition, TensorDict):
condition = TensorDict(
{
key: torch.repeat_interleave(
condition[key], torch.prod(extra_batch_size), dim=0
)
for key in condition.keys()
},
batch_size=torch.prod(extra_batch_size) * condition.shape,
device=condition.device,
)
else:
raise NotImplementedError("Not implemented")

if isinstance(solver, DPMSolver):
raise NotImplementedError("Not implemented")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,24 @@ def sample_forward_process(
# x.shape = (B*N, D)

if condition is not None:
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
if isinstance(condition, torch.Tensor):
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
elif isinstance(condition, TensorDict):
condition = TensorDict(
{
key: torch.repeat_interleave(
condition[key], torch.prod(extra_batch_size), dim=0
)
for key in condition.keys()
},
batch_size=torch.prod(extra_batch_size) * condition.shape,
device=condition.device,
)
else:
raise NotImplementedError("Not implemented")

if isinstance(solver, DPMSolver):
raise NotImplementedError("Not implemented")
Expand Down
14 changes: 10 additions & 4 deletions grl/generative_models/diffusion_model/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,16 @@ def sample_forward_process(
)
# condition.shape = (B*N, D)
elif isinstance(condition, TensorDict):
for key in condition.keys():
condition[key] = torch.repeat_interleave(
condition[key], torch.prod(extra_batch_size), dim=0
)
condition = TensorDict(
{
key: torch.repeat_interleave(
condition[key], torch.prod(extra_batch_size), dim=0
)
for key in condition.keys()
},
batch_size=torch.prod(extra_batch_size) * condition.shape,
device=condition.device,
)
else:
raise NotImplementedError("Not implemented")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,26 +342,31 @@ def sample_forward_process(
# x.shape = (B*N, D)

if condition is not None:
if isinstance(condition, TensorDict):
repeated_condition = TensorDict(
if isinstance(condition, torch.Tensor):
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
elif isinstance(condition, treetensor.torch.Tensor):
for key in condition.keys():
condition[key] = torch.repeat_interleave(
condition[key], torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
elif isinstance(condition, TensorDict):
condition = TensorDict(
{
key: torch.repeat_interleave(
value, torch.prod(extra_batch_size), dim=0
condition[key], torch.prod(extra_batch_size), dim=0
)
for key, value in condition.items()
for key in condition.keys()
},
batch_size=int(
torch.prod(
torch.tensor([*condition.batch_size, extra_batch_size])
)
),
batch_size=torch.prod(extra_batch_size) * condition.shape,
device=condition.device,
)
repeated_condition.to(condition.device)
condition = repeated_condition
else:
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
raise NotImplementedError("Not implemented")

if isinstance(solver, DPMSolver):

def noise_function_with_energy_guidance(t, x, condition):
Expand Down
33 changes: 19 additions & 14 deletions grl/generative_models/diffusion_model/guided_diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,25 +233,30 @@ def sample_forward_process(
# x.shape = (B*N, D)

if condition is not None:
if isinstance(condition, TensorDict):
repeated_condition = TensorDict(
if isinstance(condition, torch.Tensor):
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
elif isinstance(condition, treetensor.torch.Tensor):
for key in condition.keys():
condition[key] = torch.repeat_interleave(
condition[key], torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
elif isinstance(condition, TensorDict):
condition = TensorDict(
{
key: torch.repeat_interleave(
value, torch.prod(extra_batch_size), dim=0
condition[key], torch.prod(extra_batch_size), dim=0
)
for key, value in condition.items()
}
for key in condition.keys()
},
batch_size=torch.prod(extra_batch_size) * condition.shape,
device=condition.device,
)
repeated_condition.batch_size = torch.Size(
[torch.prod(extra_batch_size).item()]
)
repeated_condition.to(condition.device)
condition = repeated_condition
else:
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
raise NotImplementedError("Not implemented")

if isinstance(solver, DPMSolver):
# Note: DPMSolver does not support t_span argument assignment
Expand Down
19 changes: 18 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from setuptools import setup, find_packages

with open('README.md', 'r', encoding="utf-8") as f:
readme = f.read()

setup(
name='GenerativeRL',
version='0.0.1',
description='PyTorch implementations of generative reinforcement learning algorithms',
long_description=readme,
long_description_content_type='text/markdown',
author='OpenDILab',
author_email="[email protected]",
url="https://github.com/opendilab/GenerativeRL",

packages=find_packages(
exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
Expand Down Expand Up @@ -47,5 +54,15 @@
'black',
'isort',
],
}
},
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
license="Apache-2.0",
)

0 comments on commit 1169fc5

Please sign in to comment.