Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Prioritised TensorDict replay buffers use for loops over the batch dimension #1574

Closed
matteobettini opened this issue Sep 25, 2023 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@matteobettini
Copy link
Contributor

In prioritised tensordict replay buffers, the update_tensordict_priority, method performs a for loop over the batch dimension

priority = torch.tensor(
[self._get_priority(td) for td in data],
dtype=torch.float,
device=data.device,

This causes significant slowdowns as this is the vectorised dimension used in the training pipelines and can get to really high sizes.

This method is called every time the buffer is extended or the priorities are updated.

@matteobettini matteobettini added the bug Something isn't working label Sep 25, 2023
@vmoens
Copy link
Contributor

vmoens commented Oct 1, 2023

Any update on this? Can I help?

@matteobettini
Copy link
Contributor Author

I am not super sure about all the cases that were envisioned when the component was created this way or why we were doing this.
If you can give me some context i can try to fix it or if you already know an easy fix feel free to do it.

@vmoens
Copy link
Contributor

vmoens commented Oct 1, 2023

We just wanted to be able to handle lists of tensordicts i guess, and it should work ok with a list storage and stuff that you can't necessarily stack well, no more insight than this I'm afraid.
But in general I think that if all tests pass we should be covered.

@matteobettini
Copy link
Contributor Author

i don't think this is currently the case as the update_tensordict_priority expects TensorDictBase since in its code it calls methods such as ndim, get and so on.

while extend is compatible with lists, if extends is called with a list that is not stackable as a lazy stack, it will fail when the data is passed to update_tensordict_priority

furthermore, if extends is called with a list, the _data and index fileds are not set, so it is not clear where the update_tensordict_priority should be able to find them.

i can try to do some patch work to make update_tensordict_priority and _get_priority work vectorized, but i am very much confused by how this class is supposed to work and its contracts

@vmoens
Copy link
Contributor

vmoens commented Oct 3, 2023

If the method is broken with list it won't be bc-breaking if we don't support lists anymore, and any fix that enables the passing of lists to extend will be a "new feature" (since now it isn't a feature).
In other words: i'm ok with considering that everything passed to that method is a td of some sort

@matteobettini
Copy link
Contributor Author

Ok I'll do this in #1598

@vmoens vmoens closed this as completed Oct 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants