Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimal transport with torch + GPU? #136

Open
spinjo opened this issue Sep 28, 2024 · 1 comment
Open

Optimal transport with torch + GPU? #136

spinjo opened this issue Sep 28, 2024 · 1 comment

Comments

@spinjo
Copy link

spinjo commented Sep 28, 2024

In the current implementation, tensors are moved to numpy + CPU before calling the optimal transport solver, see e.g. https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/optimal_transport.py#L88.

Since version 0.8, the POT package supports backends beyond numpy, and GPU acceleration, see https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends-on-cpu-gpu. This can speed up the OT solver especially for large batchsize, and enables new features like differentiation through the solver. Is there a reason why torchcfm uses the numpy + cpu policy?

I am successfully using the torch + GPU support of POT and am happy to file a PR if there is interest in including this in torchcfm.

@atong01
Copy link
Owner

atong01 commented Sep 30, 2024

Hi,

Thanks for opening this issue! I haven't tried this, and am very curious if you've found any performance benefits?

From my understanding, there do not exist fast exact OT GPU solvers, only Sinkhorn-based ones. If you look behind the hood of POT, I believe even using a backend with GPU support still solves exact OT on the CPU:
https://pythonot.github.io/quickstart.html#gpu-acceleration

When implementing this package there was a bug in POT that allocated extra GPU memory when using pytorch backend (see PythonOT/POT#523). I also found that when doing DDP POT seemed to be doing operations only on one of the GPUs (so incurring GPU-GPU transfer cost). So I didn't explore this option at the time and manually disabled the backend.

I haven't checked if they are doing anything smarter than what I implemented here. Would be happy to include in torchcfm if you file a PR at least as optional. Before making default would like to know that

  1. It's at least not slower
  2. Works seamlessly on multi-gpu setups

Do you know of any methods that are making use of back-propping through the solver in this setting? As you mention it could be useful there.

Thanks again for your
Alex

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants