diff --git a/segment_anything_fast/utils/amg.py b/segment_anything_fast/utils/amg.py index dcae96e..2892cc2 100644 --- a/segment_anything_fast/utils/amg.py +++ b/segment_anything_fast/utils/amg.py @@ -114,7 +114,10 @@ def mask_to_rle_pytorch_2(tensor: torch.Tensor) -> List[Dict[str, Any]]: # Compute change indices diff = tensor[:, 1:] ^ tensor[:, :-1] - a = torch.tensor([[True]]).pin_memory().cuda().expand_as(diff.narrow(1, 0, 1)) + a = torch.tensor([[True]]) + if diff.is_cuda: + a = a.pin_memory().cuda() + a = a.expand_as(diff.narrow(1, 0, 1)) diff = torch.cat([a, diff, a], dim=1) change_indices = diff.nonzero()