Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visual prompting wrappers #181

Merged
merged 35 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
4f919b7
Add initial version of SAM wrappers
sovrasov Jun 7, 2024
c2a0361
Add initial implementation of ZSL
sovrasov Jun 8, 2024
39d6622
Update type annotation
sovrasov Jun 11, 2024
8612ca9
Fix isort
sovrasov Jun 11, 2024
fd74455
Fix imports
sovrasov Jun 11, 2024
8059565
Merge remote-tracking branch 'origin/master' into vis_prompt
sovrasov Jun 11, 2024
bf972cd
Fix labels concat on learn
sovrasov Jun 12, 2024
664c921
Merge remote-tracking branch 'origin/master' into vis_prompt
sovrasov Jun 12, 2024
2eeb73a
Merge remote-tracking branch 'origin/master' into vis_prompt
sovrasov Jun 12, 2024
1686503
Update image encoder
sovrasov Jun 13, 2024
57c8c94
Move aux methods out of the class vpt scope
sovrasov Jun 13, 2024
50d2a95
Fix handling of reference features
sovrasov Jun 18, 2024
1c03735
Align is_cascade usage
sovrasov Jun 18, 2024
70af2ba
Update vpt public interfaces
sovrasov Jun 19, 2024
29146e8
Add some docs
sovrasov Jun 19, 2024
444632b
Update docs
sovrasov Jun 19, 2024
a3d1180
Fix black
sovrasov Jun 19, 2024
d8500fb
Add SAM to testdata
sovrasov Jun 20, 2024
b4b149d
Update result objects
sovrasov Jun 21, 2024
ba41c96
Add tests
sovrasov Jun 21, 2024
9f3b713
Update decoder postprocessing
sovrasov Jun 21, 2024
c1d9deb
Update SAM ref results
sovrasov Jun 21, 2024
2938bd2
Fix linters
sovrasov Jun 21, 2024
4844153
Skip SAM in cpp tests
sovrasov Jun 21, 2024
82dd778
Workaround unsupported type annotation
sovrasov Jun 21, 2024
5a9a497
Fix black
sovrasov Jun 21, 2024
a2ab4bc
Replace prompt->list
sovrasov Jun 21, 2024
162dfc9
Fix python tests
sovrasov Jun 21, 2024
b56ee04
Restore bool mask output from mask decoder
sovrasov Jun 24, 2024
5cfaca2
Improve usability of ZSL VPT result
sovrasov Jun 24, 2024
458f3a4
Add stubs for the future support of polygon prompts
sovrasov Jun 24, 2024
7c844f7
Add polygon prompts to ZSL
sovrasov Jun 26, 2024
a8a083a
Update tests
sovrasov Jun 26, 2024
716ee65
Fix black
sovrasov Jun 26, 2024
d6c9ad0
Merge branch 'master' into vis_prompt
sovrasov Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions model_api/python/model_api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .nanodet import NanoDet, NanoDetPlus
from .open_pose import OpenPose
from .retinaface import RetinaFace, RetinaFacePyTorch
from .sam_models import SAMDecoder, SAMImageEncoder
from .segmentation import SalientObjectDetectionModel, SegmentationModel
from .ssd import SSD
from .ultra_lightweight_face_detection import UltraLightweightFaceDetection
Expand All @@ -50,11 +51,15 @@
ImageResultWithSoftPrediction,
InstanceSegmentationResult,
OutputTransform,
PredictedMask,
SegmentedObject,
SegmentedObjectWithRects,
VisualPromptingResult,
ZSLVisualPromptingResult,
add_rotated_rects,
get_contours,
)
from .visual_prompting import Prompt, SAMLearnableVisualPrompter, SAMVisualPrompter
from .yolo import YOLO, YOLOF, YOLOX, YoloV3ONNX, YoloV4, YOLOv5, YOLOv8

classification_models = [
Expand Down Expand Up @@ -97,6 +102,11 @@
"ImageModel",
"ImageResultWithSoftPrediction",
"InstanceSegmentationResult",
"VisualPromptingResult",
"ZSLVisualPromptingResult",
"PredictedMask",
"SAMVisualPrompter",
"SAMLearnableVisualPrompter",
"MaskRCNNModel",
"Model",
"MonoDepthModel",
Expand All @@ -120,7 +130,10 @@
"YOLOv8",
"YOLOF",
"YOLOX",
"SAMDecoder",
"SAMImageEncoder",
"ClassificationResult",
"Prompt",
"Detection",
"DetectionResult",
"DetectionWithLandmarks",
Expand Down
204 changes: 204 additions & 0 deletions model_api/python/model_api/models/sam_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""
Copyright (C) 2024 Intel Corporation

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from __future__ import annotations # TODO: remove when Python3.9 support is dropped

from copy import deepcopy
from typing import Any, Dict

import numpy as np
from model_api.adapters.inference_adapter import InferenceAdapter
from model_api.models.types import BooleanValue, NumericalValue

from .image_model import ImageModel
from .segmentation import SegmentationModel


class SAMImageEncoder(ImageModel):
"""Image Encoder for SAM: https://arxiv.org/abs/2304.02643"""

__model__ = "sam_image_encoder"

def __init__(
self,
inference_adapter: InferenceAdapter,
configuration: Dict[str, Any] = dict(),
preload: bool = False,
):
super().__init__(inference_adapter, configuration, preload)
self.output_name: str = list(self.outputs.keys())[0]

@classmethod
def parameters(cls) -> dict[str, Any]:
parameters = super().parameters()
parameters.update(
{
"image_size": NumericalValue(
value_type=int, default_value=1024, min=0, max=2048
),
},
)
return parameters

def preprocess(
self, inputs: np.ndarray
) -> tuple[dict[str, np.ndarray], dict[str, Any]]:
"""Update meta for image encoder."""
dict_inputs, meta = super().preprocess(inputs)
meta["resize_type"] = self.resize_type
return dict_inputs, meta

def postprocess(
self, outputs: dict[str, np.ndarray], meta: dict[str, Any]
) -> np.ndarray:
return outputs[self.output_name]


class SAMDecoder(SegmentationModel):
"""Image Decoder for SAM: https://arxiv.org/abs/2304.02643"""

__model__ = "sam_decoder"

def __init__(
self,
model_adapter: InferenceAdapter,
configuration: Dict[str, Any] = dict(),
preload: bool = False,
):
super().__init__(model_adapter, configuration, preload)

self.mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
self.has_mask_input = np.zeros((1, 1), dtype=np.float32)

@classmethod
def parameters(cls) -> dict[str, Any]:
parameters = super().parameters()
parameters.update(
{
"image_size": NumericalValue(
value_type=int, default_value=1024, min=0, max=2048
)
}
)
parameters.update(
{
"mask_threshold": NumericalValue(
value_type=float, default_value=0.0, min=0, max=1
)
}
)
parameters.update(
{
"embed_dim": NumericalValue(
value_type=int, default_value=256, min=0, max=512
)
}
)
parameters.update({"embedded_processing": BooleanValue(default_value=True)})
return parameters

def _get_outputs(self) -> str:
return "upscaled_masks"

def preprocess(self, inputs: dict[str, Any]) -> list[dict[str, Any]]:
"""Preprocess prompts."""
processed_prompts: list[dict[str, Any]] = []
for prompt_name in ["bboxes", "points"]:
if (prompts := inputs.get(prompt_name, None)) is None or (
labels := inputs["labels"].get(prompt_name, None)
) is None:
continue

for prompt, label in zip(prompts, labels):
if prompt_name == "bboxes":
point_coords = self.apply_coords(
prompt.reshape(-1, 2, 2), inputs["orig_size"]
)
point_labels = np.array([2, 3], dtype=np.float32).reshape(-1, 2)
else:
point_coords = self.apply_coords(
prompt.reshape(-1, 1, 2), inputs["orig_size"]
)
point_labels = np.array([1], dtype=np.float32).reshape(-1, 1)

processed_prompts.append(
{
"point_coords": point_coords,
"point_labels": point_labels,
"mask_input": self.mask_input,
"has_mask_input": self.has_mask_input,
"orig_size": np.array(
inputs["orig_size"], dtype=np.int64
).reshape(-1, 2),
"label": label,
},
)
return processed_prompts

def apply_coords(
self, coords: np.ndarray, orig_size: np.ndarray | list[int] | tuple[int, int]
) -> np.ndarray:
"""Process coords according to preprocessed image size using image meta."""
old_h, old_w = orig_size
new_h, new_w = self._get_preprocess_shape(old_h, old_w, self.image_size)
coords = deepcopy(coords).astype(np.float32)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords

def _get_preprocess_shape(
self, old_h: int, old_w: int, image_size: int
) -> tuple[int, int]:
"""Compute the output size given input size and target image size."""
scale = image_size / max(old_h, old_w)
new_h, new_w = old_h * scale, old_w * scale
new_w = int(new_w + 0.5)
new_h = int(new_h + 0.5)
return (new_h, new_w)

def _check_io_number(
self, number_of_inputs: int | tuple[int], number_of_outputs: int | tuple[int]
) -> None:
pass

def _get_inputs(self) -> tuple[list[str], list[str]]:
"""Get input layer name and shape."""
image_blob_names = list(self.inputs.keys())
image_info_blob_names: list = []
return image_blob_names, image_info_blob_names

def postprocess(
self, outputs: dict[str, np.ndarray], meta: dict[str, Any]
) -> dict[str, np.ndarray]:
"""Postprocess to convert soft prediction to hard prediction.

Args:
outputs (dict[str, np.ndarray]): The output of the model.
meta (dict[str, Any]): Contain label and original size.

Returns:
(dict[str, np.ndarray]): The postprocessed output of the model.
"""
probability = np.clip(outputs["scores"], 0.0, 1.0)
hard_prediction = (
outputs[self.output_blob_name].squeeze(0) > self.mask_threshold
)
soft_prediction = hard_prediction * probability.reshape(-1, 1, 1)

outputs["hard_prediction"] = hard_prediction
outputs["soft_prediction"] = soft_prediction

return outputs
59 changes: 59 additions & 0 deletions model_api/python/model_api/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,65 @@ def __str__(self):
return f"{obj_str}; {filled}; [{','.join(str(i) for i in self.feature_vector.shape)}]"


class VisualPromptingResult(NamedTuple):
upscaled_masks: List[np.ndarray] | None = None
low_res_masks: List[np.ndarray] | None = None
iou_predictions: List[np.ndarray] | None = None
scores: List[np.ndarray] | None = None
labels: List[np.ndarray] | None = None
hard_predictions: List[np.ndarray] | None = None
soft_predictions: List[np.ndarray] | None = None

def _compute_min_max(self, tensor: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
return tensor.min(), tensor.max()

def __str__(self) -> str:
assert self.hard_predictions is not None
assert self.upscaled_masks is not None
upscaled_masks_min, upscaled_masks_max = self._compute_min_max(
self.upscaled_masks[0]
)

return (
f"upscaled_masks min:{upscaled_masks_min:.3f} max:{upscaled_masks_max:.3f};"
f"hard_predictions shape:{self.hard_predictions[0].shape};"
)


class PredictedMask(NamedTuple):
mask: list[np.ndarray]
points: list[np.ndarray] | np.ndarray

def __str__(self) -> str:
obj_str = ""
obj_str += f"mask sum: {np.sum(sum(self.mask))}; "

if isinstance(self.points, list):
for point in self.points:
obj_str += "["
obj_str += ", ".join(str(round(c, 2)) for c in point)
obj_str += "] "
else:
for i in range(self.points.shape[0]):
point = self.points[i]
obj_str += "["
obj_str += ", ".join(str(round(c, 2)) for c in point)
obj_str += "] "

return obj_str.strip()


class ZSLVisualPromptingResult(NamedTuple):
data: dict[int, PredictedMask]

def __str__(self) -> str:
return ", ".join(str(self.data[k]) for k in self.data)
sovrasov marked this conversation as resolved.
Show resolved Hide resolved

def get_mask(self, label: int) -> PredictedMask:
"""Returns a mask belonging to a given label"""
return self.data[label]


def add_rotated_rects(segmented_objects):
objects_with_rects = []
for segmented_object in segmented_objects:
Expand Down
Loading
Loading