Skip to content

Commit

Permalink
Faster AMG (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored Dec 1, 2023
1 parent 2ed7d1a commit 97a69cf
Showing 8 changed files with 155 additions and 39 deletions.
46 changes: 39 additions & 7 deletions amg_example/amg_example.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,17 @@
import torch
import matplotlib.pyplot as plt
import cv2
import torch.utils.benchmark as benchmark

def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
record_shapes=True) as prof:
result = fn(*args, **kwargs)
print(f"Saving trace under {path}")
prof.export_chrome_trace(path)
return result

def show_anns(anns):
if len(anns) == 0:
@@ -22,25 +33,46 @@ def show_anns(anns):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


from segment_anything_fast import sam_model_registry, SamAutomaticMaskGenerator
from segment_anything_fast.tools import apply_eval_dtype_predictor
from segment_anything_fast import sam_model_registry, sam_model_fast_registry, SamAutomaticMaskGenerator

sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam, process_batch_size=8)

mask_generator = SamAutomaticMaskGenerator(sam)
mask_generator.predictor = apply_eval_dtype_predictor(mask_generator.predictor, torch.bfloat16)

# Run thrice for warmup
masks = mask_generator.generate(image)
masks = mask_generator.generate(image)
masks = mask_generator.generate(image)

# Save an example
plt.figure(figsize=(image.shape[1]/100., image.shape[0]/100.), dpi=100)
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.tight_layout()
plt.savefig('dog_mask_fast.png', format='png')

# Benchmark
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(10):
masks = mask_generator.generate(image)
end_event.record()
torch.cuda.synchronize()
print(start_event.elapsed_time(end_event) / 10.)

# Save a GPU trace
profiler_runner(f"amg_example_trace.json.gz", mask_generator.generate, image)

# Write out memory usage
max_memory_allocated_bytes = torch.cuda.max_memory_allocated()
_, total_memory = torch.cuda.mem_get_info()
max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
print(f"memory(MiB): {max_memory_allocated_bytes} memory(%): {max_memory_allocated_percentage}")
Binary file added amg_example/amg_example_trace.json.gz
Binary file not shown.
Binary file modified amg_example/dog_mask_fast.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
66 changes: 52 additions & 14 deletions segment_anything_fast/automatic_mask_generator.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@
generate_crop_boxes,
is_box_near_crop_edge,
mask_to_rle_pytorch,
mask_to_rle_pytorch_2,
remove_small_regions,
rle_to_mask,
uncrop_boxes_xyxy,
@@ -49,6 +50,7 @@ def __init__(
point_grids: Optional[List[np.ndarray]] = None,
min_mask_region_area: int = 0,
output_mode: str = "binary_mask",
process_batch_size: Optional[int] = None,
) -> None:
"""
Using a SAM model, generates masks for the entire image.
@@ -93,6 +95,10 @@ def __init__(
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
For large resolutions, 'binary_mask' may consume large amounts of
memory.
process_batch_size (int or None): Set a batch size for the decoding step.
If None, all points will be batched up at once. Set a small number here
to decrease memory footprint. A smaller number will likely decrease
latency, but also decrease memory usage.
"""

