Skip to content

Commit

Permalink
Add some docs
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Jun 19, 2024
1 parent 70af2ba commit 29146e8
Showing 1 changed file with 84 additions and 25 deletions.
109 changes: 84 additions & 25 deletions model_api/python/model_api/models/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,55 @@ class Prompt(NamedTuple):
labels: list[np.ndarray | int] | np.ndarray


class PredictedMask(NamedTuple):
mask: list[np.ndarray]
points: list[np.ndarray] | np.ndarray


class SAMVisualPrompter:
"""
A wrapper that implements SAM Visual Prompter.
Segmentation results can be obtained by calling infer() method
with corresponding parameters.
"""

def __init__(
self,
encoder_model: SAMImageEncoder,
decoder_model: SAMDecoder,
):
self.encoder_model = encoder_model
self.decoder_model = decoder_model
self.encoder = encoder_model
self.decoder = decoder_model

def infer(
self,
image: np.ndarray,
boxes: Prompt | None = None,
points: Prompt | None = None,
) -> VisualPromptingResult:
"""
_summary_
Args:
image (np.ndarray): HWC-shaped image
boxes (Prompt | None, optional): _description_. Defaults to None.
points (Prompt | None, optional): _description_. Defaults to None.
Raises:
RuntimeError: _description_
Returns:
VisualPromptingResult: _description_
"""
if boxes is None and points is None:
raise RuntimeError("boxes or points prompts are required for inference")

outputs: list[dict[str, Any]] = []

