Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] TensorDictMap #2306

Merged
merged 6 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,26 @@ The following classes are deprecated and just point to the classes above:
UnboundedContinuousTensorSpec
UnboundedDiscreteTensorSpec

Trees and Forests
-----------------

TorchRL offers a set of classes and functions that can be used to represent trees and forests efficiently.

.. currentmodule:: torchrl.data

.. autosummary::
:toctree: generated/
:template: rl_template.rst

BinaryToDecimal
HashToInt
QueryModule
RandomProjectionHash
SipHash
TensorDictMap
TensorMap


Reinforcement Learning From Human Feedback (RLHF)
-------------------------------------------------

Expand Down
125 changes: 124 additions & 1 deletion test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,23 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import functools
import importlib.util

import pytest

import torch

from tensordict import TensorDict
from torchrl.data.map import BinaryToDecimal, QueryModule, RandomProjectionHash, SipHash
from torchrl.data import LazyTensorStorage, ListStorage
from torchrl.data.map import (
BinaryToDecimal,
QueryModule,
RandomProjectionHash,
SipHash,
TensorDictMap,
)
from torchrl.envs import GymEnv

_has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec(
"gym", None
Expand Down Expand Up @@ -114,6 +123,120 @@ def test_query(self, clone, index_key):
for i in range(1, 3):
assert res[index_key][i].item() != res[index_key][i + 1].item()

def test_query_module(self):
query_module = QueryModule(
in_keys=["key1", "key2"],
index_key="index",
hash_module=SipHash(),
)

embedding_storage = LazyTensorStorage(23)

tensor_dict_storage = TensorDictMap(
query_module=query_module,
storage=embedding_storage,
)

index = TensorDict(
{
"key1": torch.Tensor([[-1], [1], [3], [-3]]),
"key2": torch.Tensor([[0], [2], [4], [-4]]),
},
batch_size=(4,),
)

value = TensorDict(
{"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)
)

tensor_dict_storage[index] = value
assert torch.sum(tensor_dict_storage.contains(index)).item() == 4

new_index = index.clone(True)
new_index["key3"] = torch.Tensor([[4], [5], [6], [7]])
retrieve_value = tensor_dict_storage[new_index]

assert (retrieve_value["index"] == value["index"]).all()


class TesttTensorDictMap:
@pytest.mark.parametrize(
"storage_type",
[
functools.partial(ListStorage, 1000),
functools.partial(LazyTensorStorage, 1000),
],
)
def test_map(self, storage_type):
query_module = QueryModule(
in_keys=["key1", "key2"],
index_key="index",
hash_module=SipHash(),
)

embedding_storage = storage_type()

tensor_dict_storage = TensorDictMap(
query_module=query_module,
storage=embedding_storage,
)

index = TensorDict(
{
"key1": torch.Tensor([[-1], [1], [3], [-3]]),
"key2": torch.Tensor([[0], [2], [4], [-4]]),
},
batch_size=(4,),
)

value = TensorDict(
{"index": torch.Tensor([[10], [20], [30], [40]])}, batch_size=(4,)
)
assert not hasattr(tensor_dict_storage, "out_keys")

tensor_dict_storage[index] = value
if isinstance(embedding_storage, LazyTensorStorage):
assert hasattr(tensor_dict_storage, "out_keys")
else:
assert not hasattr(tensor_dict_storage, "out_keys")
assert tensor_dict_storage._has_lazy_out_keys()
assert torch.sum(tensor_dict_storage.contains(index)).item() == 4

new_index = index.clone(True)
new_index["key3"] = torch.Tensor([[4], [5], [6], [7]])
retrieve_value = tensor_dict_storage[new_index]

assert (retrieve_value["index"] == value["index"]).all()

@pytest.mark.skipif(not _has_gym, reason="gym not installed")
def test_map_rollout(self):
torch.manual_seed(0)
env = GymEnv("CartPole-v1")
env.set_seed(0)
rollout = env.rollout(100)
source, dest = rollout.exclude("next"), rollout.get("next")
storage = TensorDictMap.from_tensordict_pair(
source,
dest,
in_keys=["observation", "action"],
)
storage_indices = TensorDictMap.from_tensordict_pair(
source,
dest,
in_keys=["observation"],
out_keys=["_index"],
)
# maps the (obs, action) tuple to a corresponding next state
storage[source] = dest
storage_indices[source] = source
contains = storage.contains(source)
assert len(contains) == rollout.shape[-1]
assert contains.all()
contains = storage.contains(torch.cat([source, source + 1]))
assert len(contains) == rollout.shape[-1] * 2
assert contains[: rollout.shape[-1]].all()
assert not contains[rollout.shape[-1] :].any()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
10 changes: 9 additions & 1 deletion torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .map import BinaryToDecimal, HashToInt, QueryModule, RandomProjectionHash, SipHash
from .map import (
BinaryToDecimal,
HashToInt,
QueryModule,
RandomProjectionHash,
SipHash,
TensorDictMap,
TensorMap,
)
from .postprocs import MultiStep
from .replay_buffers import (
Flat2TED,
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/map/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@

from .hash import BinaryToDecimal, RandomProjectionHash, SipHash
from .query import HashToInt, QueryModule
from .tdstorage import TensorDictMap, TensorMap
Loading
Loading