assert (points_per_side is None) != (
@@ -132,6 +138,7 @@ def __init__(
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
self.min_mask_region_area = min_mask_region_area
self.output_mode = output_mode
self.process_batch_size = process_batch_size

@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
@@ -241,10 +248,13 @@ def _process_crop(

# Generate masks for this crop in batches
data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
all_points = [points for (points,) in batch_iterator(self.points_per_batch, points_for_image)]
process_batch_size = len(all_points) if self.process_batch_size is None else self.process_batch_size
for i in range(0, len(all_points), process_batch_size):
some_points = all_points[i:i+process_batch_size]
batch_data = self._process_batch(some_points, cropped_im_size, crop_box, orig_size)
data.cat(batch_data)
del batch_data
data["rles"] = mask_to_rle_pytorch_2(data["masks"])
self.predictor.reset_image()

# Remove duplicates within this crop.
@@ -265,24 +275,50 @@ def _process_crop(

def _process_batch(
self,
points: np.ndarray,
all_points: List[np.ndarray],
im_size: Tuple[int, ...],
crop_box: List[int],
orig_size: Tuple[int, ...],
) -> MaskData:
orig_h, orig_w = orig_size

# Run model on this batch
transformed_points = self.predictor.transform.apply_coords(points, im_size)
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
masks, iou_preds, _ = self.predictor.predict_torch(
in_points[:, None, :],
in_labels[:, None],
nt_in_points = []
for points in all_points:
# Run model on this batch
transformed_points = self.predictor.transform.apply_coords(points, im_size)
in_points = torch.as_tensor(transformed_points) #, device=self.predictor.device)
nt_in_points.append(in_points)

nt_in_points = torch.nested.nested_tensor(nt_in_points, layout=torch.jagged, pin_memory=True).to(device=self.predictor.device, non_blocking=True)
# The call to prod is a workaround to share jagged sizes between two NestedTensors.
nt_in_labels = torch.ones_like(nt_in_points, dtype=torch.int).prod(dim=-1, keepdim=True)
nt_in_points = nt_in_points.unsqueeze(2)

self.predictor.input_sizes = [self.predictor.input_size for _ in range(len(nt_in_points))]
self.predictor.original_sizes = [self.predictor.original_size for _ in range(len(nt_in_points))]
nt_masks, nt_iou_preds, _ = self.predictor.predict_torch(
point_coords=nt_in_points,
point_labels=nt_in_labels,
multimask_output=True,
return_logits=True,
)

data = MaskData()
for masks, iou_preds, points in zip(nt_masks.unbind(), nt_iou_preds.unbind(), all_points):
batch_data = self._process_batch_2(masks, iou_preds, points, im_size, crop_box, orig_size)
data.cat(batch_data)
return data

# TODO: Batch this up
def _process_batch_2(
self,
masks: torch.Tensor,
iou_preds: torch.Tensor,
points: torch.Tensor,
im_size: Tuple[int, ...],
crop_box: List[int],
orig_size: Tuple[int, ...],
) -> MaskData:
orig_h, orig_w = orig_size
# Serialize predictions and store in MaskData
data = MaskData(
masks=masks.flatten(0, 1),
@@ -315,8 +351,10 @@ def _process_batch(

# Compress to RLE
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
data["rles"] = mask_to_rle_pytorch(data["masks"])
del data["masks"]
# Doing this once at the end across all masks.
# data["rles"] = mask_to_rle_pytorch(data["masks"].cpu())
# Keeping the masks around is faster, even though it uses more memory.
# del data["masks"]

return data

20 changes: 10 additions & 10 deletions segment_anything_fast/build_sam.py
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ def build_sam_vit_b(checkpoint=None):
"vit_b": build_sam_vit_b,
}

def _apply_eval_dtype_sam(model, dtype=None):
def _apply_eval_dtype_sam(model, dtype):

def prep_model(model, dtype):
if dtype is not None:
@@ -64,24 +64,24 @@ def prep_model(model, dtype):

return model

def build_sam_fast_vit_h(checkpoint=None):
def build_sam_fast_vit_h(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
sam = build_sam_vit_h(checkpoint)
sam = _apply_eval_dtype_sam(sam)
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
sam = _apply_eval_dtype_sam(sam, dtype)
sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode)
return sam

build_sam_fast = build_sam_fast_vit_h

def build_sam_fast_vit_l(checkpoint=None):
def build_sam_fast_vit_l(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
sam = build_sam_vit_l(checkpoint)
sam = _apply_eval_dtype_sam(sam)
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
sam = _apply_eval_dtype_sam(sam, dtype)
sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode)
return sam

def build_sam_fast_vit_b(checkpoint=None):
def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
sam = build_sam_vit_b(checkpoint)
sam = _apply_eval_dtype_sam(sam)
sam.image_encoder = torch.compile(sam.image_encoder, mode='max-autotune')
sam = _apply_eval_dtype_sam(sam, dtype)
sam.image_encoder = torch.compile(sam.image_encoder, mode=compile_mode)
return sam

sam_model_fast_registry = {
8 changes: 1 addition & 7 deletions segment_anything_fast/modeling/prompt_encoder.py
Original file line number Diff line number Diff line change
@@ -157,13 +157,10 @@ def forward(
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
"""
return_dtype = None
bs = self._get_batch_size(points, boxes, masks)
if points is not None:
coords, labels = points
sparse_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
if sparse_embeddings.dtype != coords.dtype:
return_dtype = coords.dtype
if boxes is not None:
sparse_embeddings = self._embed_boxes(boxes)

@@ -183,10 +180,7 @@ def forward(
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])

r0, r1 = sparse_embeddings.to(dense_embeddings.dtype), dense_embeddings
if return_dtype is None:
return r0, r1
return r0.to(return_dtype), r1.to(return_dtype)
return sparse_embeddings.to(dense_embeddings.dtype), dense_embeddings


class PositionEmbeddingRandom(nn.Module):
36 changes: 35 additions & 1 deletion segment_anything_fast/utils/amg.py
Original file line number Diff line number Diff line change
@@ -72,7 +72,7 @@ def cat(self, new_stats: "MaskData") -> None:
def to_numpy(self) -> None:
for k, v in self._stats.items():
if isinstance(v, torch.Tensor):
self._stats[k] = v.detach().cpu().numpy()
self._stats[k] = v.detach().cpu().float().numpy()


def is_box_near_crop_edge(
@@ -103,6 +103,40 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
for b in range(n_batches):
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]

def mask_to_rle_pytorch_2(tensor: torch.Tensor) -> List[Dict[str, Any]]:
"""
Encodes masks to an uncompressed RLE, in the format expected by
pycoco tools.
"""
# Put in fortran order and flatten h,w
b, h, w = tensor.shape
tensor = tensor.permute(0, 2, 1).flatten(1)

# Compute change indices
diff = tensor[:, 1:] ^ tensor[:, :-1]
a = torch.tensor([[True]]).pin_memory().cuda().expand_as(diff.narrow(1, 0, 1))
diff = torch.cat([a, diff, a], dim=1)
change_indices = diff.nonzero()

alt_lens = diff.sum(dim=1).tolist()

all_cur_idx = change_indices[:, 1]
all_btw_idx = torch.cat([all_cur_idx[1:], all_cur_idx[:1]]) - all_cur_idx
all_btw_idx = all_btw_idx.detach().cpu().tolist()

# Encode run length
out = []
counts_init = (tensor[:, 0] == 0).tolist()
offset = 0
for i, ci in zip(range(b), counts_init):
btw_idxs = all_btw_idx[offset:offset + alt_lens[i]][:-1]
offset += alt_lens[i]
counts = [] if ci else [0]
counts.extend(btw_idxs)
out.append({"size": [h, w], "counts": counts})

return out


def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
"""
18 changes: 18 additions & 0 deletions test/test_mask_to_rle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import itertools
from segment_anything_fast.utils.amg import (
mask_to_rle_pytorch,
mask_to_rle_pytorch_2,
)

def test_masks(masks):
rles_0 = mask_to_rle_pytorch(masks)
rles_2 = mask_to_rle_pytorch_2(masks)

for i in range(len(rles_0)):
torch.testing.assert_close(torch.tensor(rles_0[i]['counts']), torch.tensor(rles_2[i]['counts']))

for b, w, h in itertools.product([1, 5], [50, 128], [50, 128]):
test_masks(torch.randn(b, w, h).clamp(min=0).bool().cuda())
test_masks(torch.randn(b, w, h).mul(0).bool().cuda())
test_masks(torch.randn(b, w, h).mul(0).add(1).bool().cuda())

0 comments on commit 97a69cf

Please sign in to comment.