Skip to content

Commit

Permalink
Add CPU support and update README (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j authored May 14, 2024
1 parent 2ad6b94 commit bd37672
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 22 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ For example:
```
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
```
or
```
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
```

Installation instructions vary by platform. Please see the website https://pytorch.org/

Expand Down
21 changes: 17 additions & 4 deletions experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ These experiments were run on an Amazon p4d.24xlarge instance. See the Product
- 1152 GiB of RAM
- Software

Meanwhile, these experiments (fp32, bf16, compile, SDPA, Triton, NT) can run on CPU platform as well. Experiment results will be shown in the near future.

### Versions

Expand All @@ -47,11 +48,17 @@ These experiments were run on an Amazon p4d.24xlarge instance. See the Product
### Installation instructions

```
$ conda create -n nightly20231117py310
$ conda activate nightly20231117py310
$ conda create -n nightlypy310
$ conda activate nightlypy310
$ conda install python=3.10
$ pip install https://download.pytorch.org/whl/nightly/cu121/torch-2.2.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl
$ pip install https://download.pytorch.org/whl/nightly/cu121/torchvision-0.17.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl
For GPU,
- $ pip install https://download.pytorch.org/whl/nightly/cu121/torch-2.2.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl
- $ pip install https://download.pytorch.org/whl/nightly/cu121/torchvision-0.17.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl
For CPU,
- $ pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240509%2Bcpu-cp310-cp310-linux_x86_64.whl
- $ pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240509%2Bcpu-cp310-cp310-linux_x86_64.whl
- $ pip install triton
$ git clone https://github.com/cpuhrsch/segment-anything.git
$ cd segment-anything
$ pip install -e .
Expand All @@ -66,10 +73,16 @@ If you plan to run the scripts that run the experiments from segment-anything-fa

### How to run experiments

For GPU platform,
```
$ python run_experiments.py 16 vit_b <pytorch_github> <segment-anything_github> <path_to_experiments_data> --run-experiments --num-workers 32
```

For CPU platform, set SEGMENT_ANYTHING_FAST_USE_FLASH_4 as 0, since Custom flash attention kernels were written specifically for A100.
```
$ SEGMENT_ANYTHING_FAST_USE_FLASH_4=0 python run_experiments.py 16 vit_b <pytorch_github> <segment-anything_github> <path_to_experiments_data> --run-experiments --num-workers 32 --device cpu
```

If at any point you run into issue, please note that you can increase verbosity by adding `--capture_output False` to above command. Also, please don't hesitate to open an issue.


Expand Down
44 changes: 29 additions & 15 deletions experiments/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from data import build_data, setup_coco_img_ids
import math
import segment_anything_fast
import time

torch._dynamo.config.cache_size_limit = 50000

Expand Down Expand Up @@ -64,10 +65,13 @@ def build_results_batch_nested(predictor, batch, batch_size, pad_input_image_bat
# We explicitly exclude data transfers from the timing to focus
# only on the kernel performance.
# Next we synchronize and set two events to start timing.
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
if torch.cuda.is_available():
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
else:
t0 = time.time()

with torch.autograd.profiler.record_function("timed region"):
with torch.autograd.profiler.record_function("image encoder"):
Expand All @@ -93,9 +97,12 @@ def build_results_batch_nested(predictor, batch, batch_size, pad_input_image_bat
# the amount of time spent on the GPU. This is a fairly tight measurement
# around the launched GPU kernels and excludes data movement from host
# to device.
end_event.record()
torch.cuda.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
if torch.cuda.is_available():
end_event.record()
torch.cuda.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
else:
elapsed_time = time.time() - t0
return sum(result_batch, []), orig_input_image_batch_size, elapsed_time

def build_results_batch(predictor, batch, batch_size, pad_input_image_batch):
Expand Down Expand Up @@ -123,10 +130,13 @@ def build_results_batch(predictor, batch, batch_size, pad_input_image_batch):
# We explicitly exclude data transfers from the timing to focus
# only on the kernel performance.
# Next we synchronize and set two events to start timing.
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
if torch.cuda.is_available():
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
else:
t0 = time.time()

with torch.autograd.profiler.record_function("timed region"):
with torch.autograd.profiler.record_function("image encoder"):
Expand Down Expand Up @@ -157,9 +167,12 @@ def build_results_batch(predictor, batch, batch_size, pad_input_image_batch):
# the amount of time spent on the GPU. This is a fairly tight measurement
# around the launched GPU kernels and excludes data movement from host
# to device.
end_event.record()
torch.cuda.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
if torch.cuda.is_available():
end_event.record()
torch.cuda.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
else:
elapsed_time = time.time() - t0
return result_batch, orig_input_image_batch_size, elapsed_time


Expand Down Expand Up @@ -290,6 +303,7 @@ def run(
memory_path=None,
use_local_sam_fork=False,
use_compiler_settings=False,
device="cuda"
):
from torch._inductor import config as inductorconfig
inductorconfig.triton.unique_kernel_names = True
Expand Down Expand Up @@ -327,7 +341,7 @@ def run(
else:
from segment_anything import sam_model_registry, SamPredictor
checkpoint_path = model_type_to_checkpoint[sam_model_type]
sam = sam_model_registry[sam_model_type](checkpoint=checkpoint_path).cuda()
sam = sam_model_registry[sam_model_type](checkpoint=checkpoint_path).to(torch.device(device))
predictor = SamPredictor(sam)

from segment_anything_fast import tools
Expand Down
10 changes: 7 additions & 3 deletions experiments/run_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def run_experiment(experiments_data,
limit=None,
profile_path=None,
profile_top=False,
memory_path=None):
memory_path=None,
device="cuda"):
root_cmd = ["python", "eval_combo.py",
"--coco_root_dir",
f"{experiments_data}/datasets/coco2017",
Expand Down Expand Up @@ -84,6 +85,7 @@ def run_experiment(experiments_data,
args = args + ["--memory-path", memory_path]
if extra_args is None:
extra_args = []
args = args + ["--device", device]
args = args + extra_args
if print_header:
args = args + ["--print_header", "True"]
Expand Down Expand Up @@ -145,7 +147,8 @@ def run(batch_size,
num_workers=32,
print_header=True,
capture_output=True,
local_fork_only=False):
local_fork_only=False,
device="cuda"):

assert model == "vit_b" or model == "vit_h"

Expand All @@ -155,7 +158,8 @@ def run(batch_size,
model,
batch_size=batch_size,
num_workers=num_workers,
capture_output=capture_output)
capture_output=capture_output,
device=device)

print_header = True
if run_traces:
Expand Down

0 comments on commit bd37672

Please sign in to comment.