Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added multithreading to OTPlanSampler for "exact" solver #131

Merged
merged 2 commits into from
Aug 9, 2024

Conversation

yashizhang
Copy link
Contributor

What does this PR do?

This PR adds support for OpenMP multithreading via the OTPlanSampler on initialization for method == "exact". Here's the link for Python OT's documentation of the pot.emd method.

On my HPC (AMD EPYC 7742 CPU, A100 GPU), running num_threads=2 or num_threads=4 gave up to 2x speedups. This is hardware dependent so the default value will be num_threads=1, as it has been.

If you would like to test threading performance yourself, here's the code that I used:

import argparse 
import time 
import torch 
import numpy as np 
import multiprocessing
from functools import partial
from torchcfm.optimal_transport import OTPlanSampler 
from torchcfm.utils import sample_8gaussians, sample_moons
from tqdm import tqdm 


def parse_args():
   parser = argparse.ArgumentParser()
   parser.add_argument("--num_samples", type=int, default=500)
   parser.add_argument("--num_rounds", type=int, default=50)
   return parser.parse_args()


if __name__ == "__main__":
   args = parse_args()


   samplers = {}
   for num_threads in [1, 2, 4, 8, 16, 32, 64]:
       samplers[num_threads] = OTPlanSampler(method="exact", num_threads=num_threads)
   #samplers[-1] = OTPlanSampler(method="exact", num_threads="max")


   torch.manual_seed(42)
   np.random.seed(42)


   samples = []
   for _ in range(args.num_rounds):
       x0 = sample_8gaussians(args.num_samples)
       x1 = sample_moons(args.num_samples)
       samples.append((x0, x1))


   # Test with POT implementation
   times = {}
   for num_threads, sampler in samplers.items():
       start = time.time()
       for x0, x1 in tqdm(samples):
           sampler.get_map(x0, x1)
       times[num_threads] = (time.time() - start)
   print("="*50)
   print(f"Using POT Implementation, Num. Samples: {args.num_samples}, Num. Rounds: {args.num_rounds}") 
   print("="*50)
   for num_threads, _time in times.items():
       print(f"Num. Threads: {num_threads}, Time: {_time:.2f}s")

Before submitting

  • Did you make sure title is self-explanatory and the description concisely explains the PR?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you test your PR locally with pytest command?
  • Did you run pre-commit hooks with pre-commit run -a command?

@yashizhang
Copy link
Contributor Author

yashizhang commented Aug 9, 2024

I forgot that Python versions < 3.10 do not support type | type, so my type hint int | str needs to be modified. Added from typing import Union and changed to Union[int, str]

@atong01
Copy link
Owner

atong01 commented Aug 9, 2024

LGTM thank you for the PR.

Just going to put the timing information here on your system for reference.

AMD EPYC 7742 CPU, A100 GPU

Using POT Implementation, Num. Samples: 500, Num. Rounds: 50

Num. Threads: 1, Time: 14.23s
Num. Threads: 2, Time: 7.58s
Num. Threads: 4, Time: 10.67s
Num. Threads: 8, Time: 16.80s
Num. Threads: 16, Time: 23.88s
Num. Threads: 32, Time: 42.76s
Num. Threads: 64, Time: 108.10s

@atong01 atong01 merged commit b4525b5 into atong01:main Aug 9, 2024
31 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants