From 4f919b78f47dd55b176ae2a3161ecfb4c72967b2 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Fri, 7 Jun 2024 10:18:04 +0900 Subject: [PATCH 01/31] Add initial version of SAM wrappers --- model_api/python/model_api/models/__init__.py | 3 + .../python/model_api/models/sam_models.py | 195 ++++++++++++++++++ model_api/python/model_api/models/utils.py | 25 +++ .../model_api/models/visual_prompting.py | 86 ++++++++ 4 files changed, 309 insertions(+) create mode 100644 model_api/python/model_api/models/sam_models.py create mode 100644 model_api/python/model_api/models/visual_prompting.py diff --git a/model_api/python/model_api/models/__init__.py b/model_api/python/model_api/models/__init__.py index bea47b39..1a5e6e3a 100644 --- a/model_api/python/model_api/models/__init__.py +++ b/model_api/python/model_api/models/__init__.py @@ -39,6 +39,7 @@ from .segmentation import SalientObjectDetectionModel, SegmentationModel from .ssd import SSD from .ultra_lightweight_face_detection import UltraLightweightFaceDetection +from .sam_models import SAMDecoder, SAMImageEncoder from .utils import ( AnomalyResult, ClassificationResult, @@ -118,6 +119,8 @@ "YOLOv8", "YOLOF", "YOLOX", + "SAMDecoder", + "SAMImageEncoder", "ClassificationResult", "Detection", "DetectionResult", diff --git a/model_api/python/model_api/models/sam_models.py b/model_api/python/model_api/models/sam_models.py new file mode 100644 index 00000000..4e9df8c6 --- /dev/null +++ b/model_api/python/model_api/models/sam_models.py @@ -0,0 +1,195 @@ +""" + Copyright (C) 2024 Intel Corporation + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from copy import deepcopy +from typing import Any + +import numpy as np +from model_api.models import ImageModel, SegmentationModel +from model_api.models.types import BooleanValue, NumericalValue, StringValue + +from model_api.adapters.inference_adapter import InferenceAdapter + + +class SAMImageEncoder(ImageModel): + """Image Encoder for SAM: https://arxiv.org/abs/2304.02643""" + + __model__ = "sam_image_encoder" + + def __init__( + self, + inference_adapter: InferenceAdapter, + configuration: dict[str, Any] | None = None, + preload: bool = False, + ): + super().__init__(inference_adapter, configuration, preload) + + @classmethod + def parameters(cls) -> dict[str, Any]: + parameters = super().parameters() + parameters.update( + { + "image_size": NumericalValue( + value_type=int, default_value=1024, min=0, max=2048 + ), + }, + ) + return parameters + + def preprocess( + self, inputs: np.ndarray + ) -> tuple[dict[str, np.ndarray], dict[str, Any]]: + """Update meta for image encoder.""" + dict_inputs, meta = super().preprocess(inputs) + meta["resize_type"] = self.resize_type + return dict_inputs, meta + + +class SAMDecoder(SegmentationModel): + """Image Decoder for SAM: https://arxiv.org/abs/2304.02643""" + + __model__ = "sam_decoder" + + def __init__( + self, + model_adapter: InferenceAdapter, + configuration: dict | None = None, + preload: bool = False, + ): + super().__init__(model_adapter, configuration, preload) + + self.mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) + self.has_mask_input = np.zeros((1, 1), dtype=np.float32) + + @classmethod + def parameters(cls) -> dict[str, Any]: + parameters = super().parameters() + parameters.update( + { + "image_size": NumericalValue( + value_type=int, default_value=1024, min=0, max=2048 + ) + } + ) + parameters.update( + { + "mask_threshold": NumericalValue( + value_type=float, default_value=0.0, min=0, max=1 + ) + } + ) + parameters.update( + { + "embed_dim": NumericalValue( + value_type=int, default_value=256, min=0, max=512 + ) + } + ) + parameters.update({"embedded_processing": BooleanValue(default_value=True)}) + return parameters + + def _get_outputs(self) -> str: + return "upscaled_masks" + + def preprocess(self, inputs: dict[str, Any]) -> list[dict[str, Any]]: + """Preprocess prompts.""" + processed_prompts: list[dict[str, Any]] = [] + for prompt_name in ["bboxes", "points"]: + if (prompts := inputs.get(prompt_name, None)) is None or ( + labels := inputs["labels"].get(prompt_name, None) + ) is None: + continue + + for prompt, label in zip(prompts, labels): + if prompt_name == "bboxes": + point_coords = self.apply_coords( + prompt.reshape(-1, 2, 2), inputs["orig_size"] + ) + point_labels = np.array([2, 3], dtype=np.float32).reshape(-1, 2) + else: + point_coords = self.apply_coords( + prompt.reshape(-1, 1, 2), inputs["orig_size"] + ) + point_labels = np.array([1], dtype=np.float32).reshape(-1, 1) + + processed_prompts.append( + { + "point_coords": point_coords, + "point_labels": point_labels, + "mask_input": self.mask_input, + "has_mask_input": self.has_mask_input, + "orig_size": np.array( + inputs["orig_size"], dtype=np.int64 + ).reshape(-1, 2), + "label": label, + }, + ) + return processed_prompts + + def apply_coords( + self, coords: np.ndarray, orig_size: np.ndarray | list[int] | tuple[int, int] + ) -> np.ndarray: + """Process coords according to preprocessed image size using image meta.""" + old_h, old_w = orig_size + new_h, new_w = self._get_preprocess_shape(old_h, old_w, self.image_size) + coords = deepcopy(coords).astype(np.float32) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def _get_preprocess_shape( + self, old_h: int, old_w: int, image_size: int + ) -> tuple[int, int]: + """Compute the output size given input size and target image size.""" + scale = image_size / max(old_h, old_w) + new_h, new_w = old_h * scale, old_w * scale + new_w = int(new_w + 0.5) + new_h = int(new_h + 0.5) + return (new_h, new_w) + + def _check_io_number( + self, number_of_inputs: int | tuple[int], number_of_outputs: int | tuple[int] + ) -> None: + pass + + def _get_inputs(self) -> tuple[list[str], list[str]]: + """Get input layer name and shape.""" + image_blob_names = list(self.inputs.keys()) + image_info_blob_names: list = [] + return image_blob_names, image_info_blob_names + + def postprocess( + self, outputs: dict[str, np.ndarray], meta: dict[str, Any] + ) -> dict[str, np.ndarray]: + """Postprocess to convert soft prediction to hard prediction. + + Args: + outputs (dict[str, np.ndarray]): The output of the model. + meta (dict[str, Any]): Contain label and original size. + + Returns: + (dict[str, np.ndarray]): The postprocessed output of the model. + """ + probability = np.clip(outputs["scores"], 0.0, 1.0) + hard_prediction = ( + outputs[self.output_blob_name].squeeze(1) > self.mask_threshold + ).astype(np.uint8) + soft_prediction = hard_prediction * probability + + outputs["hard_prediction"] = hard_prediction + outputs["soft_prediction"] = soft_prediction + + return outputs diff --git a/model_api/python/model_api/models/utils.py b/model_api/python/model_api/models/utils.py index 592e6f1b..a8e649b6 100644 --- a/model_api/python/model_api/models/utils.py +++ b/model_api/python/model_api/models/utils.py @@ -136,6 +136,31 @@ def __str__(self): return f"{obj_str}; {filled}; [{','.join(str(i) for i in self.feature_vector.shape)}]" +class VisualPromptingResult(NamedTuple): + upscaled_masks: list[np.ndarray] | None = None + low_res_masks: list[np.ndarray] | None = None + iou_predictions: list[np.ndarray] | None = None + scores: list[np.ndarray] | None = None + labels: list[np.ndarray] | None = None + hard_predictions: list[np.ndarray] | None = None + soft_predictions: list[np.ndarray] | None = None + + def _compute_min_max(self, tensor: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + return tensor.min(), tensor.max() + + def __str__(self) -> str: + assert self.hard_predictions is not None + assert self.upscaled_masks is not None + upscaled_masks_min, upscaled_masks_max = self._compute_min_max( + self.upscaled_masks[0] + ) + + return ( + f"upscaled_masks min:{upscaled_masks_min} max:{upscaled_masks_max};" + f"hard_predictions shape:{self.hard_predictions[0].shape};" + ) + + def add_rotated_rects(segmented_objects): objects_with_rects = [] for segmented_object in segmented_objects: diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py new file mode 100644 index 00000000..554ec85b --- /dev/null +++ b/model_api/python/model_api/models/visual_prompting.py @@ -0,0 +1,86 @@ +""" + Copyright (C) 2024 Intel Corporation + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +from typing import Any + +import numpy as np + +from model_api.models import SAMImageEncoder, SAMDecoder +from model_api.models.utils import VisualPromptingResult + + +class SAMVisualPrompter: + def __init__( + self, + encoder_model: SAMImageEncoder, + decoder_model: SAMDecoder, + ): + self.encoder_model = encoder_model + self.decoder_model = decoder_model + + def infer( + self, + image: np.ndarray, + boxes: np.ndarray | None, + points: np.ndarray | None, + labels: dict[str, np.ndarray] | None, + ) -> VisualPromptingResult: + outputs: list[dict[str, Any]] = [] + + processed_image, meta = self.encoder_model.preprocess(image) + image_embeddings = self.encoder_model.infer_sync(processed_image) + processed_prompts = self.decoder_model.preprocess( + { + "bboxes": boxes, + "points": points, + "labels": labels, + "orig_size": meta["original_shape"][:2], + }, + ) + + for prompt in processed_prompts: + label = prompt.pop("label") + prompt.update(**image_embeddings) + + prediction = self.decoder_model.infer_sync(prompt) + prediction["scores"] = prediction["iou_predictions"] + prediction["labels"] = label + processed_prediction = self.decoder_model.postprocess(prediction, meta) + outputs.append(processed_prediction) + + return VisualPromptingResult( + upscaled_masks=[item["upscaled_masks"] for item in outputs], + low_res_masks=[item["low_res_masks"] for item in outputs], + iou_predictions=[item["iou_predictions"] for item in outputs], + scores=[item["scores"] for item in outputs], + labels=[item["labels"] for item in outputs], + hard_predictions=[item["hard_prediction"] for item in outputs], + soft_predictions=[item["soft_prediction"] for item in outputs], + ) + + def __call__(self, + image: np.ndarray, + boxes: np.ndarray | None, + points: np.ndarray | None, + labels: dict[str, np.ndarray] | None, + ) -> VisualPromptingResult: + return self.infer(image, boxes, points, labels) + + +class SAMLearnableVisualPrompter(SAMVisualPrompter): + def learn(self, image, prompts, reset_ref_featires: bool = False): + if reset_ref_featires or self.ref_embeddings_state is None: + self._reset_inference_features() + + def infer(self, image, reference_features): + pass From c2a0361b46b668e11f3d96535536e37057c39b12 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 8 Jun 2024 11:00:00 +0900 Subject: [PATCH 02/31] Add initial implementation of ZSL --- .../model_api/models/visual_prompting.py | 587 +++++++++++++++++- 1 file changed, 576 insertions(+), 11 deletions(-) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index 554ec85b..d2c67de3 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -12,8 +12,12 @@ """ from typing import Any +from copy import deepcopy +from collections import defaultdict +from itertools import product import numpy as np +import cv2 from model_api.models import SAMImageEncoder, SAMDecoder from model_api.models.utils import VisualPromptingResult @@ -68,19 +72,580 @@ def infer( soft_predictions=[item["soft_prediction"] for item in outputs], ) - def __call__(self, - image: np.ndarray, - boxes: np.ndarray | None, - points: np.ndarray | None, - labels: dict[str, np.ndarray] | None, + def __call__( + self, + image: np.ndarray, + boxes: np.ndarray | None, + points: np.ndarray | None, + labels: dict[str, np.ndarray] | None, ) -> VisualPromptingResult: return self.infer(image, boxes, points, labels) -class SAMLearnableVisualPrompter(SAMVisualPrompter): - def learn(self, image, prompts, reset_ref_featires: bool = False): - if reset_ref_featires or self.ref_embeddings_state is None: - self._reset_inference_features() +class SAMLearnableVisualPrompter: + def __init__( + self, + encoder_model: SAMImageEncoder, + decoder_model: SAMDecoder, + reference_features: np.ndarray | None = None, + ): + self.encoder_model = encoder_model + self.decoder_model = decoder_model + self.reference_features = reference_features + self.used_indices = None + + self.point_labels_box = np.array([[2, 3]], dtype=np.float32) + self.has_mask_inputs = [np.array([[0.0]]), np.array([[1.0]])] + + self.is_cascade: bool = True + self.threshold: float = 0.0 + self.num_bg_points: int = 1 + self.default_threshold_target: float = 0.65 + self.image_size: int = self.encoder_model.image_size + self.downsizing: int = 64 + self.default_threshold_reference: float = 0.3 + + if self.reference_features is None: + self.reset_reference_info() + + def has_reference_features(self) -> bool: + return self.reference_features is not None + + def learn( + self, + image: np.ndarray, + boxes: np.ndarray | None, + points: np.ndarray | None, + labels: dict[str, np.ndarray] | None, + ): + processed_image, meta = self.encoder_model.preprocess(image) + processed_prompts = self.decoder_model.preprocess( + { + "bboxes": boxes, + "points": points, + "labels": labels, + "orig_size": meta["original_shape"][:2], + }, + ) + + processed_prompts_w_labels = self._gather_prompts_with_labels(processed_prompts) + largest_label: int = max([int(p) for p in processed_prompts_w_labels] + [0]) + self._expand_reference_info(largest_label) + + original_shape = np.array(meta["original_shape"][:2]) + + # forward image encoder + image_embeddings = self.encoder_model.infer_sync(processed_image) + processed_embedding = ( + image_embeddings["image_embeddings"].squeeze().transpose(1, 2, 0) + ) + + # get reference masks + ref_masks: np.ndarray = np.zeros( + (largest_label + 1, *original_shape), dtype=np.uint8 + ) + for label, input_prompts in processed_prompts_w_labels.items(): + ref_mask: np.ndarray = np.zeros(original_shape, dtype=np.uint8) + for inputs_decoder in input_prompts: + label = inputs_decoder.pop("label") # noqa: PLW2901 + if "point_coords" in inputs_decoder: + # bboxes and points + inputs_decoder.update(image_embeddings) + prediction = self._predict_masks( + inputs_decoder, original_shape, is_cascade=self.is_cascade + ) + masks = prediction["upscaled_masks"] + else: + # log.warning("annotation and polygon will be supported.") + continue + ref_mask[masks] += 1 + ref_mask = np.clip(ref_mask, 0, 1) + + ref_feat: np.ndarray | None = None + cur_default_threshold_reference = deepcopy(self.default_threshold_reference) + while ref_feat is None: + # log.info(f"[*] default_threshold_reference : {cur_default_threshold_reference:.4f}") + ref_feat = self._generate_masked_features( + feats=processed_embedding, + masks=ref_mask, + threshold_mask=cur_default_threshold_reference, + image_size=self.encoder_model.image_size, + ) + cur_default_threshold_reference -= 0.05 + + self.reference_feats[label] = ref_feat + self.used_indices: np.ndarray = np.concatenate((self.used_indices, label)) + ref_masks[label] = ref_mask + + self.used_indices = np.unique(self.used_indices) + + return { + "reference_feats": self.reference_feats, + "used_indices": self.used_indices, + }, ref_masks + + def reset_reference_info(self) -> None: + """Initialize reference information.""" + self.reference_feats = np.zeros( + (0, 1, self.decoder_model.embed_dim), dtype=np.float32 + ) + self.used_indices = np.array([], dtype=np.int64) + + def _gather_prompts_with_labels( + self, + image_prompts: list[dict[str, np.ndarray]], + ) -> dict[int, list[dict[str, np.ndarray]]]: + """Gather prompts according to labels.""" + + processed_prompts: defaultdict[int, list[dict[str, np.ndarray]]] = defaultdict( + list + ) + for prompt in image_prompts: + processed_prompts[int(prompt["label"])].append(prompt) + + return dict(sorted(processed_prompts.items(), key=lambda x: x)) + + def _expand_reference_info(self, new_largest_label: int) -> None: + """Expand reference info dimensions if newly given processed prompts have more lables.""" + if new_largest_label > (cur_largest_label := len(self.reference_feats) - 1): + diff = new_largest_label - cur_largest_label + self.reference_feats = np.pad( + self.reference_feats, ((0, diff), (0, 0), (0, 0)), constant_values=0.0 + ) + + def _generate_masked_features( + self, + feats: np.ndarray, + masks: np.ndarray, + threshold_mask: float, + image_size: int = 1024, + ) -> np.ndarray | None: + """Generate masked features. + + Args: + feats (np.ndarray): Raw reference features. It will be filtered with masks. + masks (np.ndarray): Reference masks used to filter features. + threshold_mask (float): Threshold to control masked region. + image_size (int): Input image size. + + Returns: + (np.ndarray): Masked features. + """ + target_shape = image_size / max(masks.shape) * np.array(masks.shape) + target_shape = target_shape[::-1].astype(np.int32) + + # Post-process masks + masks = cv2.resize(masks, target_shape, interpolation=cv2.INTER_LINEAR) + masks = self._pad_to_square(masks, image_size) + masks = cv2.resize(masks, feats.shape[:2][::-1], interpolation=cv2.INTER_LINEAR) + + # Target feature extraction + if (masks > threshold_mask).sum() == 0: + # (for stability) there is no area to be extracted + return None + + masked_feat = feats[masks > threshold_mask] + masked_feat = masked_feat.mean(0)[None] + return masked_feat / np.linalg.norm(masked_feat, axis=-1, keepdims=True) + + def _pad_to_square(self, x: np.ndarray, image_size: int = 1024) -> np.ndarray: + """Pad to a square input. + + Args: + x (np.ndarray): Mask to be padded. + + Returns: + (np.ndarray): Padded mask. + """ + h, w = x.shape[-2:] + padh = image_size - h + padw = image_size - w + return np.pad(x, ((0, padh), (0, padw)), constant_values=0.0) + + def _predict_masks( + self, + inputs: dict[str, np.ndarray], + original_size: np.ndarray, + is_cascade: bool = False, + ) -> dict[str, np.ndarray]: + """Process function of OpenVINO Visual Prompting Inferencer.""" + masks: np.ndarray + logits: np.ndarray + scores: np.ndarray + num_iter = 3 if is_cascade else 1 + for i in range(num_iter): + if i == 0: + # First-step prediction + mask_input = np.zeros( + (1, 1, *(x * 4 for x in inputs["image_embeddings"].shape[2:])), + dtype=np.float32, + ) + has_mask_input = self.has_mask_inputs[0] + + elif i == 1: + # Cascaded Post-refinement-1 + mask_input, masks = self._decide_masks( + masks, logits, scores, is_single=True + ) # noqa: F821 + if masks.sum() == 0: + return {"upscaled_masks": masks} + + has_mask_input = self.has_mask_inputs[1] + + elif i == 2: + # Cascaded Post-refinement-2 + mask_input, masks = self._decide_masks( + masks, logits, scores + ) # noqa: F821 + if masks.sum() == 0: + return {"upscaled_masks": masks} + + has_mask_input = self.has_mask_inputs[1] + y, x = np.nonzero(masks) + box_coords = self.decoder_model.apply_coords( + np.array( + [[x.min(), y.min()], [x.max(), y.max()]], dtype=np.float32 + ), + original_size, + ) + box_coords = np.expand_dims(box_coords, axis=0) + inputs.update( + { + "point_coords": np.concatenate( + (inputs["point_coords"], box_coords), axis=1 + ), + "point_labels": np.concatenate( + (inputs["point_labels"], self.point_labels_box), axis=1 + ), + }, + ) + + inputs.update({"mask_input": mask_input, "has_mask_input": has_mask_input}) + prediction = self.decoder_model.infer_sync(inputs) + upscaled_masks, scores, logits = ( + prediction["upscaled_masks"], + prediction["iou_predictions"], + prediction["low_res_masks"], + ) + masks = upscaled_masks > self.decoder_model.mask_threshold + + _, masks = self._decide_masks(masks, logits, scores) + return {"upscaled_masks": masks} + + def _decide_masks( + self, + masks: np.ndarray, + logits: np.ndarray, + scores: np.ndarray, + is_single: bool = False, + ) -> tuple[np.ndarray, ...] | tuple[None, np.ndarray]: + """Post-process logits for resized masks according to best index based on scores.""" + if is_single: + best_idx = 0 + else: + # skip the first index components + scores, masks, logits = (x[:, 1:] for x in (scores, masks, logits)) + + # filter zero masks + while ( + len(scores[0]) > 0 + and masks[0, (best_idx := np.argmax(scores[0]))].sum() == 0 + ): + scores, masks, logits = ( + np.concatenate((x[:, :best_idx], x[:, best_idx + 1 :]), axis=1) + for x in (scores, masks, logits) + ) + + if len(scores[0]) == 0: + # all predicted masks were zero masks, ignore them. + return None, np.zeros(masks.shape[-2:]) + + best_idx = np.argmax(scores[0]) + return logits[:, [best_idx]], masks[0, best_idx] + + def infer( + self, + image: np.ndarray, + reference_features: np.ndarray | None, + used_indices: np.ndarray | None, + ): + if reference_features is None: + if self.reference_features is None: + raise RuntimeError( + "Reference features are not defined. This parameter can be passed via SAMLearnableVisualPrompter constructor, or as an argument of infer() method" + ) + else: + reference_features = self.reference_features + + if used_indices is None: + if self.used_indices is None: + raise RuntimeError( + "Used indices are not defined. This parameter can be passed via SAMLearnableVisualPrompter constructor, or as an argument of infer() method" + ) + else: + used_indices = self.used_indices + + processed_image, meta = self.encoder_model.preprocess(image) + original_shape = np.array(meta["original_shape"][:2]) + + image_embeddings = self.encoder_model.infer_sync(processed_image) + + total_points_scores, total_bg_coords = self._get_prompt_candidates( + image_embeddings=image_embeddings["image_embeddings"], + reference_feats=reference_features, + used_indices=used_indices, + original_shape=original_shape, + threshold=self.threshold, + num_bg_points=self.num_bg_points, + default_threshold_target=self.default_threshold_target, + image_size=self.image_size, + downsizing=self.downsizing, + ) + + predicted_masks: defaultdict[int, list] = defaultdict(list) + used_points: defaultdict[int, list] = defaultdict(list) + for label in total_points_scores: + points_scores = total_points_scores[label] + bg_coords = total_bg_coords[label] + for points_score in points_scores: + if points_score[-1] in [-1.0, 0.0]: + continue + + x, y = points_score[:2] + is_done = False + for pm in predicted_masks.get(label, []): + # check if that point is already assigned + if pm[int(y), int(x)] > 0: + is_done = True + break + if is_done: + continue + + point_coords = np.concatenate( + (np.array([[x, y]]), bg_coords), axis=0, dtype=np.float32 + ) + point_coords = self.decoder_model.apply_coords( + point_coords, original_shape + ) + point_labels = np.array([1] + [0] * len(bg_coords), dtype=np.float32) + inputs_decoder = { + "point_coords": point_coords[None], + "point_labels": point_labels[None], + "orig_size": original_shape[None], + } + inputs_decoder.update(image_embeddings) + + prediction = self._predict_masks( + inputs_decoder, original_shape, self.is_cascade + ) + prediction.update({"scores": points_score[-1]}) + + predicted_masks[label].append( + prediction[self.decoder_model.output_blob_name] + ) + used_points[label].append(points_score) + + # check overlapping area between different label masks + self._inspect_overlapping_areas(predicted_masks, used_points) + + return (predicted_masks, used_points) + + def _get_prompt_candidates( + self, + image_embeddings: np.ndarray, + reference_feats: np.ndarray, + used_indices: np.ndarray, + original_shape: np.ndarray, + threshold: float = 0.0, + num_bg_points: int = 1, + default_threshold_target: float = 0.65, + image_size: int = 1024, + downsizing: int = 64, + ) -> tuple[dict[int, np.ndarray], dict[int, np.ndarray]]: + """Get prompt candidates.""" + target_feat = image_embeddings.squeeze() + c_feat, h_feat, w_feat = target_feat.shape + target_feat = target_feat / np.linalg.norm(target_feat, axis=0, keepdims=True) + target_feat = target_feat.reshape(c_feat, h_feat * w_feat) + + total_points_scores: dict[int, np.ndarray] = {} + total_bg_coords: dict[int, np.ndarray] = {} + for label in used_indices: + sim = reference_feats[label] @ target_feat + sim = sim.reshape(h_feat, w_feat) + sim = self._resize_to_original_shape(sim, image_size, original_shape) + + threshold = (threshold == 0) * default_threshold_target + threshold + points_scores, bg_coords = self._point_selection( + mask_sim=sim, + original_shape=original_shape, + threshold=threshold, + num_bg_points=num_bg_points, + image_size=image_size, + downsizing=downsizing, + ) + + if points_scores is not None: + total_points_scores[label] = points_scores + total_bg_coords[label] = bg_coords + return total_points_scores, total_bg_coords + + def _point_selection( + self, + mask_sim: np.ndarray, + original_shape: np.ndarray, + threshold: float = 0.0, + num_bg_points: int = 1, + image_size: int = 1024, + downsizing: int = 64, + ) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]: + """Select point used as point prompts.""" + _, w_sim = mask_sim.shape + + # Top-first point selection + point_coords = np.where(mask_sim > threshold) + fg_coords_scores = np.stack( + point_coords[::-1] + (mask_sim[point_coords],), axis=0 + ).T + + ## skip if there is no point coords + if len(fg_coords_scores) == 0: + return None, None + + ratio = image_size / original_shape.max() + width = (original_shape[1] * ratio).astype(np.int64) + n_w = width // downsizing + + ## get grid numbers + idx_grid = ( + fg_coords_scores[:, 1] * ratio // downsizing * n_w + + fg_coords_scores[:, 0] * ratio // downsizing + ) + idx_grid_unique = np.unique(idx_grid.astype(np.int64)) + + ## get matched indices + matched_matrix = ( + np.expand_dims(idx_grid, axis=-1) == idx_grid_unique + ) # (totalN, uniqueN) + + ## sample fg_coords_scores matched by matched_matrix + matched_grid = np.expand_dims(fg_coords_scores, axis=1) * np.expand_dims( + matched_matrix, axis=-1 + ) + + ## sample the highest score one of the samples that are in the same grid + matched_indices = self._topk_numpy( + matched_grid[..., -1], k=1, axis=0, largest=True + )[1][0].astype(np.int64) + points_scores = matched_grid[matched_indices].diagonal().T + + ## sort by the highest score + sorted_points_scores_indices = np.flip( + np.argsort(points_scores[:, -1]), axis=-1 + ).astype(np.int64) + points_scores = points_scores[sorted_points_scores_indices] + + # Top-last point selection + bg_indices = self._topk_numpy(mask_sim.flatten(), num_bg_points, largest=False)[ + 1 + ] + bg_x = np.expand_dims(bg_indices // w_sim, axis=0) + bg_y = bg_indices - bg_x * w_sim + bg_coords = np.concatenate((bg_y, bg_x), axis=0).transpose(1, 0) + bg_coords = bg_coords.astype(np.float32) + + return points_scores, bg_coords + + def _resize_to_original_shape( + self, masks: np.ndarray, image_size: int, original_shape: np.ndarray + ) -> np.ndarray: + """Resize feature size to original shape.""" + # resize feature size to input size + masks = cv2.resize( + masks, (image_size, image_size), interpolation=cv2.INTER_LINEAR + ) + + # remove pad + prepadded_size = self._get_prepadded_size(original_shape, image_size) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] + + # resize unpadded one to original shape + original_shape = original_shape.astype(np.int64) + h, w = original_shape[0], original_shape[1] + return cv2.resize(masks, (w, h), interpolation=cv2.INTER_LINEAR) + + def _get_prepadded_size(self, original_shape: int, image_size: int) -> np.ndarray: + """Get pre-padded size.""" + scale = image_size / np.max(original_shape) + transformed_size = scale * original_shape + return np.floor(transformed_size + 0.5).astype(np.int64) + + def _topk_numpy( + self, x: np.ndarray, k: int, axis: int = -1, largest: bool = True + ) -> tuple[np.ndarray, np.ndarray]: + """Top-k function for numpy same with torch.topk.""" + if largest: + k = -k + indices = range(k, 0) + else: + indices = range(k) + partitioned_ind = np.argpartition(x, k, axis=axis).take( + indices=indices, axis=axis + ) + partitioned_scores = np.take_along_axis(x, partitioned_ind, axis=axis) + sorted_trunc_ind = np.argsort(partitioned_scores, axis=axis) + if largest: + sorted_trunc_ind = np.flip(sorted_trunc_ind, axis=axis) + ind = np.take_along_axis(partitioned_ind, sorted_trunc_ind, axis=axis) + scores = np.take_along_axis(partitioned_scores, sorted_trunc_ind, axis=axis) + return scores, ind + + def _inspect_overlapping_areas( + self, + predicted_masks: dict[int, list[np.ndarray]], + used_points: dict[int, list[np.ndarray]], + threshold_iou: float = 0.8, + ) -> None: + def _calculate_mask_iou( + mask1: np.ndarray, mask2: np.ndarray + ) -> tuple[float, np.ndarray | None]: + assert mask1.ndim == 2 # noqa: S101 + assert mask2.ndim == 2 # noqa: S101 + # Avoid division by zero + if (union := np.logical_or(mask1, mask2).sum().item()) == 0: + return 0.0, None + intersection = np.logical_and(mask1, mask2) + return intersection.sum().item() / union, intersection + + for (label, masks), (other_label, other_masks) in product( + predicted_masks.items(), predicted_masks.items() + ): + if other_label <= label: + continue + + overlapped_label = [] + overlapped_other_label = [] + for (im, mask), (jm, other_mask) in product( + enumerate(masks), enumerate(other_masks) + ): + _mask_iou, _intersection = _calculate_mask_iou(mask, other_mask) + if _mask_iou > threshold_iou: + if used_points[label][im][2] > used_points[other_label][jm][2]: + overlapped_other_label.append(jm) + else: + overlapped_label.append(im) + elif _mask_iou > 0: + # refine the slightly overlapping region + overlapped_coords = np.where(_intersection) + if used_points[label][im][2] > used_points[other_label][jm][2]: + other_mask[overlapped_coords] = 0.0 + else: + mask[overlapped_coords] = 0.0 + + for im in sorted(set(overlapped_label), reverse=True): + masks.pop(im) + used_points[label].pop(im) - def infer(self, image, reference_features): - pass + for jm in sorted(set(overlapped_other_label), reverse=True): + other_masks.pop(jm) + used_points[other_label].pop(jm) From 39d6622206d149035669a3fb91afb334f424175e Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Wed, 12 Jun 2024 05:02:00 +0900 Subject: [PATCH 03/31] Update type annotation --- model_api/python/model_api/models/sam_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model_api/python/model_api/models/sam_models.py b/model_api/python/model_api/models/sam_models.py index 4e9df8c6..c20583f7 100644 --- a/model_api/python/model_api/models/sam_models.py +++ b/model_api/python/model_api/models/sam_models.py @@ -15,11 +15,11 @@ """ from copy import deepcopy -from typing import Any +from typing import Any, Dict import numpy as np from model_api.models import ImageModel, SegmentationModel -from model_api.models.types import BooleanValue, NumericalValue, StringValue +from model_api.models.types import BooleanValue, NumericalValue from model_api.adapters.inference_adapter import InferenceAdapter @@ -32,7 +32,7 @@ class SAMImageEncoder(ImageModel): def __init__( self, inference_adapter: InferenceAdapter, - configuration: dict[str, Any] | None = None, + configuration: Dict[str, Any] | None = dict(), preload: bool = False, ): super().__init__(inference_adapter, configuration, preload) @@ -66,7 +66,7 @@ class SAMDecoder(SegmentationModel): def __init__( self, model_adapter: InferenceAdapter, - configuration: dict | None = None, + configuration: Dict[str, Any] | None = dict(), preload: bool = False, ): super().__init__(model_adapter, configuration, preload) From 8612ca98b2baa91bc5f72f605afe57b41a211585 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Wed, 12 Jun 2024 05:04:07 +0900 Subject: [PATCH 04/31] Fix isort --- model_api/python/model_api/models/__init__.py | 2 +- model_api/python/model_api/models/sam_models.py | 3 +-- model_api/python/model_api/models/visual_prompting.py | 9 ++++----- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/model_api/python/model_api/models/__init__.py b/model_api/python/model_api/models/__init__.py index 1a5e6e3a..7efce583 100644 --- a/model_api/python/model_api/models/__init__.py +++ b/model_api/python/model_api/models/__init__.py @@ -36,10 +36,10 @@ from .nanodet import NanoDet, NanoDetPlus from .open_pose import OpenPose from .retinaface import RetinaFace, RetinaFacePyTorch +from .sam_models import SAMDecoder, SAMImageEncoder from .segmentation import SalientObjectDetectionModel, SegmentationModel from .ssd import SSD from .ultra_lightweight_face_detection import UltraLightweightFaceDetection -from .sam_models import SAMDecoder, SAMImageEncoder from .utils import ( AnomalyResult, ClassificationResult, diff --git a/model_api/python/model_api/models/sam_models.py b/model_api/python/model_api/models/sam_models.py index c20583f7..e6d331b5 100644 --- a/model_api/python/model_api/models/sam_models.py +++ b/model_api/python/model_api/models/sam_models.py @@ -18,11 +18,10 @@ from typing import Any, Dict import numpy as np +from model_api.adapters.inference_adapter import InferenceAdapter from model_api.models import ImageModel, SegmentationModel from model_api.models.types import BooleanValue, NumericalValue -from model_api.adapters.inference_adapter import InferenceAdapter - class SAMImageEncoder(ImageModel): """Image Encoder for SAM: https://arxiv.org/abs/2304.02643""" diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index d2c67de3..7ce427f2 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -11,15 +11,14 @@ limitations under the License. """ -from typing import Any -from copy import deepcopy from collections import defaultdict +from copy import deepcopy from itertools import product +from typing import Any -import numpy as np import cv2 - -from model_api.models import SAMImageEncoder, SAMDecoder +import numpy as np +from model_api.models import SAMDecoder, SAMImageEncoder from model_api.models.utils import VisualPromptingResult From fd74455cceb54894d8538a58ad3df1c2ad53fe6f Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Wed, 12 Jun 2024 05:09:38 +0900 Subject: [PATCH 05/31] Fix imports --- model_api/python/model_api/models/sam_models.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/model_api/python/model_api/models/sam_models.py b/model_api/python/model_api/models/sam_models.py index e6d331b5..b30df94c 100644 --- a/model_api/python/model_api/models/sam_models.py +++ b/model_api/python/model_api/models/sam_models.py @@ -19,9 +19,11 @@ import numpy as np from model_api.adapters.inference_adapter import InferenceAdapter -from model_api.models import ImageModel, SegmentationModel from model_api.models.types import BooleanValue, NumericalValue +from .image_model import ImageModel +from .segmentation import SegmentationModel + class SAMImageEncoder(ImageModel): """Image Encoder for SAM: https://arxiv.org/abs/2304.02643""" @@ -31,7 +33,7 @@ class SAMImageEncoder(ImageModel): def __init__( self, inference_adapter: InferenceAdapter, - configuration: Dict[str, Any] | None = dict(), + configuration: Dict[str, Any] = dict(), preload: bool = False, ): super().__init__(inference_adapter, configuration, preload) @@ -65,7 +67,7 @@ class SAMDecoder(SegmentationModel): def __init__( self, model_adapter: InferenceAdapter, - configuration: Dict[str, Any] | None = dict(), + configuration: Dict[str, Any] = dict(), preload: bool = False, ): super().__init__(model_adapter, configuration, preload) From bf972cdcc0d4840afa54fb1da3831f5e75c0ffde Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Thu, 13 Jun 2024 07:34:01 +0900 Subject: [PATCH 06/31] Fix labels concat on learn --- model_api/python/model_api/models/visual_prompting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index 7ce427f2..48cb9690 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -146,7 +146,7 @@ def learn( for label, input_prompts in processed_prompts_w_labels.items(): ref_mask: np.ndarray = np.zeros(original_shape, dtype=np.uint8) for inputs_decoder in input_prompts: - label = inputs_decoder.pop("label") # noqa: PLW2901 + inputs_decoder.pop("label") if "point_coords" in inputs_decoder: # bboxes and points inputs_decoder.update(image_embeddings) @@ -173,7 +173,7 @@ def learn( cur_default_threshold_reference -= 0.05 self.reference_feats[label] = ref_feat - self.used_indices: np.ndarray = np.concatenate((self.used_indices, label)) + self.used_indices: np.ndarray = np.concatenate((self.used_indices, [label])) ref_masks[label] = ref_mask self.used_indices = np.unique(self.used_indices) From 16865032b5a0e6358e23ae25f7287fd8c5bf2d59 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Thu, 13 Jun 2024 10:33:35 +0900 Subject: [PATCH 07/31] Update image encoder --- .../python/model_api/models/sam_models.py | 6 +++++ .../model_api/models/visual_prompting.py | 22 ++++++++----------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/model_api/python/model_api/models/sam_models.py b/model_api/python/model_api/models/sam_models.py index b30df94c..acc4f549 100644 --- a/model_api/python/model_api/models/sam_models.py +++ b/model_api/python/model_api/models/sam_models.py @@ -37,6 +37,7 @@ def __init__( preload: bool = False, ): super().__init__(inference_adapter, configuration, preload) + self.output_name: str = list(self.outputs.keys())[0] @classmethod def parameters(cls) -> dict[str, Any]: @@ -58,6 +59,11 @@ def preprocess( meta["resize_type"] = self.resize_type return dict_inputs, meta + def postprocess( + self, outputs: dict[str, np.ndarray], meta: dict[str, Any] + ) -> np.ndarray: + return outputs[self.output_name] + class SAMDecoder(SegmentationModel): """Image Decoder for SAM: https://arxiv.org/abs/2304.02643""" diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index 48cb9690..8c10a976 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -117,13 +117,12 @@ def learn( points: np.ndarray | None, labels: dict[str, np.ndarray] | None, ): - processed_image, meta = self.encoder_model.preprocess(image) processed_prompts = self.decoder_model.preprocess( { "bboxes": boxes, "points": points, "labels": labels, - "orig_size": meta["original_shape"][:2], + "orig_size": image.shape[:2], }, ) @@ -131,13 +130,11 @@ def learn( largest_label: int = max([int(p) for p in processed_prompts_w_labels] + [0]) self._expand_reference_info(largest_label) - original_shape = np.array(meta["original_shape"][:2]) + original_shape = np.array(image.shape[:2]) # forward image encoder - image_embeddings = self.encoder_model.infer_sync(processed_image) - processed_embedding = ( - image_embeddings["image_embeddings"].squeeze().transpose(1, 2, 0) - ) + image_embeddings = self.encoder_model(image) + processed_embedding = image_embeddings.squeeze().transpose(1, 2, 0) # get reference masks ref_masks: np.ndarray = np.zeros( @@ -149,7 +146,7 @@ def learn( inputs_decoder.pop("label") if "point_coords" in inputs_decoder: # bboxes and points - inputs_decoder.update(image_embeddings) + inputs_decoder["image_embeddings"] = image_embeddings prediction = self._predict_masks( inputs_decoder, original_shape, is_cascade=self.is_cascade ) @@ -384,13 +381,12 @@ def infer( else: used_indices = self.used_indices - processed_image, meta = self.encoder_model.preprocess(image) - original_shape = np.array(meta["original_shape"][:2]) + original_shape = np.array(image.shape[:2]) - image_embeddings = self.encoder_model.infer_sync(processed_image) + image_embeddings = self.encoder_model(image) total_points_scores, total_bg_coords = self._get_prompt_candidates( - image_embeddings=image_embeddings["image_embeddings"], + image_embeddings=image_embeddings, reference_feats=reference_features, used_indices=used_indices, original_shape=original_shape, @@ -432,7 +428,7 @@ def infer( "point_labels": point_labels[None], "orig_size": original_shape[None], } - inputs_decoder.update(image_embeddings) + inputs_decoder["image_embeddings"] = image_embeddings prediction = self._predict_masks( inputs_decoder, original_shape, self.is_cascade From 57c8c946b02899e44fa953043845e0fd5b566b13 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Thu, 13 Jun 2024 11:05:12 +0900 Subject: [PATCH 08/31] Move aux methods out of the class vpt scope --- .../model_api/models/visual_prompting.py | 560 +++++++++--------- 1 file changed, 282 insertions(+), 278 deletions(-) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index 8c10a976..955a0e12 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -161,7 +161,7 @@ def learn( cur_default_threshold_reference = deepcopy(self.default_threshold_reference) while ref_feat is None: # log.info(f"[*] default_threshold_reference : {cur_default_threshold_reference:.4f}") - ref_feat = self._generate_masked_features( + ref_feat = _generate_masked_features( feats=processed_embedding, masks=ref_mask, threshold_mask=cur_default_threshold_reference, @@ -209,55 +209,6 @@ def _expand_reference_info(self, new_largest_label: int) -> None: self.reference_feats, ((0, diff), (0, 0), (0, 0)), constant_values=0.0 ) - def _generate_masked_features( - self, - feats: np.ndarray, - masks: np.ndarray, - threshold_mask: float, - image_size: int = 1024, - ) -> np.ndarray | None: - """Generate masked features. - - Args: - feats (np.ndarray): Raw reference features. It will be filtered with masks. - masks (np.ndarray): Reference masks used to filter features. - threshold_mask (float): Threshold to control masked region. - image_size (int): Input image size. - - Returns: - (np.ndarray): Masked features. - """ - target_shape = image_size / max(masks.shape) * np.array(masks.shape) - target_shape = target_shape[::-1].astype(np.int32) - - # Post-process masks - masks = cv2.resize(masks, target_shape, interpolation=cv2.INTER_LINEAR) - masks = self._pad_to_square(masks, image_size) - masks = cv2.resize(masks, feats.shape[:2][::-1], interpolation=cv2.INTER_LINEAR) - - # Target feature extraction - if (masks > threshold_mask).sum() == 0: - # (for stability) there is no area to be extracted - return None - - masked_feat = feats[masks > threshold_mask] - masked_feat = masked_feat.mean(0)[None] - return masked_feat / np.linalg.norm(masked_feat, axis=-1, keepdims=True) - - def _pad_to_square(self, x: np.ndarray, image_size: int = 1024) -> np.ndarray: - """Pad to a square input. - - Args: - x (np.ndarray): Mask to be padded. - - Returns: - (np.ndarray): Padded mask. - """ - h, w = x.shape[-2:] - padh = image_size - h - padw = image_size - w - return np.pad(x, ((0, padh), (0, padw)), constant_values=0.0) - def _predict_masks( self, inputs: dict[str, np.ndarray], @@ -280,7 +231,7 @@ def _predict_masks( elif i == 1: # Cascaded Post-refinement-1 - mask_input, masks = self._decide_masks( + mask_input, masks = _decide_masks( masks, logits, scores, is_single=True ) # noqa: F821 if masks.sum() == 0: @@ -290,7 +241,7 @@ def _predict_masks( elif i == 2: # Cascaded Post-refinement-2 - mask_input, masks = self._decide_masks( + mask_input, masks = _decide_masks( masks, logits, scores ) # noqa: F821 if masks.sum() == 0: @@ -325,40 +276,9 @@ def _predict_masks( ) masks = upscaled_masks > self.decoder_model.mask_threshold - _, masks = self._decide_masks(masks, logits, scores) + _, masks = _decide_masks(masks, logits, scores) return {"upscaled_masks": masks} - def _decide_masks( - self, - masks: np.ndarray, - logits: np.ndarray, - scores: np.ndarray, - is_single: bool = False, - ) -> tuple[np.ndarray, ...] | tuple[None, np.ndarray]: - """Post-process logits for resized masks according to best index based on scores.""" - if is_single: - best_idx = 0 - else: - # skip the first index components - scores, masks, logits = (x[:, 1:] for x in (scores, masks, logits)) - - # filter zero masks - while ( - len(scores[0]) > 0 - and masks[0, (best_idx := np.argmax(scores[0]))].sum() == 0 - ): - scores, masks, logits = ( - np.concatenate((x[:, :best_idx], x[:, best_idx + 1 :]), axis=1) - for x in (scores, masks, logits) - ) - - if len(scores[0]) == 0: - # all predicted masks were zero masks, ignore them. - return None, np.zeros(masks.shape[-2:]) - - best_idx = np.argmax(scores[0]) - return logits[:, [best_idx]], masks[0, best_idx] - def infer( self, image: np.ndarray, @@ -385,7 +305,7 @@ def infer( image_embeddings = self.encoder_model(image) - total_points_scores, total_bg_coords = self._get_prompt_candidates( + total_points_scores, total_bg_coords = _get_prompt_candidates( image_embeddings=image_embeddings, reference_feats=reference_features, used_indices=used_indices, @@ -441,206 +361,290 @@ def infer( used_points[label].append(points_score) # check overlapping area between different label masks - self._inspect_overlapping_areas(predicted_masks, used_points) + _inspect_overlapping_areas(predicted_masks, used_points) return (predicted_masks, used_points) - def _get_prompt_candidates( - self, - image_embeddings: np.ndarray, - reference_feats: np.ndarray, - used_indices: np.ndarray, - original_shape: np.ndarray, - threshold: float = 0.0, - num_bg_points: int = 1, - default_threshold_target: float = 0.65, - image_size: int = 1024, - downsizing: int = 64, - ) -> tuple[dict[int, np.ndarray], dict[int, np.ndarray]]: - """Get prompt candidates.""" - target_feat = image_embeddings.squeeze() - c_feat, h_feat, w_feat = target_feat.shape - target_feat = target_feat / np.linalg.norm(target_feat, axis=0, keepdims=True) - target_feat = target_feat.reshape(c_feat, h_feat * w_feat) - - total_points_scores: dict[int, np.ndarray] = {} - total_bg_coords: dict[int, np.ndarray] = {} - for label in used_indices: - sim = reference_feats[label] @ target_feat - sim = sim.reshape(h_feat, w_feat) - sim = self._resize_to_original_shape(sim, image_size, original_shape) - - threshold = (threshold == 0) * default_threshold_target + threshold - points_scores, bg_coords = self._point_selection( - mask_sim=sim, - original_shape=original_shape, - threshold=threshold, - num_bg_points=num_bg_points, - image_size=image_size, - downsizing=downsizing, - ) - if points_scores is not None: - total_points_scores[label] = points_scores - total_bg_coords[label] = bg_coords - return total_points_scores, total_bg_coords +def _generate_masked_features( + feats: np.ndarray, + masks: np.ndarray, + threshold_mask: float, + image_size: int = 1024, +) -> np.ndarray | None: + """Generate masked features. + + Args: + feats (np.ndarray): Raw reference features. It will be filtered with masks. + masks (np.ndarray): Reference masks used to filter features. + threshold_mask (float): Threshold to control masked region. + image_size (int): Input image size. + + Returns: + (np.ndarray): Masked features. + """ + target_shape = image_size / max(masks.shape) * np.array(masks.shape) + target_shape = target_shape[::-1].astype(np.int32) + + # Post-process masks + masks = cv2.resize(masks, target_shape, interpolation=cv2.INTER_LINEAR) + masks = _pad_to_square(masks, image_size) + masks = cv2.resize(masks, feats.shape[:2][::-1], interpolation=cv2.INTER_LINEAR) + + # Target feature extraction + if (masks > threshold_mask).sum() == 0: + # (for stability) there is no area to be extracted + return None + + masked_feat = feats[masks > threshold_mask] + masked_feat = masked_feat.mean(0)[None] + return masked_feat / np.linalg.norm(masked_feat, axis=-1, keepdims=True) + + +def _pad_to_square(x: np.ndarray, image_size: int = 1024) -> np.ndarray: + """Pad to a square input. + + Args: + x (np.ndarray): Mask to be padded. + + Returns: + (np.ndarray): Padded mask. + """ + h, w = x.shape[-2:] + padh = image_size - h + padw = image_size - w + return np.pad(x, ((0, padh), (0, padw)), constant_values=0.0) + + +def _decide_masks( + masks: np.ndarray, + logits: np.ndarray, + scores: np.ndarray, + is_single: bool = False, +) -> tuple[np.ndarray, ...] | tuple[None, np.ndarray]: + """Post-process logits for resized masks according to best index based on scores.""" + if is_single: + best_idx = 0 + else: + # skip the first index components + scores, masks, logits = (x[:, 1:] for x in (scores, masks, logits)) + + # filter zero masks + while ( + len(scores[0]) > 0 + and masks[0, (best_idx := np.argmax(scores[0]))].sum() == 0 + ): + scores, masks, logits = ( + np.concatenate((x[:, :best_idx], x[:, best_idx + 1 :]), axis=1) + for x in (scores, masks, logits) + ) - def _point_selection( - self, - mask_sim: np.ndarray, - original_shape: np.ndarray, - threshold: float = 0.0, - num_bg_points: int = 1, - image_size: int = 1024, - downsizing: int = 64, - ) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]: - """Select point used as point prompts.""" - _, w_sim = mask_sim.shape - - # Top-first point selection - point_coords = np.where(mask_sim > threshold) - fg_coords_scores = np.stack( - point_coords[::-1] + (mask_sim[point_coords],), axis=0 - ).T - - ## skip if there is no point coords - if len(fg_coords_scores) == 0: - return None, None - - ratio = image_size / original_shape.max() - width = (original_shape[1] * ratio).astype(np.int64) - n_w = width // downsizing - - ## get grid numbers - idx_grid = ( - fg_coords_scores[:, 1] * ratio // downsizing * n_w - + fg_coords_scores[:, 0] * ratio // downsizing + if len(scores[0]) == 0: + # all predicted masks were zero masks, ignore them. + return None, np.zeros(masks.shape[-2:]) + + best_idx = np.argmax(scores[0]) + return logits[:, [best_idx]], masks[0, best_idx] + + +def _get_prompt_candidates( + image_embeddings: np.ndarray, + reference_feats: np.ndarray, + used_indices: np.ndarray, + original_shape: np.ndarray, + threshold: float = 0.0, + num_bg_points: int = 1, + default_threshold_target: float = 0.65, + image_size: int = 1024, + downsizing: int = 64, +) -> tuple[dict[int, np.ndarray], dict[int, np.ndarray]]: + """Get prompt candidates.""" + target_feat = image_embeddings.squeeze() + c_feat, h_feat, w_feat = target_feat.shape + target_feat = target_feat / np.linalg.norm(target_feat, axis=0, keepdims=True) + target_feat = target_feat.reshape(c_feat, h_feat * w_feat) + + total_points_scores: dict[int, np.ndarray] = {} + total_bg_coords: dict[int, np.ndarray] = {} + for label in used_indices: + sim = reference_feats[label] @ target_feat + sim = sim.reshape(h_feat, w_feat) + sim = _resize_to_original_shape(sim, image_size, original_shape) + + threshold = (threshold == 0) * default_threshold_target + threshold + points_scores, bg_coords = _point_selection( + mask_sim=sim, + original_shape=original_shape, + threshold=threshold, + num_bg_points=num_bg_points, + image_size=image_size, + downsizing=downsizing, ) - idx_grid_unique = np.unique(idx_grid.astype(np.int64)) - ## get matched indices - matched_matrix = ( - np.expand_dims(idx_grid, axis=-1) == idx_grid_unique - ) # (totalN, uniqueN) + if points_scores is not None: + total_points_scores[label] = points_scores + total_bg_coords[label] = bg_coords + return total_points_scores, total_bg_coords + + +def _point_selection( + mask_sim: np.ndarray, + original_shape: np.ndarray, + threshold: float = 0.0, + num_bg_points: int = 1, + image_size: int = 1024, + downsizing: int = 64, +) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]: + """Select point used as point prompts.""" + _, w_sim = mask_sim.shape + + # Top-first point selection + point_coords = np.where(mask_sim > threshold) + fg_coords_scores = np.stack( + point_coords[::-1] + (mask_sim[point_coords],), axis=0 + ).T + + ## skip if there is no point coords + if len(fg_coords_scores) == 0: + return None, None + + ratio = image_size / original_shape.max() + width = (original_shape[1] * ratio).astype(np.int64) + n_w = width // downsizing + + ## get grid numbers + idx_grid = ( + fg_coords_scores[:, 1] * ratio // downsizing * n_w + + fg_coords_scores[:, 0] * ratio // downsizing + ) + idx_grid_unique = np.unique(idx_grid.astype(np.int64)) + + ## get matched indices + matched_matrix = ( + np.expand_dims(idx_grid, axis=-1) == idx_grid_unique + ) # (totalN, uniqueN) + + ## sample fg_coords_scores matched by matched_matrix + matched_grid = np.expand_dims(fg_coords_scores, axis=1) * np.expand_dims( + matched_matrix, axis=-1 + ) + + ## sample the highest score one of the samples that are in the same grid + matched_indices = _topk_numpy( + matched_grid[..., -1], k=1, axis=0, largest=True + )[1][0].astype(np.int64) + points_scores = matched_grid[matched_indices].diagonal().T + + ## sort by the highest score + sorted_points_scores_indices = np.flip( + np.argsort(points_scores[:, -1]), axis=-1 + ).astype(np.int64) + points_scores = points_scores[sorted_points_scores_indices] + + # Top-last point selection + bg_indices = _topk_numpy(mask_sim.flatten(), num_bg_points, largest=False)[ + 1 + ] + bg_x = np.expand_dims(bg_indices // w_sim, axis=0) + bg_y = bg_indices - bg_x * w_sim + bg_coords = np.concatenate((bg_y, bg_x), axis=0).transpose(1, 0) + bg_coords = bg_coords.astype(np.float32) + + return points_scores, bg_coords + + +def _resize_to_original_shape( + masks: np.ndarray, image_size: int, original_shape: np.ndarray +) -> np.ndarray: + """Resize feature size to original shape.""" + # resize feature size to input size + masks = cv2.resize( + masks, (image_size, image_size), interpolation=cv2.INTER_LINEAR + ) + + # remove pad + prepadded_size = _get_prepadded_size(original_shape, image_size) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] + + # resize unpadded one to original shape + original_shape = original_shape.astype(np.int64) + h, w = original_shape[0], original_shape[1] + return cv2.resize(masks, (w, h), interpolation=cv2.INTER_LINEAR) + + +def _get_prepadded_size(original_shape: np.ndarray, image_size: int) -> np.ndarray: + """Get pre-padded size.""" + scale = image_size / np.max(original_shape) + transformed_size = scale * original_shape + return np.floor(transformed_size + 0.5).astype(np.int64) + + +def _topk_numpy( + x: np.ndarray, k: int, axis: int = -1, largest: bool = True +) -> tuple[np.ndarray, np.ndarray]: + """Top-k function for numpy same with torch.topk.""" + if largest: + k = -k + indices = range(k, 0) + else: + indices = range(k) + partitioned_ind = np.argpartition(x, k, axis=axis).take( + indices=indices, axis=axis + ) + partitioned_scores = np.take_along_axis(x, partitioned_ind, axis=axis) + sorted_trunc_ind = np.argsort(partitioned_scores, axis=axis) + if largest: + sorted_trunc_ind = np.flip(sorted_trunc_ind, axis=axis) + ind = np.take_along_axis(partitioned_ind, sorted_trunc_ind, axis=axis) + scores = np.take_along_axis(partitioned_scores, sorted_trunc_ind, axis=axis) + return scores, ind + + +def _inspect_overlapping_areas( + predicted_masks: dict[int, list[np.ndarray]], + used_points: dict[int, list[np.ndarray]], + threshold_iou: float = 0.8, +) -> None: + def _calculate_mask_iou( + mask1: np.ndarray, mask2: np.ndarray + ) -> tuple[float, np.ndarray | None]: + assert mask1.ndim == 2 # noqa: S101 + assert mask2.ndim == 2 # noqa: S101 + # Avoid division by zero + if (union := np.logical_or(mask1, mask2).sum().item()) == 0: + return 0.0, None + intersection = np.logical_and(mask1, mask2) + return intersection.sum().item() / union, intersection + + for (label, masks), (other_label, other_masks) in product( + predicted_masks.items(), predicted_masks.items() + ): + if other_label <= label: + continue - ## sample fg_coords_scores matched by matched_matrix - matched_grid = np.expand_dims(fg_coords_scores, axis=1) * np.expand_dims( - matched_matrix, axis=-1 - ) + overlapped_label = [] + overlapped_other_label = [] + for (im, mask), (jm, other_mask) in product( + enumerate(masks), enumerate(other_masks) + ): + _mask_iou, _intersection = _calculate_mask_iou(mask, other_mask) + if _mask_iou > threshold_iou: + if used_points[label][im][2] > used_points[other_label][jm][2]: + overlapped_other_label.append(jm) + else: + overlapped_label.append(im) + elif _mask_iou > 0: + # refine the slightly overlapping region + overlapped_coords = np.where(_intersection) + if used_points[label][im][2] > used_points[other_label][jm][2]: + other_mask[overlapped_coords] = 0.0 + else: + mask[overlapped_coords] = 0.0 - ## sample the highest score one of the samples that are in the same grid - matched_indices = self._topk_numpy( - matched_grid[..., -1], k=1, axis=0, largest=True - )[1][0].astype(np.int64) - points_scores = matched_grid[matched_indices].diagonal().T - - ## sort by the highest score - sorted_points_scores_indices = np.flip( - np.argsort(points_scores[:, -1]), axis=-1 - ).astype(np.int64) - points_scores = points_scores[sorted_points_scores_indices] - - # Top-last point selection - bg_indices = self._topk_numpy(mask_sim.flatten(), num_bg_points, largest=False)[ - 1 - ] - bg_x = np.expand_dims(bg_indices // w_sim, axis=0) - bg_y = bg_indices - bg_x * w_sim - bg_coords = np.concatenate((bg_y, bg_x), axis=0).transpose(1, 0) - bg_coords = bg_coords.astype(np.float32) - - return points_scores, bg_coords - - def _resize_to_original_shape( - self, masks: np.ndarray, image_size: int, original_shape: np.ndarray - ) -> np.ndarray: - """Resize feature size to original shape.""" - # resize feature size to input size - masks = cv2.resize( - masks, (image_size, image_size), interpolation=cv2.INTER_LINEAR - ) + for im in sorted(set(overlapped_label), reverse=True): + masks.pop(im) + used_points[label].pop(im) - # remove pad - prepadded_size = self._get_prepadded_size(original_shape, image_size) - masks = masks[..., : prepadded_size[0], : prepadded_size[1]] - - # resize unpadded one to original shape - original_shape = original_shape.astype(np.int64) - h, w = original_shape[0], original_shape[1] - return cv2.resize(masks, (w, h), interpolation=cv2.INTER_LINEAR) - - def _get_prepadded_size(self, original_shape: int, image_size: int) -> np.ndarray: - """Get pre-padded size.""" - scale = image_size / np.max(original_shape) - transformed_size = scale * original_shape - return np.floor(transformed_size + 0.5).astype(np.int64) - - def _topk_numpy( - self, x: np.ndarray, k: int, axis: int = -1, largest: bool = True - ) -> tuple[np.ndarray, np.ndarray]: - """Top-k function for numpy same with torch.topk.""" - if largest: - k = -k - indices = range(k, 0) - else: - indices = range(k) - partitioned_ind = np.argpartition(x, k, axis=axis).take( - indices=indices, axis=axis - ) - partitioned_scores = np.take_along_axis(x, partitioned_ind, axis=axis) - sorted_trunc_ind = np.argsort(partitioned_scores, axis=axis) - if largest: - sorted_trunc_ind = np.flip(sorted_trunc_ind, axis=axis) - ind = np.take_along_axis(partitioned_ind, sorted_trunc_ind, axis=axis) - scores = np.take_along_axis(partitioned_scores, sorted_trunc_ind, axis=axis) - return scores, ind - - def _inspect_overlapping_areas( - self, - predicted_masks: dict[int, list[np.ndarray]], - used_points: dict[int, list[np.ndarray]], - threshold_iou: float = 0.8, - ) -> None: - def _calculate_mask_iou( - mask1: np.ndarray, mask2: np.ndarray - ) -> tuple[float, np.ndarray | None]: - assert mask1.ndim == 2 # noqa: S101 - assert mask2.ndim == 2 # noqa: S101 - # Avoid division by zero - if (union := np.logical_or(mask1, mask2).sum().item()) == 0: - return 0.0, None - intersection = np.logical_and(mask1, mask2) - return intersection.sum().item() / union, intersection - - for (label, masks), (other_label, other_masks) in product( - predicted_masks.items(), predicted_masks.items() - ): - if other_label <= label: - continue - - overlapped_label = [] - overlapped_other_label = [] - for (im, mask), (jm, other_mask) in product( - enumerate(masks), enumerate(other_masks) - ): - _mask_iou, _intersection = _calculate_mask_iou(mask, other_mask) - if _mask_iou > threshold_iou: - if used_points[label][im][2] > used_points[other_label][jm][2]: - overlapped_other_label.append(jm) - else: - overlapped_label.append(im) - elif _mask_iou > 0: - # refine the slightly overlapping region - overlapped_coords = np.where(_intersection) - if used_points[label][im][2] > used_points[other_label][jm][2]: - other_mask[overlapped_coords] = 0.0 - else: - mask[overlapped_coords] = 0.0 - - for im in sorted(set(overlapped_label), reverse=True): - masks.pop(im) - used_points[label].pop(im) - - for jm in sorted(set(overlapped_other_label), reverse=True): - other_masks.pop(jm) - used_points[other_label].pop(jm) + for jm in sorted(set(overlapped_other_label), reverse=True): + other_masks.pop(jm) + used_points[other_label].pop(jm) From 50d2a95834539743a34f4dd847b9a35a6fb30f9a Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Tue, 18 Jun 2024 09:11:35 +0900 Subject: [PATCH 09/31] Fix handling of reference features --- .../python/model_api/models/visual_prompting.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index 955a0e12..8bbf5ddb 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -12,7 +12,6 @@ """ from collections import defaultdict -from copy import deepcopy from itertools import product from typing import Any @@ -158,7 +157,7 @@ def learn( ref_mask = np.clip(ref_mask, 0, 1) ref_feat: np.ndarray | None = None - cur_default_threshold_reference = deepcopy(self.default_threshold_reference) + cur_default_threshold_reference = self.default_threshold_reference while ref_feat is None: # log.info(f"[*] default_threshold_reference : {cur_default_threshold_reference:.4f}") ref_feat = _generate_masked_features( @@ -169,20 +168,20 @@ def learn( ) cur_default_threshold_reference -= 0.05 - self.reference_feats[label] = ref_feat + self.reference_features[label] = ref_feat self.used_indices: np.ndarray = np.concatenate((self.used_indices, [label])) ref_masks[label] = ref_mask self.used_indices = np.unique(self.used_indices) return { - "reference_feats": self.reference_feats, + "reference_features": self.reference_features, "used_indices": self.used_indices, }, ref_masks def reset_reference_info(self) -> None: """Initialize reference information.""" - self.reference_feats = np.zeros( + self.reference_features = np.zeros( (0, 1, self.decoder_model.embed_dim), dtype=np.float32 ) self.used_indices = np.array([], dtype=np.int64) @@ -203,10 +202,10 @@ def _gather_prompts_with_labels( def _expand_reference_info(self, new_largest_label: int) -> None: """Expand reference info dimensions if newly given processed prompts have more lables.""" - if new_largest_label > (cur_largest_label := len(self.reference_feats) - 1): + if new_largest_label > (cur_largest_label := len(self.reference_features) - 1): diff = new_largest_label - cur_largest_label - self.reference_feats = np.pad( - self.reference_feats, ((0, diff), (0, 0), (0, 0)), constant_values=0.0 + self.reference_features = np.pad( + self.reference_features, ((0, diff), (0, 0), (0, 0)), constant_values=0.0 ) def _predict_masks( From 1c037359ed3defce0fa0b12fa75eddd412fab6c3 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Tue, 18 Jun 2024 10:06:13 +0900 Subject: [PATCH 10/31] Align is_cascade usage --- model_api/python/model_api/models/visual_prompting.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index 8bbf5ddb..79a5a5c8 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -95,7 +95,7 @@ def __init__( self.point_labels_box = np.array([[2, 3]], dtype=np.float32) self.has_mask_inputs = [np.array([[0.0]]), np.array([[1.0]])] - self.is_cascade: bool = True + self.is_cascade: bool = False self.threshold: float = 0.0 self.num_bg_points: int = 1 self.default_threshold_target: float = 0.65 @@ -151,15 +151,13 @@ def learn( ) masks = prediction["upscaled_masks"] else: - # log.warning("annotation and polygon will be supported.") - continue + raise RuntimeError("Prompts other than points are not supported") ref_mask[masks] += 1 ref_mask = np.clip(ref_mask, 0, 1) ref_feat: np.ndarray | None = None cur_default_threshold_reference = self.default_threshold_reference while ref_feat is None: - # log.info(f"[*] default_threshold_reference : {cur_default_threshold_reference:.4f}") ref_feat = _generate_masked_features( feats=processed_embedding, masks=ref_mask, @@ -350,7 +348,7 @@ def infer( inputs_decoder["image_embeddings"] = image_embeddings prediction = self._predict_masks( - inputs_decoder, original_shape, self.is_cascade + inputs_decoder, original_shape, True ) prediction.update({"scores": points_score[-1]}) From 70af2bacbfea344e7e99586d1139ce03c6c17de1 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Wed, 19 Jun 2024 10:40:03 +0900 Subject: [PATCH 11/31] Update vpt public interfaces --- .../model_api/models/visual_prompting.py | 327 ++++++++++-------- 1 file changed, 189 insertions(+), 138 deletions(-) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index 79a5a5c8..d481d68f 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -13,7 +13,7 @@ from collections import defaultdict from itertools import product -from typing import Any +from typing import Any, NamedTuple import cv2 import numpy as np @@ -21,6 +21,16 @@ from model_api.models.utils import VisualPromptingResult +class VisualPromptingFeatures(NamedTuple): + feature_vectors: np.ndarray + used_indices: np.ndarray + + +class Prompt(NamedTuple): + data: list[np.ndarray] | np.ndarray + labels: list[np.ndarray | int] | np.ndarray + + class SAMVisualPrompter: def __init__( self, @@ -33,19 +43,22 @@ def __init__( def infer( self, image: np.ndarray, - boxes: np.ndarray | None, - points: np.ndarray | None, - labels: dict[str, np.ndarray] | None, + boxes: Prompt | None = None, + points: Prompt | None = None, ) -> VisualPromptingResult: + if boxes is None and points is None: + raise RuntimeError("boxes or points prompts are required for inference") + outputs: list[dict[str, Any]] = [] processed_image, meta = self.encoder_model.preprocess(image) image_embeddings = self.encoder_model.infer_sync(processed_image) processed_prompts = self.decoder_model.preprocess( { - "bboxes": boxes, - "points": points, - "labels": labels, + "bboxes": boxes.data if boxes else None, + "points": points.data if points else None, + "labels": {"bboxes": boxes.labels if boxes else None, + "points": points.labels if points else None}, "orig_size": meta["original_shape"][:2], }, ) @@ -73,11 +86,10 @@ def infer( def __call__( self, image: np.ndarray, - boxes: np.ndarray | None, - points: np.ndarray | None, - labels: dict[str, np.ndarray] | None, + boxes: Prompt | None = None, + points: Prompt | None = None, ) -> VisualPromptingResult: - return self.infer(image, boxes, points, labels) + return self.infer(image, boxes, points) class SAMLearnableVisualPrompter: @@ -85,42 +97,75 @@ def __init__( self, encoder_model: SAMImageEncoder, decoder_model: SAMDecoder, - reference_features: np.ndarray | None = None, + reference_features: VisualPromptingFeatures | None = None, ): self.encoder_model = encoder_model self.decoder_model = decoder_model - self.reference_features = reference_features - self.used_indices = None - self.point_labels_box = np.array([[2, 3]], dtype=np.float32) - self.has_mask_inputs = [np.array([[0.0]]), np.array([[1.0]])] + if reference_features is not None: + self._reference_features = reference_features.feature_vectors + self._used_indices = reference_features.used_indices + else: + self._reference_features = None + self._used_indices = None - self.is_cascade: bool = False - self.threshold: float = 0.0 - self.num_bg_points: int = 1 - self.default_threshold_target: float = 0.65 - self.image_size: int = self.encoder_model.image_size - self.downsizing: int = 64 - self.default_threshold_reference: float = 0.3 + self._point_labels_box = np.array([[2, 3]], dtype=np.float32) + self._has_mask_inputs = [np.array([[0.0]]), np.array([[1.0]])] - if self.reference_features is None: - self.reset_reference_info() + self._is_cascade: bool = False + self._threshold: float = 0.0 + self._num_bg_points: int = 1 + self._default_threshold_target: float = 0.65 + self._image_size: int = self.encoder_model.image_size + self._downsizing: int = 64 + self._default_threshold_reference: float = 0.3 def has_reference_features(self) -> bool: - return self.reference_features is not None + return self._reference_features is not None and self._used_indices is not None + + @property + def reference_features(self) -> VisualPromptingFeatures: + if self.has_reference_features(): + return VisualPromptingFeatures(np.copy(self._reference_features), np.copy(self._used_indices)) + + raise RuntimeError("Reference features are not generated") def learn( self, image: np.ndarray, - boxes: np.ndarray | None, - points: np.ndarray | None, - labels: dict[str, np.ndarray] | None, - ): + boxes: Prompt | None = None, + points: Prompt | None = None, + reset_features: bool = False + ) -> tuple[VisualPromptingFeatures, np.ndarray]: + """ + Executes `learn` stage of SAM ZSL pipeline. + + Reference features are updated according to newly arrived prompts. This method should not be run on different images + without resetting reference features. Consequent runs on the same image with preserving reference features make sense if new or refined prompts are passed. + + Args: + image (np.ndarray): HWC-shaped image + boxes (Prompt | None, optional): Prompt containing bounding boxes (in XYXY torchvision format) and their labels (ints, one per box). Defaults to None. + points (Prompt | None, optional): Prompt containing bounding boxes (in XY format) and their labels (ints, one per point). Defaults to None. + reset_features (bool, optional): Forces learning from scratch. Defaults to False. + + Returns: + tuple[VisualPromptingFeatures, np.ndarray]: return values are the updated VPT reference features and reference masks. + The shape of the reference mask is N_labels x H x W, where H and W are the same as in the input image. + """ + + if boxes is None and points is None: + raise RuntimeError("boxes or points prompts are required for learning") + + if reset_features or not self.has_reference_features(): + self.reset_reference_info() + processed_prompts = self.decoder_model.preprocess( { - "bboxes": boxes, - "points": points, - "labels": labels, + "bboxes": boxes.data if boxes else None, + "points": points.data if points else None, + "labels": {"bboxes": boxes.labels if boxes else None, + "points": points.labels if points else None}, "orig_size": image.shape[:2], }, ) @@ -147,7 +192,7 @@ def learn( # bboxes and points inputs_decoder["image_embeddings"] = image_embeddings prediction = self._predict_masks( - inputs_decoder, original_shape, is_cascade=self.is_cascade + inputs_decoder, original_shape, is_cascade=self._is_cascade ) masks = prediction["upscaled_masks"] else: @@ -156,7 +201,7 @@ def learn( ref_mask = np.clip(ref_mask, 0, 1) ref_feat: np.ndarray | None = None - cur_default_threshold_reference = self.default_threshold_reference + cur_default_threshold_reference = self._default_threshold_reference while ref_feat is None: ref_feat = _generate_masked_features( feats=processed_embedding, @@ -166,23 +211,112 @@ def learn( ) cur_default_threshold_reference -= 0.05 - self.reference_features[label] = ref_feat - self.used_indices: np.ndarray = np.concatenate((self.used_indices, [label])) + self._reference_features[label] = ref_feat + self._used_indices: np.ndarray = np.concatenate((self._used_indices, [label])) ref_masks[label] = ref_mask - self.used_indices = np.unique(self.used_indices) + self._used_indices = np.unique(self._used_indices) + + return self.reference_features, ref_masks + + def __call__( + self, + image: np.ndarray, + reference_features: VisualPromptingFeatures | None = None, + ): + return self.infer(image, reference_features) + + def infer( + self, + image: np.ndarray, + reference_features: VisualPromptingFeatures | None = None, + ): + if reference_features is None: + if self._reference_features is None: + raise RuntimeError( + "Reference features are not defined. This parameter can be passed via SAMLearnableVisualPrompter constructor, or as an argument of infer() method" + ) + else: + reference_feats = self._reference_features + + if self._used_indices is None: + raise RuntimeError( + "Used indices are not defined. This parameter can be passed via SAMLearnableVisualPrompter constructor, or as an argument of infer() method" + ) + else: + used_idx = self._used_indices + else: + reference_feats, used_idx = reference_features + + original_shape = np.array(image.shape[:2]) + image_embeddings = self.encoder_model(image) + + total_points_scores, total_bg_coords = _get_prompt_candidates( + image_embeddings=image_embeddings, + reference_feats=reference_feats, + used_indices=used_idx, + original_shape=original_shape, + threshold=self._threshold, + num_bg_points=self._num_bg_points, + default_threshold_target=self._default_threshold_target, + image_size=self._image_size, + downsizing=self._downsizing, + ) + + predicted_masks: defaultdict[int, list] = defaultdict(list) + used_points: defaultdict[int, list] = defaultdict(list) + for label in total_points_scores: + points_scores = total_points_scores[label] + bg_coords = total_bg_coords[label] + for points_score in points_scores: + if points_score[-1] in [-1.0, 0.0]: + continue + + x, y = points_score[:2] + is_done = False + for pm in predicted_masks.get(label, []): + # check if that point is already assigned + if pm[int(y), int(x)] > 0: + is_done = True + break + if is_done: + continue + + point_coords = np.concatenate( + (np.array([[x, y]]), bg_coords), axis=0, dtype=np.float32 + ) + point_coords = self.decoder_model.apply_coords( + point_coords, original_shape + ) + point_labels = np.array([1] + [0] * len(bg_coords), dtype=np.float32) + inputs_decoder = { + "point_coords": point_coords[None], + "point_labels": point_labels[None], + "orig_size": original_shape[None], + } + inputs_decoder["image_embeddings"] = image_embeddings + + prediction = self._predict_masks( + inputs_decoder, original_shape, True + ) + prediction.update({"scores": points_score[-1]}) + + predicted_masks[label].append( + prediction[self.decoder_model.output_blob_name] + ) + used_points[label].append(points_score) + + # check overlapping area between different label masks + _inspect_overlapping_areas(predicted_masks, used_points) - return { - "reference_features": self.reference_features, - "used_indices": self.used_indices, - }, ref_masks + return (predicted_masks, used_points) def reset_reference_info(self) -> None: """Initialize reference information.""" - self.reference_features = np.zeros( + self._reference_features = np.zeros( (0, 1, self.decoder_model.embed_dim), dtype=np.float32 ) - self.used_indices = np.array([], dtype=np.int64) + self._used_indices = np.array([], dtype=np.int64) def _gather_prompts_with_labels( self, @@ -199,11 +333,14 @@ def _gather_prompts_with_labels( return dict(sorted(processed_prompts.items(), key=lambda x: x)) def _expand_reference_info(self, new_largest_label: int) -> None: - """Expand reference info dimensions if newly given processed prompts have more lables.""" - if new_largest_label > (cur_largest_label := len(self.reference_features) - 1): + """Expand reference info dimensions if newly given processed prompts have more labels.""" + if self._reference_features is None: + raise RuntimeError("Can not expand non existing reference info") + + if new_largest_label > (cur_largest_label := len(self._reference_features) - 1): diff = new_largest_label - cur_largest_label - self.reference_features = np.pad( - self.reference_features, ((0, diff), (0, 0), (0, 0)), constant_values=0.0 + self._reference_features = np.pad( + self._reference_features, ((0, diff), (0, 0), (0, 0)), constant_values=0.0 ) def _predict_masks( @@ -224,7 +361,7 @@ def _predict_masks( (1, 1, *(x * 4 for x in inputs["image_embeddings"].shape[2:])), dtype=np.float32, ) - has_mask_input = self.has_mask_inputs[0] + has_mask_input = self._has_mask_inputs[0] elif i == 1: # Cascaded Post-refinement-1 @@ -234,7 +371,7 @@ def _predict_masks( if masks.sum() == 0: return {"upscaled_masks": masks} - has_mask_input = self.has_mask_inputs[1] + has_mask_input = self._has_mask_inputs[1] elif i == 2: # Cascaded Post-refinement-2 @@ -244,7 +381,7 @@ def _predict_masks( if masks.sum() == 0: return {"upscaled_masks": masks} - has_mask_input = self.has_mask_inputs[1] + has_mask_input = self._has_mask_inputs[1] y, x = np.nonzero(masks) box_coords = self.decoder_model.apply_coords( np.array( @@ -259,7 +396,7 @@ def _predict_masks( (inputs["point_coords"], box_coords), axis=1 ), "point_labels": np.concatenate( - (inputs["point_labels"], self.point_labels_box), axis=1 + (inputs["point_labels"], self._point_labels_box), axis=1 ), }, ) @@ -276,92 +413,6 @@ def _predict_masks( _, masks = _decide_masks(masks, logits, scores) return {"upscaled_masks": masks} - def infer( - self, - image: np.ndarray, - reference_features: np.ndarray | None, - used_indices: np.ndarray | None, - ): - if reference_features is None: - if self.reference_features is None: - raise RuntimeError( - "Reference features are not defined. This parameter can be passed via SAMLearnableVisualPrompter constructor, or as an argument of infer() method" - ) - else: - reference_features = self.reference_features - - if used_indices is None: - if self.used_indices is None: - raise RuntimeError( - "Used indices are not defined. This parameter can be passed via SAMLearnableVisualPrompter constructor, or as an argument of infer() method" - ) - else: - used_indices = self.used_indices - - original_shape = np.array(image.shape[:2]) - - image_embeddings = self.encoder_model(image) - - total_points_scores, total_bg_coords = _get_prompt_candidates( - image_embeddings=image_embeddings, - reference_feats=reference_features, - used_indices=used_indices, - original_shape=original_shape, - threshold=self.threshold, - num_bg_points=self.num_bg_points, - default_threshold_target=self.default_threshold_target, - image_size=self.image_size, - downsizing=self.downsizing, - ) - - predicted_masks: defaultdict[int, list] = defaultdict(list) - used_points: defaultdict[int, list] = defaultdict(list) - for label in total_points_scores: - points_scores = total_points_scores[label] - bg_coords = total_bg_coords[label] - for points_score in points_scores: - if points_score[-1] in [-1.0, 0.0]: - continue - - x, y = points_score[:2] - is_done = False - for pm in predicted_masks.get(label, []): - # check if that point is already assigned - if pm[int(y), int(x)] > 0: - is_done = True - break - if is_done: - continue - - point_coords = np.concatenate( - (np.array([[x, y]]), bg_coords), axis=0, dtype=np.float32 - ) - point_coords = self.decoder_model.apply_coords( - point_coords, original_shape - ) - point_labels = np.array([1] + [0] * len(bg_coords), dtype=np.float32) - inputs_decoder = { - "point_coords": point_coords[None], - "point_labels": point_labels[None], - "orig_size": original_shape[None], - } - inputs_decoder["image_embeddings"] = image_embeddings - - prediction = self._predict_masks( - inputs_decoder, original_shape, True - ) - prediction.update({"scores": points_score[-1]}) - - predicted_masks[label].append( - prediction[self.decoder_model.output_blob_name] - ) - used_points[label].append(points_score) - - # check overlapping area between different label masks - _inspect_overlapping_areas(predicted_masks, used_points) - - return (predicted_masks, used_points) - def _generate_masked_features( feats: np.ndarray, From 29146e81b3290b6d26f1bb9856c3ffd896f77dfb Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Thu, 20 Jun 2024 07:40:49 +0900 Subject: [PATCH 12/31] Add some docs --- .../model_api/models/visual_prompting.py | 109 ++++++++++++++---- 1 file changed, 84 insertions(+), 25 deletions(-) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index d481d68f..a98b1e3a 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -31,14 +31,26 @@ class Prompt(NamedTuple): labels: list[np.ndarray | int] | np.ndarray +class PredictedMask(NamedTuple): + mask: list[np.ndarray] + points: list[np.ndarray] | np.ndarray + + class SAMVisualPrompter: + """ + A wrapper that implements SAM Visual Prompter. + + Segmentation results can be obtained by calling infer() method + with corresponding parameters. + """ + def __init__( self, encoder_model: SAMImageEncoder, decoder_model: SAMDecoder, ): - self.encoder_model = encoder_model - self.decoder_model = decoder_model + self.encoder = encoder_model + self.decoder = decoder_model def infer( self, @@ -46,14 +58,28 @@ def infer( boxes: Prompt | None = None, points: Prompt | None = None, ) -> VisualPromptingResult: + """ + _summary_ + + Args: + image (np.ndarray): HWC-shaped image + boxes (Prompt | None, optional): _description_. Defaults to None. + points (Prompt | None, optional): _description_. Defaults to None. + + Raises: + RuntimeError: _description_ + + Returns: + VisualPromptingResult: _description_ + """ if boxes is None and points is None: raise RuntimeError("boxes or points prompts are required for inference") outputs: list[dict[str, Any]] = [] - processed_image, meta = self.encoder_model.preprocess(image) - image_embeddings = self.encoder_model.infer_sync(processed_image) - processed_prompts = self.decoder_model.preprocess( + processed_image, meta = self.encoder.preprocess(image) + image_embeddings = self.encoder.infer_sync(processed_image) + processed_prompts = self.decoder.preprocess( { "bboxes": boxes.data if boxes else None, "points": points.data if points else None, @@ -67,10 +93,10 @@ def infer( label = prompt.pop("label") prompt.update(**image_embeddings) - prediction = self.decoder_model.infer_sync(prompt) + prediction = self.decoder.infer_sync(prompt) prediction["scores"] = prediction["iou_predictions"] prediction["labels"] = label - processed_prediction = self.decoder_model.postprocess(prediction, meta) + processed_prediction = self.decoder.postprocess(prediction, meta) outputs.append(processed_prediction) return VisualPromptingResult( @@ -93,14 +119,19 @@ def __call__( class SAMLearnableVisualPrompter: + """ + A wrapper that provides ZSL Visual Prompting workflow. + To obtain segmentation results, one should run learn() first to obtain the reference features, + or use previously generated ones. + """ def __init__( self, encoder_model: SAMImageEncoder, decoder_model: SAMDecoder, reference_features: VisualPromptingFeatures | None = None, ): - self.encoder_model = encoder_model - self.decoder_model = decoder_model + self.encoder = encoder_model + self.decoder = decoder_model if reference_features is not None: self._reference_features = reference_features.feature_vectors @@ -116,15 +147,22 @@ def __init__( self._threshold: float = 0.0 self._num_bg_points: int = 1 self._default_threshold_target: float = 0.65 - self._image_size: int = self.encoder_model.image_size + self._image_size: int = self.encoder.image_size self._downsizing: int = 64 self._default_threshold_reference: float = 0.3 def has_reference_features(self) -> bool: + """ + Checks if reference features are stored in the object state. + """ return self._reference_features is not None and self._used_indices is not None @property def reference_features(self) -> VisualPromptingFeatures: + """ + Property represents reference features. An exception is thrown if called when + the features are not presented in the internal object state. + """ if self.has_reference_features(): return VisualPromptingFeatures(np.copy(self._reference_features), np.copy(self._used_indices)) @@ -140,8 +178,8 @@ def learn( """ Executes `learn` stage of SAM ZSL pipeline. - Reference features are updated according to newly arrived prompts. This method should not be run on different images - without resetting reference features. Consequent runs on the same image with preserving reference features make sense if new or refined prompts are passed. + Reference features are updated according to newly arrived prompts. Features corresponding to the same labels are overridden during + consequent learn() calls. Args: image (np.ndarray): HWC-shaped image @@ -160,7 +198,7 @@ def learn( if reset_features or not self.has_reference_features(): self.reset_reference_info() - processed_prompts = self.decoder_model.preprocess( + processed_prompts = self.decoder.preprocess( { "bboxes": boxes.data if boxes else None, "points": points.data if points else None, @@ -177,7 +215,7 @@ def learn( original_shape = np.array(image.shape[:2]) # forward image encoder - image_embeddings = self.encoder_model(image) + image_embeddings = self.encoder(image) processed_embedding = image_embeddings.squeeze().transpose(1, 2, 0) # get reference masks @@ -207,7 +245,7 @@ def learn( feats=processed_embedding, masks=ref_mask, threshold_mask=cur_default_threshold_reference, - image_size=self.encoder_model.image_size, + image_size=self.encoder.image_size, ) cur_default_threshold_reference -= 0.05 @@ -223,14 +261,31 @@ def __call__( self, image: np.ndarray, reference_features: VisualPromptingFeatures | None = None, - ): + ) -> dict[int, PredictedMask]: + """A wrapper of the SAMLearnableVisualPrompter.infer() method""" return self.infer(image, reference_features) def infer( self, image: np.ndarray, reference_features: VisualPromptingFeatures | None = None, - ): + ) -> dict[int, PredictedMask]: + """ + Obtains masks by already prepared reference features. + + Reference features can be obtained with SAMLearnableVisualPrompter.learn() and passed as an argument. + If the features are not passed, instance internal state will be used as a source of the features. + + Args: + image (np.ndarray): HWC-shaped image + reference_features (VisualPromptingFeatures | None, optional): Reference features object obtained during previous learn() calls. + If not passed, object internal state is used, which reflects the last learn() call. Defaults to None. + + Returns: + dict[int, PredictedMask]: Mapping label -> predicted mask. Each mask object contains a list of binary masks, and a list of + related prompts. Each binary mask corresponds to one prompt point. Class mask can be obtained by applying OR operation to all + mask corresponding to one label. + """ if reference_features is None: if self._reference_features is None: raise RuntimeError( @@ -249,7 +304,7 @@ def infer( reference_feats, used_idx = reference_features original_shape = np.array(image.shape[:2]) - image_embeddings = self.encoder_model(image) + image_embeddings = self.encoder(image) total_points_scores, total_bg_coords = _get_prompt_candidates( image_embeddings=image_embeddings, @@ -285,7 +340,7 @@ def infer( point_coords = np.concatenate( (np.array([[x, y]]), bg_coords), axis=0, dtype=np.float32 ) - point_coords = self.decoder_model.apply_coords( + point_coords = self.decoder.apply_coords( point_coords, original_shape ) point_labels = np.array([1] + [0] * len(bg_coords), dtype=np.float32) @@ -302,19 +357,23 @@ def infer( prediction.update({"scores": points_score[-1]}) predicted_masks[label].append( - prediction[self.decoder_model.output_blob_name] + prediction[self.decoder.output_blob_name] ) used_points[label].append(points_score) # check overlapping area between different label masks _inspect_overlapping_areas(predicted_masks, used_points) - return (predicted_masks, used_points) + prediction = {} + for k in used_points: + prediction[k] = PredictedMask(predicted_masks[k], used_points[k]) + + return prediction def reset_reference_info(self) -> None: """Initialize reference information.""" self._reference_features = np.zeros( - (0, 1, self.decoder_model.embed_dim), dtype=np.float32 + (0, 1, self.decoder.embed_dim), dtype=np.float32 ) self._used_indices = np.array([], dtype=np.int64) @@ -383,7 +442,7 @@ def _predict_masks( has_mask_input = self._has_mask_inputs[1] y, x = np.nonzero(masks) - box_coords = self.decoder_model.apply_coords( + box_coords = self.decoder.apply_coords( np.array( [[x.min(), y.min()], [x.max(), y.max()]], dtype=np.float32 ), @@ -402,13 +461,13 @@ def _predict_masks( ) inputs.update({"mask_input": mask_input, "has_mask_input": has_mask_input}) - prediction = self.decoder_model.infer_sync(inputs) + prediction = self.decoder.infer_sync(inputs) upscaled_masks, scores, logits = ( prediction["upscaled_masks"], prediction["iou_predictions"], prediction["low_res_masks"], ) - masks = upscaled_masks > self.decoder_model.mask_threshold + masks = upscaled_masks > self.decoder.mask_threshold _, masks = _decide_masks(masks, logits, scores) return {"upscaled_masks": masks} From 444632b0f5bf3dbe2e52cbb85944d05dabd2af27 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Thu, 20 Jun 2024 08:32:04 +0900 Subject: [PATCH 13/31] Update docs --- .../model_api/models/visual_prompting.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index a98b1e3a..ba486b5c 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -59,18 +59,17 @@ def infer( points: Prompt | None = None, ) -> VisualPromptingResult: """ - _summary_ + Obtains segmentation masks using given prompts. Args: image (np.ndarray): HWC-shaped image - boxes (Prompt | None, optional): _description_. Defaults to None. - points (Prompt | None, optional): _description_. Defaults to None. - - Raises: - RuntimeError: _description_ + boxes (Prompt | None, optional): Prompt containing bounding boxes (in XYXY torchvision format) + and their labels (ints, one per box). Defaults to None. + points (Prompt | None, optional): Prompt containing points (in XY format) + and their labels (ints, one per point). Defaults to None. Returns: - VisualPromptingResult: _description_ + VisualPromptingResult: result object containing predicted masks and aux information. """ if boxes is None and points is None: raise RuntimeError("boxes or points prompts are required for inference") @@ -115,6 +114,7 @@ def __call__( boxes: Prompt | None = None, points: Prompt | None = None, ) -> VisualPromptingResult: + """A wrapper of the SAMVisualPrompter.infer() method""" return self.infer(image, boxes, points) @@ -178,13 +178,16 @@ def learn( """ Executes `learn` stage of SAM ZSL pipeline. - Reference features are updated according to newly arrived prompts. Features corresponding to the same labels are overridden during + Reference features are updated according to newly arrived prompts. + Features corresponding to the same labels are overridden during consequent learn() calls. Args: image (np.ndarray): HWC-shaped image - boxes (Prompt | None, optional): Prompt containing bounding boxes (in XYXY torchvision format) and their labels (ints, one per box). Defaults to None. - points (Prompt | None, optional): Prompt containing bounding boxes (in XY format) and their labels (ints, one per point). Defaults to None. + boxes (Prompt | None, optional): Prompt containing bounding boxes (in XYXY torchvision format) + and their labels (ints, one per box). Defaults to None. + points (Prompt | None, optional): Prompt containing points (in XY format) + and their labels (ints, one per point). Defaults to None. reset_features (bool, optional): Forces learning from scratch. Defaults to False. Returns: From a3d118052528af95a30a89b1cedd0f4ba526920e Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Thu, 20 Jun 2024 08:32:59 +0900 Subject: [PATCH 14/31] Fix black --- .../model_api/models/visual_prompting.py | 61 +++++++++---------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index ba486b5c..28208377 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -82,8 +82,10 @@ def infer( { "bboxes": boxes.data if boxes else None, "points": points.data if points else None, - "labels": {"bboxes": boxes.labels if boxes else None, - "points": points.labels if points else None}, + "labels": { + "bboxes": boxes.labels if boxes else None, + "points": points.labels if points else None, + }, "orig_size": meta["original_shape"][:2], }, ) @@ -124,6 +126,7 @@ class SAMLearnableVisualPrompter: To obtain segmentation results, one should run learn() first to obtain the reference features, or use previously generated ones. """ + def __init__( self, encoder_model: SAMImageEncoder, @@ -164,7 +167,9 @@ def reference_features(self) -> VisualPromptingFeatures: the features are not presented in the internal object state. """ if self.has_reference_features(): - return VisualPromptingFeatures(np.copy(self._reference_features), np.copy(self._used_indices)) + return VisualPromptingFeatures( + np.copy(self._reference_features), np.copy(self._used_indices) + ) raise RuntimeError("Reference features are not generated") @@ -173,7 +178,7 @@ def learn( image: np.ndarray, boxes: Prompt | None = None, points: Prompt | None = None, - reset_features: bool = False + reset_features: bool = False, ) -> tuple[VisualPromptingFeatures, np.ndarray]: """ Executes `learn` stage of SAM ZSL pipeline. @@ -205,8 +210,10 @@ def learn( { "bboxes": boxes.data if boxes else None, "points": points.data if points else None, - "labels": {"bboxes": boxes.labels if boxes else None, - "points": points.labels if points else None}, + "labels": { + "bboxes": boxes.labels if boxes else None, + "points": points.labels if points else None, + }, "orig_size": image.shape[:2], }, ) @@ -253,7 +260,9 @@ def learn( cur_default_threshold_reference -= 0.05 self._reference_features[label] = ref_feat - self._used_indices: np.ndarray = np.concatenate((self._used_indices, [label])) + self._used_indices: np.ndarray = np.concatenate( + (self._used_indices, [label]) + ) ref_masks[label] = ref_mask self._used_indices = np.unique(self._used_indices) @@ -343,9 +352,7 @@ def infer( point_coords = np.concatenate( (np.array([[x, y]]), bg_coords), axis=0, dtype=np.float32 ) - point_coords = self.decoder.apply_coords( - point_coords, original_shape - ) + point_coords = self.decoder.apply_coords(point_coords, original_shape) point_labels = np.array([1] + [0] * len(bg_coords), dtype=np.float32) inputs_decoder = { "point_coords": point_coords[None], @@ -354,14 +361,10 @@ def infer( } inputs_decoder["image_embeddings"] = image_embeddings - prediction = self._predict_masks( - inputs_decoder, original_shape, True - ) + prediction = self._predict_masks(inputs_decoder, original_shape, True) prediction.update({"scores": points_score[-1]}) - predicted_masks[label].append( - prediction[self.decoder.output_blob_name] - ) + predicted_masks[label].append(prediction[self.decoder.output_blob_name]) used_points[label].append(points_score) # check overlapping area between different label masks @@ -402,7 +405,9 @@ def _expand_reference_info(self, new_largest_label: int) -> None: if new_largest_label > (cur_largest_label := len(self._reference_features) - 1): diff = new_largest_label - cur_largest_label self._reference_features = np.pad( - self._reference_features, ((0, diff), (0, 0), (0, 0)), constant_values=0.0 + self._reference_features, + ((0, diff), (0, 0), (0, 0)), + constant_values=0.0, ) def _predict_masks( @@ -437,9 +442,7 @@ def _predict_masks( elif i == 2: # Cascaded Post-refinement-2 - mask_input, masks = _decide_masks( - masks, logits, scores - ) # noqa: F821 + mask_input, masks = _decide_masks(masks, logits, scores) # noqa: F821 if masks.sum() == 0: return {"upscaled_masks": masks} @@ -640,9 +643,9 @@ def _point_selection( ) ## sample the highest score one of the samples that are in the same grid - matched_indices = _topk_numpy( - matched_grid[..., -1], k=1, axis=0, largest=True - )[1][0].astype(np.int64) + matched_indices = _topk_numpy(matched_grid[..., -1], k=1, axis=0, largest=True)[1][ + 0 + ].astype(np.int64) points_scores = matched_grid[matched_indices].diagonal().T ## sort by the highest score @@ -652,9 +655,7 @@ def _point_selection( points_scores = points_scores[sorted_points_scores_indices] # Top-last point selection - bg_indices = _topk_numpy(mask_sim.flatten(), num_bg_points, largest=False)[ - 1 - ] + bg_indices = _topk_numpy(mask_sim.flatten(), num_bg_points, largest=False)[1] bg_x = np.expand_dims(bg_indices // w_sim, axis=0) bg_y = bg_indices - bg_x * w_sim bg_coords = np.concatenate((bg_y, bg_x), axis=0).transpose(1, 0) @@ -668,9 +669,7 @@ def _resize_to_original_shape( ) -> np.ndarray: """Resize feature size to original shape.""" # resize feature size to input size - masks = cv2.resize( - masks, (image_size, image_size), interpolation=cv2.INTER_LINEAR - ) + masks = cv2.resize(masks, (image_size, image_size), interpolation=cv2.INTER_LINEAR) # remove pad prepadded_size = _get_prepadded_size(original_shape, image_size) @@ -698,9 +697,7 @@ def _topk_numpy( indices = range(k, 0) else: indices = range(k) - partitioned_ind = np.argpartition(x, k, axis=axis).take( - indices=indices, axis=axis - ) + partitioned_ind = np.argpartition(x, k, axis=axis).take(indices=indices, axis=axis) partitioned_scores = np.take_along_axis(x, partitioned_ind, axis=axis) sorted_trunc_ind = np.argsort(partitioned_scores, axis=axis) if largest: From d8500fb5a24d9f1d3a4652ada30537d0ab26b98f Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Thu, 20 Jun 2024 10:35:17 +0900 Subject: [PATCH 15/31] Add SAM to testdata --- model_api/python/model_api/models/utils.py | 37 ++++++++++++++++++---- tests/python/accuracy/prepare_data.py | 2 ++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/model_api/python/model_api/models/utils.py b/model_api/python/model_api/models/utils.py index a8e649b6..cce6e937 100644 --- a/model_api/python/model_api/models/utils.py +++ b/model_api/python/model_api/models/utils.py @@ -137,13 +137,13 @@ def __str__(self): class VisualPromptingResult(NamedTuple): - upscaled_masks: list[np.ndarray] | None = None - low_res_masks: list[np.ndarray] | None = None - iou_predictions: list[np.ndarray] | None = None - scores: list[np.ndarray] | None = None - labels: list[np.ndarray] | None = None - hard_predictions: list[np.ndarray] | None = None - soft_predictions: list[np.ndarray] | None = None + upscaled_masks: List[np.ndarray] | None = None + low_res_masks: List[np.ndarray] | None = None + iou_predictions: List[np.ndarray] | None = None + scores: List[np.ndarray] | None = None + labels: List[np.ndarray] | None = None + hard_predictions: List[np.ndarray] | None = None + soft_predictions: List[np.ndarray] | None = None def _compute_min_max(self, tensor: np.ndarray) -> tuple[np.ndarray, np.ndarray]: return tensor.min(), tensor.max() @@ -161,6 +161,29 @@ def __str__(self) -> str: ) +class PredictedMask(NamedTuple): + mask: list[np.ndarray] + points: list[np.ndarray] | np.ndarray + + def __str__(self) -> str: + obj_str = "" + obj_str += f"mask sum: {np.sum(sum(self.mask))}; " + + if isinstance(self.points, list): + for point in self.points: + obj_str += "[" + obj_str += ", ".join(str(round(c, 2)) for c in point) + obj_str += "] " + else: + for i in range(self.points.shape[0]): + point = self.points[i] + obj_str += "[" + obj_str += ", ".join(str(round(c, 2)) for c in point) + obj_str += "] " + + return obj_str + + def add_rotated_rects(segmented_objects): objects_with_rects = [] for segmented_object in segmented_objects: diff --git a/tests/python/accuracy/prepare_data.py b/tests/python/accuracy/prepare_data.py index e11ced03..071f5b0b 100644 --- a/tests/python/accuracy/prepare_data.py +++ b/tests/python/accuracy/prepare_data.py @@ -113,6 +113,8 @@ async def main(): client, otx_models_dir, "cls_efficient_b0_shuffled_outputs" ), download_otx_model(client, otx_models_dir, "action_cls_xd3_kinetic"), + download_otx_model(client, otx_models_dir, "sam_vit_b_zsl_encoder"), + download_otx_model(client, otx_models_dir, "sam_vit_b_zsl_decoder"), ) From b4b149dec4a27b29c64f499f976648d67ee99299 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Fri, 21 Jun 2024 15:57:01 +0900 Subject: [PATCH 16/31] Update result objects --- model_api/python/model_api/models/__init__.py | 10 +++++ model_api/python/model_api/models/utils.py | 8 +++- .../model_api/models/visual_prompting.py | 15 +++---- tests/python/accuracy/test_accuracy.py | 39 ++++++++++++++++++- 4 files changed, 60 insertions(+), 12 deletions(-) diff --git a/model_api/python/model_api/models/__init__.py b/model_api/python/model_api/models/__init__.py index ed6109bc..50994405 100644 --- a/model_api/python/model_api/models/__init__.py +++ b/model_api/python/model_api/models/__init__.py @@ -39,6 +39,7 @@ from .retinaface import RetinaFace, RetinaFacePyTorch from .sam_models import SAMDecoder, SAMImageEncoder from .segmentation import SalientObjectDetectionModel, SegmentationModel +from .visual_prompting import SAMVisualPrompter, SAMLearnableVisualPrompter, Prompt from .ssd import SSD from .ultra_lightweight_face_detection import UltraLightweightFaceDetection from .utils import ( @@ -50,6 +51,9 @@ DetectionWithLandmarks, ImageResultWithSoftPrediction, InstanceSegmentationResult, + VisualPromptingResult, + PredictedMask, + ZSLVisualPromptingResult, OutputTransform, SegmentedObject, SegmentedObjectWithRects, @@ -98,6 +102,11 @@ "ImageModel", "ImageResultWithSoftPrediction", "InstanceSegmentationResult", + "VisualPromptingResult", + "ZSLVisualPromptingResult", + "PredictedMask", + "SAMVisualPrompter", + "SAMLearnableVisualPrompter", "MaskRCNNModel", "Model", "MonoDepthModel", @@ -124,6 +133,7 @@ "SAMDecoder", "SAMImageEncoder", "ClassificationResult", + "Prompt", "Detection", "DetectionResult", "DetectionWithLandmarks", diff --git a/model_api/python/model_api/models/utils.py b/model_api/python/model_api/models/utils.py index cce6e937..e06459c5 100644 --- a/model_api/python/model_api/models/utils.py +++ b/model_api/python/model_api/models/utils.py @@ -181,7 +181,13 @@ def __str__(self) -> str: obj_str += ", ".join(str(round(c, 2)) for c in point) obj_str += "] " - return obj_str + return obj_str.strip() + + +class ZSLVisualPromptingResult(NamedTuple): + data: dict[int, PredictedMask] + def __str__(self) -> str: + return ", ".join(str(self.data[k]) for k in self.data) def add_rotated_rects(segmented_objects): diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index 28208377..0712d9ee 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -18,7 +18,7 @@ import cv2 import numpy as np from model_api.models import SAMDecoder, SAMImageEncoder -from model_api.models.utils import VisualPromptingResult +from model_api.models.utils import PredictedMask, VisualPromptingResult, ZSLVisualPromptingResult class VisualPromptingFeatures(NamedTuple): @@ -31,11 +31,6 @@ class Prompt(NamedTuple): labels: list[np.ndarray | int] | np.ndarray -class PredictedMask(NamedTuple): - mask: list[np.ndarray] - points: list[np.ndarray] | np.ndarray - - class SAMVisualPrompter: """ A wrapper that implements SAM Visual Prompter. @@ -273,7 +268,7 @@ def __call__( self, image: np.ndarray, reference_features: VisualPromptingFeatures | None = None, - ) -> dict[int, PredictedMask]: + ) -> ZSLVisualPromptingResult: """A wrapper of the SAMLearnableVisualPrompter.infer() method""" return self.infer(image, reference_features) @@ -281,7 +276,7 @@ def infer( self, image: np.ndarray, reference_features: VisualPromptingFeatures | None = None, - ) -> dict[int, PredictedMask]: + ) -> ZSLVisualPromptingResult: """ Obtains masks by already prepared reference features. @@ -294,7 +289,7 @@ def infer( If not passed, object internal state is used, which reflects the last learn() call. Defaults to None. Returns: - dict[int, PredictedMask]: Mapping label -> predicted mask. Each mask object contains a list of binary masks, and a list of + ZSLVisualPromptingResult: Mapping label -> predicted mask. Each mask object contains a list of binary masks, and a list of related prompts. Each binary mask corresponds to one prompt point. Class mask can be obtained by applying OR operation to all mask corresponding to one label. """ @@ -374,7 +369,7 @@ def infer( for k in used_points: prediction[k] = PredictedMask(predicted_masks[k], used_points[k]) - return prediction + return ZSLVisualPromptingResult(prediction) def reset_reference_info(self) -> None: """Initialize reference information.""" diff --git a/tests/python/accuracy/test_accuracy.py b/tests/python/accuracy/test_accuracy.py index bb87dc85..dbeec0f1 100644 --- a/tests/python/accuracy/test_accuracy.py +++ b/tests/python/accuracy/test_accuracy.py @@ -17,11 +17,19 @@ ClassificationResult, DetectionModel, DetectionResult, + VisualPromptingResult, + PredictedMask, ImageModel, ImageResultWithSoftPrediction, InstanceSegmentationResult, MaskRCNNModel, SegmentationModel, + SAMImageEncoder, + SAMDecoder, + SAMVisualPrompter, + SAMLearnableVisualPrompter, + ZSLVisualPromptingResult, + Prompt, add_rotated_rects, get_contours, ) @@ -108,6 +116,18 @@ def test_image_models(data, dump, result, model_data): ) else: model = eval(model_data["tiler"])(model, configuration={}) + elif "prompter" in model_data: + encoder_adapter = OpenvinoAdapter( + create_core(), f"{data}/{model_data['encoder']}", device="CPU" + ) + + encoder_model = eval(model_data["encoder_type"])( + encoder_adapter, configuration={}, preload=True + ) + model = eval(model_data["prompter"])( + encoder_model, + model + ) if dump: result.append(model_data) @@ -122,7 +142,14 @@ def test_image_models(data, dump, result, model_data): image = cv2.resize(image, eval(model_data["input_res"])) if isinstance(model, ActionClassificationModel): image = np.stack([image for _ in range(8)]) - outputs = model(image) + if "prompter" in model_data: + if model_data["prompter"] == "SAMLearnableVisualPrompter": + model.learn(image, points=Prompt(np.array([image.shape[0] / 2, image.shape[1] / 2]).reshape(1, 2), [0])) + outputs = model(image) + else: + outputs = model(image, points=Prompt(np.array([image.shape[0] / 2, image.shape[1] / 2]).reshape(1, 2), [0])) + else: + outputs = model(image) if isinstance(outputs, ClassificationResult): assert 1 == len(test_data["reference"]) output_str = str(outputs) @@ -175,6 +202,14 @@ def test_image_models(data, dump, result, model_data): output_str = str(outputs) assert test_data["reference"][0] == output_str image_result = [output_str] + elif isinstance(outputs, ZSLVisualPromptingResult): + output_str = str(outputs) + assert test_data["reference"][0] == output_str + image_result = [output_str] + elif isinstance(outputs, VisualPromptingResult): + output_str = str(outputs) + assert test_data["reference"][0] == output_str + image_result = [output_str] else: assert False if dump: @@ -189,6 +224,8 @@ def test_image_models(data, dump, result, model_data): if not model_data.get("force_ort", False): if "tiler" in model_data: model.get_model().save(data + "/serialized/" + save_name) + elif "prompter" in model_data: + pass else: model.save(data + "/serialized/" + save_name) if model_data.get("check_extra_rt_info", False): From ba41c967888b1c0e81496c8a5b46b412b0d05055 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Fri, 21 Jun 2024 15:57:59 +0900 Subject: [PATCH 17/31] Add tests --- tests/python/accuracy/public_scope.json | 26 +++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/python/accuracy/public_scope.json b/tests/python/accuracy/public_scope.json index 6f1dca5b..dcd49078 100644 --- a/tests/python/accuracy/public_scope.json +++ b/tests/python/accuracy/public_scope.json @@ -406,5 +406,31 @@ "reference": ["38 (WritingOnBoard): 0.096, [0], [0], [0]"] } ] + }, + { + "name": "otx_models/sam_vit_b_zsl_decoder.xml", + "type": "SAMDecoder", + "prompter": "SAMLearnableVisualPrompter", + "encoder": "otx_models/sam_vit_b_zsl_encoder.xml", + "encoder_type": "SAMImageEncoder", + "test_data": [ + { + "image": "coco128/images/train2017/000000000471.jpg", + "reference": ["mask sum: 108565; [385.0, 315.0, 0.93] [335.0, 414.0, 0.76] [44.0, 205.0, 0.66] [605.0, 224.0, 0.65]"] + } + ] + }, + { + "name": "otx_models/sam_vit_b_zsl_decoder.xml", + "type": "SAMDecoder", + "prompter": "SAMVisualPrompter", + "encoder": "otx_models/sam_vit_b_zsl_encoder.xml", + "encoder_type": "SAMImageEncoder", + "test_data": [ + { + "image": "coco128/images/train2017/000000000471.jpg", + "reference": ["38 (WritingOnBoard): 0.096, [0], [0], [0]"] + } + ] } ] From 9f3b7138f3fddfdeccea9ee4a7152a29fdd543ac Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 22 Jun 2024 05:36:46 +0900 Subject: [PATCH 18/31] Update decoder postprocessing --- model_api/python/model_api/models/sam_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model_api/python/model_api/models/sam_models.py b/model_api/python/model_api/models/sam_models.py index acc4f549..66676f43 100644 --- a/model_api/python/model_api/models/sam_models.py +++ b/model_api/python/model_api/models/sam_models.py @@ -192,9 +192,9 @@ def postprocess( """ probability = np.clip(outputs["scores"], 0.0, 1.0) hard_prediction = ( - outputs[self.output_blob_name].squeeze(1) > self.mask_threshold + outputs[self.output_blob_name].squeeze(0) > self.mask_threshold ).astype(np.uint8) - soft_prediction = hard_prediction * probability + soft_prediction = hard_prediction * probability.reshape(-1, 1, 1) outputs["hard_prediction"] = hard_prediction outputs["soft_prediction"] = soft_prediction From c1d9deb005bce5839f3c0906b16240af075e50ed Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 22 Jun 2024 05:42:56 +0900 Subject: [PATCH 19/31] Update SAM ref results --- tests/python/accuracy/public_scope.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/accuracy/public_scope.json b/tests/python/accuracy/public_scope.json index dcd49078..e0f05ad5 100644 --- a/tests/python/accuracy/public_scope.json +++ b/tests/python/accuracy/public_scope.json @@ -429,7 +429,7 @@ "test_data": [ { "image": "coco128/images/train2017/000000000471.jpg", - "reference": ["38 (WritingOnBoard): 0.096, [0], [0], [0]"] + "reference": ["upscaled_masks min:-25.906675338745117 max:11.185405731201172;hard_predictions shape:(4, 427, 640);"] } ] } From 2938bd2ddb73508e335b3ac8858d3f8fc2fd53be Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 22 Jun 2024 05:43:54 +0900 Subject: [PATCH 20/31] Fix linters --- model_api/python/model_api/models/__init__.py | 8 ++++---- model_api/python/model_api/models/utils.py | 1 + model_api/python/model_api/models/visual_prompting.py | 6 +++++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/model_api/python/model_api/models/__init__.py b/model_api/python/model_api/models/__init__.py index 50994405..6dbc05c5 100644 --- a/model_api/python/model_api/models/__init__.py +++ b/model_api/python/model_api/models/__init__.py @@ -39,7 +39,6 @@ from .retinaface import RetinaFace, RetinaFacePyTorch from .sam_models import SAMDecoder, SAMImageEncoder from .segmentation import SalientObjectDetectionModel, SegmentationModel -from .visual_prompting import SAMVisualPrompter, SAMLearnableVisualPrompter, Prompt from .ssd import SSD from .ultra_lightweight_face_detection import UltraLightweightFaceDetection from .utils import ( @@ -51,15 +50,16 @@ DetectionWithLandmarks, ImageResultWithSoftPrediction, InstanceSegmentationResult, - VisualPromptingResult, - PredictedMask, - ZSLVisualPromptingResult, OutputTransform, + PredictedMask, SegmentedObject, SegmentedObjectWithRects, + VisualPromptingResult, + ZSLVisualPromptingResult, add_rotated_rects, get_contours, ) +from .visual_prompting import Prompt, SAMLearnableVisualPrompter, SAMVisualPrompter from .yolo import YOLO, YOLOF, YOLOX, YoloV3ONNX, YoloV4, YOLOv5, YOLOv8 classification_models = [ diff --git a/model_api/python/model_api/models/utils.py b/model_api/python/model_api/models/utils.py index e06459c5..87cdc8a0 100644 --- a/model_api/python/model_api/models/utils.py +++ b/model_api/python/model_api/models/utils.py @@ -186,6 +186,7 @@ def __str__(self) -> str: class ZSLVisualPromptingResult(NamedTuple): data: dict[int, PredictedMask] + def __str__(self) -> str: return ", ".join(str(self.data[k]) for k in self.data) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index 0712d9ee..3599380c 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -18,7 +18,11 @@ import cv2 import numpy as np from model_api.models import SAMDecoder, SAMImageEncoder -from model_api.models.utils import PredictedMask, VisualPromptingResult, ZSLVisualPromptingResult +from model_api.models.utils import ( + PredictedMask, + VisualPromptingResult, + ZSLVisualPromptingResult, +) class VisualPromptingFeatures(NamedTuple): From 48441537b72bfc986511d0424b95ddb8b85b7f2d Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 22 Jun 2024 05:47:39 +0900 Subject: [PATCH 21/31] Skip SAM in cpp tests --- tests/cpp/accuracy/test_accuracy.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/cpp/accuracy/test_accuracy.cpp b/tests/cpp/accuracy/test_accuracy.cpp index 00119649..a5cfa561 100644 --- a/tests/cpp/accuracy/test_accuracy.cpp +++ b/tests/cpp/accuracy/test_accuracy.cpp @@ -131,7 +131,10 @@ TEST_P(ModelParameterizedTest, AccuracyTest) GTEST_SKIP() << "ONNX models are not supported in C++ implementation"; } if (name.find("action_cls_xd3_kinetic") != std::string::npos) { - GTEST_SKIP() << "ActionClassificationModel are not supported in C++ implementation"; + GTEST_SKIP() << "ActionClassificationModel is not supported in C++ implementation"; + } + if (name.find("sam_vit_b") != std::string::npos) { + GTEST_SKIP() << "SAM-based models are not supported in C++ implementation"; } if (name.substr(name.size() - 4) == ".xml") { From 82dd778b43d6ef8c517b110e8fac549294e865c2 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 22 Jun 2024 06:12:06 +0900 Subject: [PATCH 22/31] Workaround unsupported type annotation --- model_api/python/model_api/models/sam_models.py | 2 ++ model_api/python/model_api/models/visual_prompting.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/model_api/python/model_api/models/sam_models.py b/model_api/python/model_api/models/sam_models.py index 66676f43..d37a3c18 100644 --- a/model_api/python/model_api/models/sam_models.py +++ b/model_api/python/model_api/models/sam_models.py @@ -14,6 +14,8 @@ limitations under the License. """ +from __future__ import annotations # TODO: remove when Python3.9 support is dropped + from copy import deepcopy from typing import Any, Dict diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index 3599380c..a5001cf6 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -11,6 +11,8 @@ limitations under the License. """ +from __future__ import annotations # TODO: remove when Python3.9 support is dropped + from collections import defaultdict from itertools import product from typing import Any, NamedTuple From 5a9a497ce9f1f359fec0be9e19c5aa3f8eb1c751 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 22 Jun 2024 06:18:39 +0900 Subject: [PATCH 23/31] Fix black --- tests/python/accuracy/test_accuracy.py | 37 +++++++++++++++++--------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/tests/python/accuracy/test_accuracy.py b/tests/python/accuracy/test_accuracy.py index dbeec0f1..3345edcd 100644 --- a/tests/python/accuracy/test_accuracy.py +++ b/tests/python/accuracy/test_accuracy.py @@ -17,19 +17,19 @@ ClassificationResult, DetectionModel, DetectionResult, - VisualPromptingResult, - PredictedMask, ImageModel, ImageResultWithSoftPrediction, InstanceSegmentationResult, MaskRCNNModel, - SegmentationModel, - SAMImageEncoder, + PredictedMask, + Prompt, SAMDecoder, - SAMVisualPrompter, + SAMImageEncoder, SAMLearnableVisualPrompter, + SAMVisualPrompter, + SegmentationModel, + VisualPromptingResult, ZSLVisualPromptingResult, - Prompt, add_rotated_rects, get_contours, ) @@ -124,10 +124,7 @@ def test_image_models(data, dump, result, model_data): encoder_model = eval(model_data["encoder_type"])( encoder_adapter, configuration={}, preload=True ) - model = eval(model_data["prompter"])( - encoder_model, - model - ) + model = eval(model_data["prompter"])(encoder_model, model) if dump: result.append(model_data) @@ -144,10 +141,26 @@ def test_image_models(data, dump, result, model_data): image = np.stack([image for _ in range(8)]) if "prompter" in model_data: if model_data["prompter"] == "SAMLearnableVisualPrompter": - model.learn(image, points=Prompt(np.array([image.shape[0] / 2, image.shape[1] / 2]).reshape(1, 2), [0])) + model.learn( + image, + points=Prompt( + np.array([image.shape[0] / 2, image.shape[1] / 2]).reshape( + 1, 2 + ), + [0], + ), + ) outputs = model(image) else: - outputs = model(image, points=Prompt(np.array([image.shape[0] / 2, image.shape[1] / 2]).reshape(1, 2), [0])) + outputs = model( + image, + points=Prompt( + np.array([image.shape[0] / 2, image.shape[1] / 2]).reshape( + 1, 2 + ), + [0], + ), + ) else: outputs = model(image) if isinstance(outputs, ClassificationResult): From a2ab4bcc14519a498af4655730cd942b08cc12b5 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 22 Jun 2024 07:45:07 +0900 Subject: [PATCH 24/31] Replace prompt->list --- model_api/python/model_api/models/utils.py | 2 +- .../model_api/models/visual_prompting.py | 40 +++++++++---------- tests/python/accuracy/public_scope.json | 4 +- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/model_api/python/model_api/models/utils.py b/model_api/python/model_api/models/utils.py index 87cdc8a0..1eeafb17 100644 --- a/model_api/python/model_api/models/utils.py +++ b/model_api/python/model_api/models/utils.py @@ -156,7 +156,7 @@ def __str__(self) -> str: ) return ( - f"upscaled_masks min:{upscaled_masks_min} max:{upscaled_masks_max};" + f"upscaled_masks min:{upscaled_masks_min:.3f} max:{upscaled_masks_max:.3f};" f"hard_predictions shape:{self.hard_predictions[0].shape};" ) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index a5001cf6..7374c050 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -33,8 +33,8 @@ class VisualPromptingFeatures(NamedTuple): class Prompt(NamedTuple): - data: list[np.ndarray] | np.ndarray - labels: list[np.ndarray | int] | np.ndarray + data: np.ndarray + label: int | np.ndarray class SAMVisualPrompter: @@ -56,17 +56,17 @@ def __init__( def infer( self, image: np.ndarray, - boxes: Prompt | None = None, - points: Prompt | None = None, + boxes: list[Prompt] | None = None, + points: list[Prompt] | None = None, ) -> VisualPromptingResult: """ Obtains segmentation masks using given prompts. Args: image (np.ndarray): HWC-shaped image - boxes (Prompt | None, optional): Prompt containing bounding boxes (in XYXY torchvision format) + boxes (list[Prompt] | None, optional): Prompts containing bounding boxes (in XYXY torchvision format) and their labels (ints, one per box). Defaults to None. - points (Prompt | None, optional): Prompt containing points (in XY format) + points (list[Prompt] | None, optional): Prompts containing points (in XY format) and their labels (ints, one per point). Defaults to None. Returns: @@ -81,11 +81,11 @@ def infer( image_embeddings = self.encoder.infer_sync(processed_image) processed_prompts = self.decoder.preprocess( { - "bboxes": boxes.data if boxes else None, - "points": points.data if points else None, + "bboxes": [box.data for box in boxes] if boxes else None, + "points": [point.data for point in points] if points else None, "labels": { - "bboxes": boxes.labels if boxes else None, - "points": points.labels if points else None, + "bboxes": [box.label for box in boxes] if boxes else None, + "points": [point.label for point in points] if points else None, }, "orig_size": meta["original_shape"][:2], }, @@ -114,8 +114,8 @@ def infer( def __call__( self, image: np.ndarray, - boxes: Prompt | None = None, - points: Prompt | None = None, + boxes: list[Prompt] | None = None, + points: list[Prompt] | None = None, ) -> VisualPromptingResult: """A wrapper of the SAMVisualPrompter.infer() method""" return self.infer(image, boxes, points) @@ -177,8 +177,8 @@ def reference_features(self) -> VisualPromptingFeatures: def learn( self, image: np.ndarray, - boxes: Prompt | None = None, - points: Prompt | None = None, + boxes: list[Prompt] | None = None, + points: list[Prompt] | None = None, reset_features: bool = False, ) -> tuple[VisualPromptingFeatures, np.ndarray]: """ @@ -190,9 +190,9 @@ def learn( Args: image (np.ndarray): HWC-shaped image - boxes (Prompt | None, optional): Prompt containing bounding boxes (in XYXY torchvision format) + boxes (list[Prompt] | None, optional): Prompts containing bounding boxes (in XYXY torchvision format) and their labels (ints, one per box). Defaults to None. - points (Prompt | None, optional): Prompt containing points (in XY format) + points (list[Prompt] | None, optional): Prompts containing points (in XY format) and their labels (ints, one per point). Defaults to None. reset_features (bool, optional): Forces learning from scratch. Defaults to False. @@ -209,11 +209,11 @@ def learn( processed_prompts = self.decoder.preprocess( { - "bboxes": boxes.data if boxes else None, - "points": points.data if points else None, + "bboxes": [box.data for box in boxes] if boxes else None, + "points": [point.data for point in points] if points else None, "labels": { - "bboxes": boxes.labels if boxes else None, - "points": points.labels if points else None, + "bboxes": [box.label for box in boxes] if boxes else None, + "points": [point.label for point in points] if points else None, }, "orig_size": image.shape[:2], }, diff --git a/tests/python/accuracy/public_scope.json b/tests/python/accuracy/public_scope.json index e0f05ad5..8756a431 100644 --- a/tests/python/accuracy/public_scope.json +++ b/tests/python/accuracy/public_scope.json @@ -66,7 +66,7 @@ "test_data": [ { "image": "coco128/images/train2017/000000000074.jpg", - "reference": ["0: 0.944, 1: 0.056, [426,640,2], [0], [0]; object: 0.505, 2, object: 0.518, 8, object: 0.512, 5, object: 0.506, 4, object: 0.526, 8, object: 0.529, 21, object: 0.513, 12, object: 0.536, 49, object: 0.505, 2, object: 0.512, 4, object: 0.547, 6, object: 0.511, 6, object: 0.503, 1, object: 0.539, 6, object: 0.543, 39, object: 0.529, 2, object: 0.516, 9, object: 0.565, 157, object: 0.524, 6, object: 0.528, 15, object: 0.521, 18, object: 0.503, 1, object: 0.537, 73, object: 0.513, 4, object: 0.524, 27, object: 0.513, 6, object: 0.538, 65, object: 0.501, 6, object: 0.504, 1, object: 0.507, 4, object: 0.502, 1, object: 0.518, 8, object: 0.530, 11, object: 0.502, 2, object: 0.516, 2, object: 0.506, 1, object: 0.567, 17, object: 0.502, 1, object: 0.512, 7, object: 0.538, 24, object: 0.507, 1, object: 0.534, 12, object: 0.537, 6, object: 0.519, 13, object: 0.505, 2, object: 0.517, 16, object: 0.505, 5, object: 0.506, 20, object: 0.508, 6, object: 0.519, 24, object: 0.507, 4, object: 0.506, 2, object: 0.511, 4, object: 0.556, 47, object: 0.510, 10, object: 0.500, 1, object: 0.504, 5, object: 0.501, 1, object: 0.510, 6, object: 0.549, 13, object: 0.509, 2, object: 0.510, 3, object: 0.514, 1, object: 0.529, 15, object: 0.551, 110, object: 0.504, 2, object: 0.503, 3, object: 0.518, 16, object: 0.511, 14, object: 0.502, 1, object: 0.523, 1, object: 0.533, 16, object: 0.568, 65, object: 0.582, 1793, "] + "reference": ["0: 0.944, 1: 0.056, [426,640,2], [0], [0]; object: 0.505, 2, object: 0.518, 8, object: 0.512, 5, object: 0.506, 4, object: 0.526, 8, object: 0.529, 21, object: 0.513, 12, object: 0.535, 49, object: 0.505, 2, object: 0.512, 4, object: 0.547, 6, object: 0.511, 6, object: 0.503, 1, object: 0.539, 6, object: 0.543, 39, object: 0.529, 2, object: 0.516, 9, object: 0.565, 157, object: 0.524, 6, object: 0.528, 15, object: 0.521, 18, object: 0.502, 1, object: 0.537, 73, object: 0.513, 4, object: 0.524, 27, object: 0.513, 6, object: 0.538, 65, object: 0.501, 6, object: 0.504, 1, object: 0.507, 4, object: 0.502, 1, object: 0.518, 8, object: 0.530, 11, object: 0.502, 2, object: 0.516, 2, object: 0.506, 1, object: 0.567, 17, object: 0.502, 1, object: 0.512, 7, object: 0.538, 24, object: 0.507, 1, object: 0.534, 12, object: 0.537, 6, object: 0.519, 13, object: 0.505, 2, object: 0.517, 16, object: 0.505, 5, object: 0.506, 20, object: 0.508, 6, object: 0.519, 24, object: 0.507, 4, object: 0.506, 2, object: 0.511, 4, object: 0.556, 47, object: 0.510, 10, object: 0.500, 1, object: 0.504, 5, object: 0.501, 1, object: 0.510, 6, object: 0.549, 13, object: 0.509, 2, object: 0.510, 3, object: 0.514, 1, object: 0.529, 15, object: 0.551, 110, object: 0.504, 2, object: 0.503, 3, object: 0.518, 16, object: 0.511, 14, object: 0.502, 1, object: 0.523, 1, object: 0.533, 16, object: 0.568, 66, object: 0.582, 1793, "] } ] }, @@ -429,7 +429,7 @@ "test_data": [ { "image": "coco128/images/train2017/000000000471.jpg", - "reference": ["upscaled_masks min:-25.906675338745117 max:11.185405731201172;hard_predictions shape:(4, 427, 640);"] + "reference": ["upscaled_masks min:-25.907 max:11.185;hard_predictions shape:(4, 427, 640);"] } ] } From 162dfc9aa922adf78c6184a9ca21ab38faf5020c Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Sat, 22 Jun 2024 08:05:36 +0900 Subject: [PATCH 25/31] Fix python tests --- tests/python/accuracy/test_accuracy.py | 30 +++++++++++--------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/tests/python/accuracy/test_accuracy.py b/tests/python/accuracy/test_accuracy.py index 3345edcd..b849f741 100644 --- a/tests/python/accuracy/test_accuracy.py +++ b/tests/python/accuracy/test_accuracy.py @@ -143,23 +143,23 @@ def test_image_models(data, dump, result, model_data): if model_data["prompter"] == "SAMLearnableVisualPrompter": model.learn( image, - points=Prompt( - np.array([image.shape[0] / 2, image.shape[1] / 2]).reshape( - 1, 2 - ), - [0], - ), + points=[ + Prompt( + np.array([image.shape[0] / 2, image.shape[1] / 2]), + 0, + ) + ], ) outputs = model(image) else: outputs = model( image, - points=Prompt( - np.array([image.shape[0] / 2, image.shape[1] / 2]).reshape( - 1, 2 - ), - [0], - ), + points=[ + Prompt( + np.array([image.shape[0] / 2, image.shape[1] / 2]), + 0, + ) + ], ) else: outputs = model(image) @@ -215,11 +215,7 @@ def test_image_models(data, dump, result, model_data): output_str = str(outputs) assert test_data["reference"][0] == output_str image_result = [output_str] - elif isinstance(outputs, ZSLVisualPromptingResult): - output_str = str(outputs) - assert test_data["reference"][0] == output_str - image_result = [output_str] - elif isinstance(outputs, VisualPromptingResult): + elif isinstance(outputs, (ZSLVisualPromptingResult, VisualPromptingResult)): output_str = str(outputs) assert test_data["reference"][0] == output_str image_result = [output_str] From b56ee043aff5651227a577a830e40882b831871f Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Tue, 25 Jun 2024 07:40:29 +0900 Subject: [PATCH 26/31] Restore bool mask output from mask decoder --- model_api/python/model_api/models/sam_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_api/python/model_api/models/sam_models.py b/model_api/python/model_api/models/sam_models.py index d37a3c18..42efe32a 100644 --- a/model_api/python/model_api/models/sam_models.py +++ b/model_api/python/model_api/models/sam_models.py @@ -195,7 +195,7 @@ def postprocess( probability = np.clip(outputs["scores"], 0.0, 1.0) hard_prediction = ( outputs[self.output_blob_name].squeeze(0) > self.mask_threshold - ).astype(np.uint8) + ) soft_prediction = hard_prediction * probability.reshape(-1, 1, 1) outputs["hard_prediction"] = hard_prediction From 5cfaca2db94b07502c575f19b13156b5b54067ed Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Tue, 25 Jun 2024 07:54:28 +0900 Subject: [PATCH 27/31] Improve usability of ZSL VPT result --- model_api/python/model_api/models/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/model_api/python/model_api/models/utils.py b/model_api/python/model_api/models/utils.py index 1eeafb17..0f5bf16a 100644 --- a/model_api/python/model_api/models/utils.py +++ b/model_api/python/model_api/models/utils.py @@ -190,6 +190,10 @@ class ZSLVisualPromptingResult(NamedTuple): def __str__(self) -> str: return ", ".join(str(self.data[k]) for k in self.data) + def get_mask(self, label: int) -> PredictedMask: + """Returns a mask belonging to a given label""" + return self.data[label] + def add_rotated_rects(segmented_objects): objects_with_rects = [] From 458f3a45bffeebeada61475018560cbbc4d48d87 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Tue, 25 Jun 2024 08:20:22 +0900 Subject: [PATCH 28/31] Add stubs for the future support of polygon prompts --- .../python/model_api/models/visual_prompting.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index 7374c050..6ddbfc3a 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -58,6 +58,7 @@ def infer( image: np.ndarray, boxes: list[Prompt] | None = None, points: list[Prompt] | None = None, + polygons: list[Prompt] | None = None, ) -> VisualPromptingResult: """ Obtains segmentation masks using given prompts. @@ -68,12 +69,16 @@ def infer( and their labels (ints, one per box). Defaults to None. points (list[Prompt] | None, optional): Prompts containing points (in XY format) and their labels (ints, one per point). Defaults to None. + polygons: (list[Prompt] | None): Prompts containing polygons (a sequence of points in XY format) + and their labels (ints, one per polygon). Each polygon is represented as a mask prompt. Defaults to None. Returns: VisualPromptingResult: result object containing predicted masks and aux information. """ if boxes is None and points is None: raise RuntimeError("boxes or points prompts are required for inference") + if polygons is not None: + raise RuntimeError("Polygon support is not implemented yet") outputs: list[dict[str, Any]] = [] @@ -116,9 +121,10 @@ def __call__( image: np.ndarray, boxes: list[Prompt] | None = None, points: list[Prompt] | None = None, + polygons: list[Prompt] | None = None, ) -> VisualPromptingResult: """A wrapper of the SAMVisualPrompter.infer() method""" - return self.infer(image, boxes, points) + return self.infer(image, boxes, points, polygons) class SAMLearnableVisualPrompter: @@ -179,6 +185,7 @@ def learn( image: np.ndarray, boxes: list[Prompt] | None = None, points: list[Prompt] | None = None, + polygons: list[Prompt] | None = None, reset_features: bool = False, ) -> tuple[VisualPromptingFeatures, np.ndarray]: """ @@ -194,6 +201,8 @@ def learn( and their labels (ints, one per box). Defaults to None. points (list[Prompt] | None, optional): Prompts containing points (in XY format) and their labels (ints, one per point). Defaults to None. + polygons: (list[Prompt] | None): Prompts containing polygons (a sequence of points in XY format) + and their labels (ints, one per polygon). Each polygon is represented as a mask prompt. Defaults to None. reset_features (bool, optional): Forces learning from scratch. Defaults to False. Returns: @@ -203,6 +212,8 @@ def learn( if boxes is None and points is None: raise RuntimeError("boxes or points prompts are required for learning") + if polygons is not None: + raise RuntimeError("Polygon support is not implemented yet") if reset_features or not self.has_reference_features(): self.reset_reference_info() From 7c844f7f285df36de71e8eae8a1801aa7d3e24fe Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Wed, 26 Jun 2024 09:21:22 +0900 Subject: [PATCH 29/31] Add polygon prompts to ZSL --- .../model_api/models/visual_prompting.py | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index 6ddbfc3a..7ed54887 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -58,7 +58,6 @@ def infer( image: np.ndarray, boxes: list[Prompt] | None = None, points: list[Prompt] | None = None, - polygons: list[Prompt] | None = None, ) -> VisualPromptingResult: """ Obtains segmentation masks using given prompts. @@ -69,16 +68,12 @@ def infer( and their labels (ints, one per box). Defaults to None. points (list[Prompt] | None, optional): Prompts containing points (in XY format) and their labels (ints, one per point). Defaults to None. - polygons: (list[Prompt] | None): Prompts containing polygons (a sequence of points in XY format) - and their labels (ints, one per polygon). Each polygon is represented as a mask prompt. Defaults to None. Returns: VisualPromptingResult: result object containing predicted masks and aux information. """ if boxes is None and points is None: raise RuntimeError("boxes or points prompts are required for inference") - if polygons is not None: - raise RuntimeError("Polygon support is not implemented yet") outputs: list[dict[str, Any]] = [] @@ -121,10 +116,9 @@ def __call__( image: np.ndarray, boxes: list[Prompt] | None = None, points: list[Prompt] | None = None, - polygons: list[Prompt] | None = None, ) -> VisualPromptingResult: """A wrapper of the SAMVisualPrompter.infer() method""" - return self.infer(image, boxes, points, polygons) + return self.infer(image, boxes, points) class SAMLearnableVisualPrompter: @@ -202,7 +196,8 @@ def learn( points (list[Prompt] | None, optional): Prompts containing points (in XY format) and their labels (ints, one per point). Defaults to None. polygons: (list[Prompt] | None): Prompts containing polygons (a sequence of points in XY format) - and their labels (ints, one per polygon). Each polygon is represented as a mask prompt. Defaults to None. + and their labels (ints, one per polygon). + Polygon prompts are used to mask out the source features without implying decoder usage. Defaults to None. reset_features (bool, optional): Forces learning from scratch. Defaults to False. Returns: @@ -210,10 +205,8 @@ def learn( The shape of the reference mask is N_labels x H x W, where H and W are the same as in the input image. """ - if boxes is None and points is None: - raise RuntimeError("boxes or points prompts are required for learning") - if polygons is not None: - raise RuntimeError("Polygon support is not implemented yet") + if boxes is None and points is None and polygons is None: + raise RuntimeError("boxes, polygons or points prompts are required for learning") if reset_features or not self.has_reference_features(): self.reset_reference_info() @@ -230,8 +223,13 @@ def learn( }, ) + if polygons is not None: + for poly in polygons: + processed_prompts.append({"polygon": poly.data, "label": poly.label}) + processed_prompts_w_labels = self._gather_prompts_with_labels(processed_prompts) largest_label: int = max([int(p) for p in processed_prompts_w_labels] + [0]) + self._expand_reference_info(largest_label) original_shape = np.array(image.shape[:2]) @@ -255,8 +253,10 @@ def learn( inputs_decoder, original_shape, is_cascade=self._is_cascade ) masks = prediction["upscaled_masks"] + elif "polygon" in inputs_decoder: + masks = _polygon_to_mask(inputs_decoder["polygon"], *original_shape) else: - raise RuntimeError("Prompts other than points are not supported") + raise RuntimeError("Unsupported type of prompt") ref_mask[masks] += 1 ref_mask = np.clip(ref_mask, 0, 1) @@ -491,6 +491,17 @@ def _predict_masks( return {"upscaled_masks": masks} +def _polygon_to_mask(polygon: np.ndarray | list[np.ndarray], height: int, width: int) -> np.ndarray: + """Converts a polygon represented as an array of 2D points into a mask""" + if isinstance(polygon, np.ndarray) and np.issubdtype(polygon.dtype, np.integer): + contour = polygon.reshape(-1, 2) + else: + contour = [[int(point[0]), int(point[1])] for point in polygon] + gt_mask = np.zeros((height, width), dtype=np.uint8) + gt_mask = cv2.drawContours(gt_mask, np.asarray([contour]), 0, 1, cv2.FILLED) + return gt_mask + + def _generate_masked_features( feats: np.ndarray, masks: np.ndarray, From a8a083a80bee09b29a0ef66fb17c7006eb54c40c Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Wed, 26 Jun 2024 09:21:43 +0900 Subject: [PATCH 30/31] Update tests --- tests/python/accuracy/public_scope.json | 2 +- tests/python/accuracy/test_accuracy.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/python/accuracy/public_scope.json b/tests/python/accuracy/public_scope.json index 8756a431..da0ad4bb 100644 --- a/tests/python/accuracy/public_scope.json +++ b/tests/python/accuracy/public_scope.json @@ -416,7 +416,7 @@ "test_data": [ { "image": "coco128/images/train2017/000000000471.jpg", - "reference": ["mask sum: 108565; [385.0, 315.0, 0.93] [335.0, 414.0, 0.76] [44.0, 205.0, 0.66] [605.0, 224.0, 0.65]"] + "reference": ["mask sum: 14991; [385.0, 315.0, 0.93] [44.0, 205.0, 0.66] [605.0, 224.0, 0.65], mask sum: 248221; [374.0, 365.0, 0.9] [335.0, 34.0, 0.9] [354.0, 135.0, 0.71]"] } ] }, diff --git a/tests/python/accuracy/test_accuracy.py b/tests/python/accuracy/test_accuracy.py index b849f741..b27cffb5 100644 --- a/tests/python/accuracy/test_accuracy.py +++ b/tests/python/accuracy/test_accuracy.py @@ -149,6 +149,18 @@ def test_image_models(data, dump, result, model_data): 0, ) ], + polygons=[ + Prompt( + np.array( + [ + [image.shape[0] / 4, image.shape[1] / 4], + [image.shape[0] / 4, image.shape[1] / 2], + [image.shape[0] / 2, image.shape[1] / 2], + ] + ), + 1, + ) + ], ) outputs = model(image) else: From 716ee65c1e3da2b2a429fab26b2e75d13621ae55 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Wed, 26 Jun 2024 09:23:13 +0900 Subject: [PATCH 31/31] Fix black --- model_api/python/model_api/models/visual_prompting.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/model_api/python/model_api/models/visual_prompting.py b/model_api/python/model_api/models/visual_prompting.py index 7ed54887..1fd5cfa9 100644 --- a/model_api/python/model_api/models/visual_prompting.py +++ b/model_api/python/model_api/models/visual_prompting.py @@ -206,7 +206,9 @@ def learn( """ if boxes is None and points is None and polygons is None: - raise RuntimeError("boxes, polygons or points prompts are required for learning") + raise RuntimeError( + "boxes, polygons or points prompts are required for learning" + ) if reset_features or not self.has_reference_features(): self.reset_reference_info() @@ -491,7 +493,9 @@ def _predict_masks( return {"upscaled_masks": masks} -def _polygon_to_mask(polygon: np.ndarray | list[np.ndarray], height: int, width: int) -> np.ndarray: +def _polygon_to_mask( + polygon: np.ndarray | list[np.ndarray], height: int, width: int +) -> np.ndarray: """Converts a polygon represented as an array of 2D points into a mask""" if isinstance(polygon, np.ndarray) and np.issubdtype(polygon.dtype, np.integer): contour = polygon.reshape(-1, 2)