From 5dc3b3c0f65f464d04ba5a26e7c6a4ef6dc507f4 Mon Sep 17 00:00:00 2001 From: Christian Puhrsch Date: Wed, 13 Dec 2023 20:57:40 +0000 Subject: [PATCH] Support CPU for mask_to_rle_pytorch_2 --- segment_anything_fast/utils/amg.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/segment_anything_fast/utils/amg.py b/segment_anything_fast/utils/amg.py index dcae96e..2892cc2 100644 --- a/segment_anything_fast/utils/amg.py +++ b/segment_anything_fast/utils/amg.py @@ -114,7 +114,10 @@ def mask_to_rle_pytorch_2(tensor: torch.Tensor) -> List[Dict[str, Any]]: # Compute change indices diff = tensor[:, 1:] ^ tensor[:, :-1] - a = torch.tensor([[True]]).pin_memory().cuda().expand_as(diff.narrow(1, 0, 1)) + a = torch.tensor([[True]]) + if diff.is_cuda: + a = a.pin_memory().cuda() + a = a.expand_as(diff.narrow(1, 0, 1)) diff = torch.cat([a, diff, a], dim=1) change_indices = diff.nonzero()