processed_image, meta = self.encoder_model.preprocess(image)
image_embeddings = self.encoder_model.infer_sync(processed_image)
processed_prompts = self.decoder_model.preprocess(
processed_image, meta = self.encoder.preprocess(image)
image_embeddings = self.encoder.infer_sync(processed_image)
processed_prompts = self.decoder.preprocess(
{
"bboxes": boxes.data if boxes else None,
"points": points.data if points else None,
Expand All @@ -67,10 +93,10 @@ def infer(
label = prompt.pop("label")
prompt.update(**image_embeddings)

prediction = self.decoder_model.infer_sync(prompt)
prediction = self.decoder.infer_sync(prompt)
prediction["scores"] = prediction["iou_predictions"]
prediction["labels"] = label
processed_prediction = self.decoder_model.postprocess(prediction, meta)
processed_prediction = self.decoder.postprocess(prediction, meta)
outputs.append(processed_prediction)

return VisualPromptingResult(
Expand All @@ -93,14 +119,19 @@ def __call__(


class SAMLearnableVisualPrompter:
"""
A wrapper that provides ZSL Visual Prompting workflow.
To obtain segmentation results, one should run learn() first to obtain the reference features,
or use previously generated ones.
"""
def __init__(
self,
encoder_model: SAMImageEncoder,
decoder_model: SAMDecoder,
reference_features: VisualPromptingFeatures | None = None,
):
self.encoder_model = encoder_model
self.decoder_model = decoder_model
self.encoder = encoder_model
self.decoder = decoder_model

if reference_features is not None:
self._reference_features = reference_features.feature_vectors
Expand All @@ -116,15 +147,22 @@ def __init__(
self._threshold: float = 0.0
self._num_bg_points: int = 1
self._default_threshold_target: float = 0.65
self._image_size: int = self.encoder_model.image_size
self._image_size: int = self.encoder.image_size
self._downsizing: int = 64
self._default_threshold_reference: float = 0.3

def has_reference_features(self) -> bool:
"""
Checks if reference features are stored in the object state.
"""
return self._reference_features is not None and self._used_indices is not None

@property
def reference_features(self) -> VisualPromptingFeatures:
"""
Property represents reference features. An exception is thrown if called when
the features are not presented in the internal object state.
"""
if self.has_reference_features():
return VisualPromptingFeatures(np.copy(self._reference_features), np.copy(self._used_indices))

Expand All @@ -140,8 +178,8 @@ def learn(
"""
Executes `learn` stage of SAM ZSL pipeline.
Reference features are updated according to newly arrived prompts. This method should not be run on different images
without resetting reference features. Consequent runs on the same image with preserving reference features make sense if new or refined prompts are passed.
Reference features are updated according to newly arrived prompts. Features corresponding to the same labels are overridden during
consequent learn() calls.
Args:
image (np.ndarray): HWC-shaped image
Expand All @@ -160,7 +198,7 @@ def learn(
if reset_features or not self.has_reference_features():
self.reset_reference_info()

processed_prompts = self.decoder_model.preprocess(
processed_prompts = self.decoder.preprocess(
{
"bboxes": boxes.data if boxes else None,
"points": points.data if points else None,
Expand All @@ -177,7 +215,7 @@ def learn(
original_shape = np.array(image.shape[:2])

# forward image encoder
image_embeddings = self.encoder_model(image)
image_embeddings = self.encoder(image)
processed_embedding = image_embeddings.squeeze().transpose(1, 2, 0)

# get reference masks
Expand Down Expand Up @@ -207,7 +245,7 @@ def learn(
feats=processed_embedding,
masks=ref_mask,
threshold_mask=cur_default_threshold_reference,
image_size=self.encoder_model.image_size,
image_size=self.encoder.image_size,
)
cur_default_threshold_reference -= 0.05

Expand All @@ -223,14 +261,31 @@ def __call__(
self,
image: np.ndarray,
reference_features: VisualPromptingFeatures | None = None,
):
) -> dict[int, PredictedMask]:
"""A wrapper of the SAMLearnableVisualPrompter.infer() method"""
return self.infer(image, reference_features)

def infer(
self,
image: np.ndarray,
reference_features: VisualPromptingFeatures | None = None,
):
) -> dict[int, PredictedMask]:
"""
Obtains masks by already prepared reference features.
Reference features can be obtained with SAMLearnableVisualPrompter.learn() and passed as an argument.
If the features are not passed, instance internal state will be used as a source of the features.
Args:
image (np.ndarray): HWC-shaped image
reference_features (VisualPromptingFeatures | None, optional): Reference features object obtained during previous learn() calls.
If not passed, object internal state is used, which reflects the last learn() call. Defaults to None.
Returns:
dict[int, PredictedMask]: Mapping label -> predicted mask. Each mask object contains a list of binary masks, and a list of
related prompts. Each binary mask corresponds to one prompt point. Class mask can be obtained by applying OR operation to all
mask corresponding to one label.
"""
if reference_features is None:
if self._reference_features is None:
raise RuntimeError(
Expand All @@ -249,7 +304,7 @@ def infer(
reference_feats, used_idx = reference_features

original_shape = np.array(image.shape[:2])
image_embeddings = self.encoder_model(image)
image_embeddings = self.encoder(image)

total_points_scores, total_bg_coords = _get_prompt_candidates(
image_embeddings=image_embeddings,
Expand Down Expand Up @@ -285,7 +340,7 @@ def infer(
point_coords = np.concatenate(
(np.array([[x, y]]), bg_coords), axis=0, dtype=np.float32
)
point_coords = self.decoder_model.apply_coords(
point_coords = self.decoder.apply_coords(
point_coords, original_shape
)
point_labels = np.array([1] + [0] * len(bg_coords), dtype=np.float32)
Expand All @@ -302,19 +357,23 @@ def infer(
prediction.update({"scores": points_score[-1]})

predicted_masks[label].append(
prediction[self.decoder_model.output_blob_name]
prediction[self.decoder.output_blob_name]
)
used_points[label].append(points_score)

# check overlapping area between different label masks
_inspect_overlapping_areas(predicted_masks, used_points)

return (predicted_masks, used_points)
prediction = {}
for k in used_points:
prediction[k] = PredictedMask(predicted_masks[k], used_points[k])

return prediction

def reset_reference_info(self) -> None:
"""Initialize reference information."""
self._reference_features = np.zeros(
(0, 1, self.decoder_model.embed_dim), dtype=np.float32
(0, 1, self.decoder.embed_dim), dtype=np.float32
)
self._used_indices = np.array([], dtype=np.int64)

Expand Down Expand Up @@ -383,7 +442,7 @@ def _predict_masks(

has_mask_input = self._has_mask_inputs[1]
y, x = np.nonzero(masks)
box_coords = self.decoder_model.apply_coords(
box_coords = self.decoder.apply_coords(
np.array(
[[x.min(), y.min()], [x.max(), y.max()]], dtype=np.float32
),
Expand All @@ -402,13 +461,13 @@ def _predict_masks(
)

inputs.update({"mask_input": mask_input, "has_mask_input": has_mask_input})
prediction = self.decoder_model.infer_sync(inputs)
prediction = self.decoder.infer_sync(inputs)
upscaled_masks, scores, logits = (
prediction["upscaled_masks"],
prediction["iou_predictions"],
prediction["low_res_masks"],
)
masks = upscaled_masks > self.decoder_model.mask_threshold
masks = upscaled_masks > self.decoder.mask_threshold

_, masks = _decide_masks(masks, logits, scores)
return {"upscaled_masks": masks}
Expand Down

0 comments on commit 29146e8

Please sign in to comment.