Skip to content

Commit

Permalink
Autotuning not on A100, instructions, single line API for model optim…
Browse files Browse the repository at this point in the history
…izations. (#67)
  • Loading branch information
cpuhrsch authored Nov 17, 2023
1 parent 7cd6ba3 commit 6a420ed
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 21 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ The package acts like a drop-in replacement for segment-anything.

So, for example, if you're currently doing `from segment_anything import sam_model_registry` you should be able to do `from segment_anything_fast import sam_model_registry`.

However, you're likely here because you want to try a fast, inference version. So we also created a `sam_model_fast_registry` that automatically applies
- Sets `eval` mode
- Uses `bfloat16`
- Enables torch.compile with max-autotune
- Uses a custom Triton kernel that implements SDPA for relative positional encodings for long sequence lengths

The custom Triton kernel in particular was written for A100. If you're not using an A100, we will try to rerun autotuning on your device and locally save the best configs.
You might still run into performance issues, so you can disable the kernel by setting the environment variable `SEGMENT_ANYTHING_FAST_USE_FLASH_4=0`

Please also note that the first time you're running this model you'll likely need to wait a bit for it to compile.

If you'd like to see the details on how to reproduce all results, please see the README in the experiments folder above.

Please don't be shy to open a Github issue if you're missing functionality or find an issue. Thank you.

## Results
Expand Down
11 changes: 5 additions & 6 deletions experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,15 @@ These experiments were run on an Amazon p4d.24xlarge instance. See the Product
### Installation instructions

```
$ conda create -n nightly20231023py310
$ conda activate nightly20231023py310
$ conda create -n nightly20231117py310
$ conda activate nightly20231117py310
$ conda install python=3.10
$ pip install https://download.pytorch.org/whl/nightly/cu121/torch-2.2.0.dev20231023%2Bcu121-cp310-cp310-linux_x86_64.whl
$ pip install https://download.pytorch.org/whl/nightly/cu121/torchvision-0.17.0.dev20231023%2Bcu121-cp310-cp310-linux_x86_64.whl
$ cd /scratch/cpuhrsch/dev
$ pip install https://download.pytorch.org/whl/nightly/cu121/torch-2.2.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl
$ pip install https://download.pytorch.org/whl/nightly/cu121/torchvision-0.17.0.dev20231117%2Bcu121-cp310-cp310-linux_x86_64.whl
$ git clone https://github.com/cpuhrsch/segment-anything.git
$ cd segment-anything
$ pip install -e .
$ cd /scratch/cpuhrsch/dev
$ cd ..
$ git clone https://github.com/pytorch-labs/segment-anything-fast.git
$ cd segment-anything-fast
$ pip install -e .
Expand Down
40 changes: 28 additions & 12 deletions experiments/run_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def run(batch_size,
traces_dir=None,
num_workers=32,
print_header=True,
capture_output=True):
capture_output=True,
local_fork_only=False):

assert model == "vit_b" or model == "vit_h"

Expand All @@ -161,23 +162,38 @@ def run(batch_size,
assert traces_dir is not None
rt = functools.partial(run_traces_fn, traces_dir, pytorch_path, rexp)

rt("fp32", "default", print_header=print_header)
rt("fp16", "codesign", use_half="bfloat16")
rt("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")
rt("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune")
rt("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune")
if local_fork_only:
rt("fp32", "local-fork", print_header=print_header)
rt("fp16", "local-fork", use_half="bfloat16")
rt("compile", "local-fork", use_half="bfloat16", use_compile="max-autotune")
# The local fork already uses SDPA + Triton for all of the above experiments.
# local_fork_only mainly exists to ablate the order in which we apply
# techniques and cannot be used to reproduce the experimental results
else:
rt("fp32", "default", print_header=print_header)
rt("fp16", "codesign", use_half="bfloat16")
rt("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")
rt("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune")
rt("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune")
if batch_size > 1:
rt("NT", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True)
rt("int8", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True, compress="dynamic_quant")
rt("sparse", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=True, compress="sparse")

if run_experiments:
rexp("fp32", "default", print_header=print_header)
print_header = False
rexp("bf16", "codesign", use_half="bfloat16")
rexp("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")
rexp("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune")
rexp("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune")
if local_fork_only:
rexp("fp32", "local-fork", print_header=print_header)
rexp("bf16", "local-fork", use_half="bfloat16")
rexp("compile", "local-fork", use_half="bfloat16", use_compile="max-autotune")
# The local fork already uses SDPA + Triton for all of the above experiments.
# local_fork_only mainly exists to ablate the order in which we apply
# techniques and cannot be used to reproduce the experimental results
else:
rexp("fp32", "default", print_header=print_header)
rexp("bf16", "codesign", use_half="bfloat16")
rexp("compile", "codesign", use_half="bfloat16", use_compile="max-autotune")
rexp("SDPA", "sdpa-decoder", use_half="bfloat16", use_compile="max-autotune")
rexp("Triton", "local-fork", use_half="bfloat16", use_compile="max-autotune")
if batch_size > 1:
rexp("NT", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(batch_size > 1))
rexp("int8", "local-fork", use_half="bfloat16", use_compile="max-autotune", use_nested_tensor=(batch_size > 1), compress="dynamic_quant")
Expand Down
5 changes: 5 additions & 0 deletions segment_anything_fast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
build_sam_vit_l,
build_sam_vit_b,
sam_model_registry,
build_sam_fast,
build_sam_fast_vit_h,
build_sam_fast_vit_l,
build_sam_fast_vit_b,
sam_model_fast_registry,
)
from .predictor import SamPredictor
from .automatic_mask_generator import SamAutomaticMaskGenerator
40 changes: 40 additions & 0 deletions segment_anything_fast/build_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,46 @@ def build_sam_vit_b(checkpoint=None):
"vit_b": build_sam_vit_b,
}

def _apply_eval_dtype_sam(model, dtype=None):

def prep_model(model, dtype):
if dtype is not None:
return model.eval().to(dtype)
return model.eval()

model.image_encoder = prep_model(model.image_encoder, dtype)
model.prompt_encoder = prep_model(model.prompt_encoder, dtype)
model.mask_decoder = prep_model(model.mask_decoder, dtype)

return model

def build_sam_fast_vit_h(checkpoint=None):
sam = build_sam_vit_h(checkpoint)
sam = _apply_eval_dtype_sam(sam)
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
return sam

build_sam_fast = build_sam_fast_vit_h

def build_sam_fast_vit_l(checkpoint=None):
sam = build_sam_vit_l(checkpoint)
sam = _apply_eval_dtype_sam(sam)
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
return sam

def build_sam_fast_vit_b(checkpoint=None):
sam = build_sam_vit_b(checkpoint)
sam = _apply_eval_dtype_sam(sam)
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
return sam

sam_model_fast_registry = {
"default": build_sam_fast_vit_h,
"vit_h": build_sam_fast_vit_h,
"vit_l": build_sam_fast_vit_l,
"vit_b": build_sam_fast_vit_b,
}


def _build_sam(
encoder_embed_dim,
Expand Down
29 changes: 26 additions & 3 deletions segment_anything_fast/flash_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import triton
import triton.language as tl

import os
import pathlib


@triton.jit
def _fwd_kernel_aligned(
Expand Down Expand Up @@ -220,9 +223,18 @@ def _attention_rel_h_rel_w_kernel_aligned_device(q, k, v, rel_h_w, sm_scale, o,


def _load_best_configs():
device_name = torch.cuda.get_device_name()
if not device_name.startswith('NVIDIA A100'):
print("Warning: Custom flash attention kernels were written specifically for A100.")
import importlib
saved_configs = importlib.resources.files("segment_anything_fast")
saved_configs = saved_configs / "configs" / "flash_4_configs_a100.p"
if not device_name.startswith('NVIDIA A100'):
cwd = pathlib.Path.cwd()
saved_configs = cwd / "flash_4_configs.p"
print(f"We will try to read previously created kernel configurations from {saved_configs}.")
print("You can disable this kernel by setting SEGMENT_ANYTHING_FAST_USE_FLASH_4=0")
return None
if saved_configs.is_file():
import pickle
with open(saved_configs, 'rb') as f:
Expand All @@ -234,6 +246,11 @@ def _save_best_configs(best_configs):
import importlib
saved_configs = importlib.resources.files("segment_anything_fast")
saved_configs = saved_configs / "configs" / "flash_4_configs_a100.p"
device_name = torch.cuda.get_device_name()
if not device_name.startswith('NVIDIA A100'):
saved_configs = pathlib.Path.cwd() / "flash_4_configs.p"
print("Warning: Custom flash attention kernels were written specifically for A100.")
print(f"Storing configs for {device_name} locally under {saved_configs}")
with open(saved_configs, 'wb') as f:
import pickle
print(f"Saving best configs to file {saved_configs}")
Expand Down Expand Up @@ -277,7 +294,7 @@ def _attention_rel_h_rel_w_kernel_aligned(q, k, v, rel_h_w, sm_scale):
BEST_CONFIGS = _load_best_configs()
key = _create_best_configs_key(q, k, v, rel_h_w, o)
if key not in BEST_CONFIGS:
print("key ", key, " not found. Running autotune")
print("key ", key, " not found. Running autotune. This might take a while.")
import functools
import itertools
configs = []
Expand Down Expand Up @@ -309,6 +326,9 @@ def _attention_rel_h_rel_w_kernel_aligned(q, k, v, rel_h_w, sm_scale):
return o


USE_CUSTOM_KERNEL = bool(int(os.environ.get('SEGMENT_ANYTHING_FAST_USE_FLASH_4', 1)))


def _attention_rel_h_rel_w(q_, k_, v_, rel_h_, rel_w_):
"""
Writing this as a composite allows torch.compile to fuse
Expand All @@ -320,15 +340,18 @@ def _attention_rel_h_rel_w(q_, k_, v_, rel_h_, rel_w_):
sm_scale = 1. / math.sqrt(q_.size(-1))
# Check if second last dimension is multiple of 256
q_size_2_padded = (((q_.size(-2) + 256 - 1) // 256) * 256) - q_.size(-2)

def kernel_guards(q_, k_, v_):
return (q_.dtype == torch.bfloat16 or q_.dtype == torch.float16) and q_.dtype == k_.dtype and k_.dtype == v_.dtype and USE_CUSTOM_KERNEL
# vit_b and vit_l
if q_size_2_padded == 0 and q_.size(-1) == 64:
if q_size_2_padded == 0 and q_.size(-1) == 64 and kernel_guards(q_, k_, v_):
rel_h_w = torch.cat([rel_h_.squeeze(-1), rel_w_.squeeze(-2)], dim=-1)
o = torch.ops.customflash.custom_flash_aligned(
q_, k_, v_, rel_h_w, sm_scale)
if o.numel() > 0:
return o
# vit_h
if q_size_2_padded == 0 and q_.size(-1) == 80:
if q_size_2_padded == 0 and q_.size(-1) == 80 and kernel_guards(q_, k_, v_):
# Only support multiples of 64, so need to pad
q = torch.nn.functional.pad(q_, (0, 128 - 80, 0, 0), "constant", 0)
k = torch.nn.functional.pad(k_, (0, 128 - 80, 0, 0), "constant", 0)
Expand Down

0 comments on commit 6a420ed

Please sign in to comment.