Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Vectorized priority update in replay buffers #1598

Merged
merged 6 commits into from
Oct 4, 2023

Conversation

matteobettini
Copy link
Contributor

This is a patch for #1574

It fixes the core problem highlighted in that issue but there are still points that will need attention in a future refactoring of this class as I am not sure it is compatible with all the cases it aims to support. I am happy to elucidate more about this if needed.

Signed-off-by: Matteo Bettini <[email protected]>
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 3, 2023
@vmoens vmoens added bug Something isn't working performance Performance issue or suggestion for improvement labels Oct 3, 2023
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, see my few minor comments

torchrl/data/replay_buffers/replay_buffers.py Outdated Show resolved Hide resolved
torchrl/data/replay_buffers/replay_buffers.py Outdated Show resolved Hide resolved
Comment on lines 778 to 782
priority = torch.tensor(
[self._get_priority_item(td) for td in data],
dtype=torch.float,
device=data.device,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we assume that the stack dim is 0 but it could not be (?)
I think we can consider the priority can be stacked, no? At the end of the day it's supposed to be one priority per item. Maybe I'm missing something

Copy link
Contributor Author

@matteobettini matteobettini Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was the previous treatment so I am not super sure what was going on or why things were this way.

my guess is that it is assuming 0 as stack dim because in expand it stacks on 0 and also because that is the dim of the indeces and priority

are you suggesting to completely remove this and always go vectorized? I am down to try.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it was this way because the priority can have different shapes along the stack dim?

That is the only explanation I can guess

torchrl/data/replay_buffers/replay_buffers.py Outdated Show resolved Hide resolved
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

@vmoens vmoens merged commit 3d2c161 into pytorch:main Oct 4, 2023
55 of 59 checks passed
@matteobettini matteobettini deleted the fix_prioritised_buffer branch October 4, 2023 07:38
vmoens added a commit to hyerra/rl that referenced this pull request Oct 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. performance Performance issue or suggestion for improvement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants