Skip to content

Commit

Permalink
Add Classification And Detection Scene (#259)
Browse files Browse the repository at this point in the history
* Add classification scene

Signed-off-by: Ashwin Vaidya <[email protected]>

* Add detection scene

Signed-off-by: Ashwin Vaidya <[email protected]>

* Add tests

Signed-off-by: Ashwin Vaidya <[email protected]>

* Add title to overlay

Signed-off-by: Ashwin Vaidya <[email protected]>

* Pass name and confidence separately

Signed-off-by: Ashwin Vaidya <[email protected]>

* Fix tests

Signed-off-by: Ashwin Vaidya <[email protected]>

---------

Signed-off-by: Ashwin Vaidya <[email protected]>
  • Loading branch information
ashwinvaidya17 authored Jan 23, 2025
1 parent 4d4bf20 commit 63056fc
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 9 deletions.
5 changes: 5 additions & 0 deletions src/python/model_api/models/result/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def label_names(self, value):

@property
def saliency_map(self):
"""Saliency map for XAI.
Returns:
np.ndarray: Saliency map in dim of (B, N_CLASSES, H, W).
"""
return self._saliency_map

@saliency_map.setter
Expand Down
4 changes: 4 additions & 0 deletions src/python/model_api/visualizer/layout/hstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import PIL

from model_api.visualizer.primitive import Overlay

from .layout import Layout

if TYPE_CHECKING:
Expand All @@ -31,6 +33,8 @@ def _compute_on_primitive(self, primitive: Type[Primitive], image: PIL.Image, sc
images = []
for _primitive in scene.get_primitives(primitive):
image_ = _primitive.compute(image.copy())
if isinstance(_primitive, Overlay):
image_ = Overlay.overlay_labels(image=image_, labels=_primitive.label)
images.append(image_)
return self._stitch(*images)
return None
Expand Down
31 changes: 30 additions & 1 deletion src/python/model_api/visualizer/primitive/overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

from __future__ import annotations

from typing import Union

import numpy as np
import PIL
from PIL import ImageFont

from .primitive import Primitive

Expand All @@ -18,11 +21,18 @@ class Overlay(Primitive):
Args:
image (PIL.Image | np.ndarray): Image to be overlaid.
label (str | None): Optional label name to overlay.
opacity (float): Opacity of the overlay.
"""

def __init__(self, image: PIL.Image | np.ndarray, opacity: float = 0.4) -> None:
def __init__(
self,
image: PIL.Image | np.ndarray,
opacity: float = 0.4,
label: Union[str, None] = None,
) -> None:
self.image = self._to_pil(image)
self.label = label
self.opacity = opacity

def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
Expand All @@ -33,3 +43,22 @@ def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
def compute(self, image: PIL.Image) -> PIL.Image:
image_ = self.image.resize(image.size)
return PIL.Image.blend(image, image_, self.opacity)

@classmethod
def overlay_labels(cls, image: PIL.Image, labels: Union[list[str], str, None] = None) -> PIL.Image:
"""Draw labels at the bottom center of the image.
This is handy when you want to add a label to the image.
"""
if labels is not None:
labels = [labels] if isinstance(labels, str) else labels
font = ImageFont.load_default(size=18)
buffer_y = 5
dummy_image = PIL.Image.new("RGB", (1, 1))
draw = PIL.ImageDraw.Draw(dummy_image)
textbox = draw.textbbox((0, 0), ", ".join(labels), font=font)
image_ = PIL.Image.new("RGB", (textbox[2] - textbox[0], textbox[3] + buffer_y - textbox[1]), "white")
draw = PIL.ImageDraw.Draw(image_)
draw.text((0, 0), ", ".join(labels), font=font, fill="black")
image.paste(image_, (image.width // 2 - image_.width // 2, image.height - image_.height - buffer_y))
return image
28 changes: 24 additions & 4 deletions src/python/model_api/visualizer/scene/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

from typing import Union

import cv2
from PIL import Image

from model_api.models.result import ClassificationResult
from model_api.visualizer.layout import Flatten, Layout
from model_api.visualizer.primitive import Overlay
from model_api.visualizer.primitive import Label, Overlay

from .scene import Scene

Expand All @@ -18,9 +19,28 @@ class ClassificationScene(Scene):
"""Classification Scene."""

def __init__(self, image: Image, result: ClassificationResult, layout: Union[Layout, None] = None) -> None:
self.image = image
self.result = result
super().__init__(
base=image,
label=self._get_labels(result),
overlay=self._get_overlays(result),
layout=layout,
)

def _get_labels(self, result: ClassificationResult) -> list[Label]:
labels = []
if result.top_labels is not None and len(result.top_labels) > 0:
for label in result.top_labels:
if label.name is not None:
labels.append(Label(label=label.name, score=label.confidence))
return labels

def _get_overlays(self, result: ClassificationResult) -> list[Overlay]:
overlays = []
if result.saliency_map is not None and result.saliency_map.size > 0:
saliency_map = cv2.cvtColor(result.saliency_map, cv2.COLOR_BGR2RGB)
overlays.append(Overlay(saliency_map))
return overlays

@property
def default_layout(self) -> Layout:
return Flatten(Overlay)
return Flatten(Overlay, Label)
34 changes: 31 additions & 3 deletions src/python/model_api/visualizer/scene/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

from typing import Union

import cv2
from PIL import Image

from model_api.models.result import DetectionResult
from model_api.visualizer.layout import Layout
from model_api.visualizer.layout import Flatten, HStack, Layout
from model_api.visualizer.primitive import BoundingBox, Label, Overlay

from .scene import Scene

Expand All @@ -17,5 +19,31 @@ class DetectionScene(Scene):
"""Detection Scene."""

def __init__(self, image: Image, result: DetectionResult, layout: Union[Layout, None] = None) -> None:
self.image = image
self.result = result
super().__init__(
base=image,
bounding_box=self._get_bounding_boxes(result),
overlay=self._get_overlays(result),
layout=layout,
)

def _get_overlays(self, result: DetectionResult) -> list[Overlay]:
overlays = []
# Add only the overlays that are predicted
label_index_mapping = dict(zip(result.labels, result.label_names))
for label_index, label_name in label_index_mapping.items():
# Index 0 as it assumes only one batch
saliency_map = cv2.applyColorMap(result.saliency_map[0][label_index], cv2.COLORMAP_JET)
overlays.append(Overlay(saliency_map, label=label_name.title()))
return overlays

def _get_bounding_boxes(self, result: DetectionResult) -> list[BoundingBox]:
bounding_boxes = []
for score, label_name, bbox in zip(result.scores, result.label_names, result.bboxes):
x1, y1, x2, y2 = bbox
label = f"{label_name} ({score:.2f})"
bounding_boxes.append(BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2, label=label))
return bounding_boxes

@property
def default_layout(self) -> Layout:
return HStack(Flatten(BoundingBox, Label), Overlay)
33 changes: 32 additions & 1 deletion tests/python/unit/visualizer/test_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import numpy as np
from PIL import Image

from model_api.models.result import AnomalyResult
from model_api.models.result import AnomalyResult, ClassificationResult, DetectionResult
from model_api.models.result.classification import Label
from model_api.visualizer import Visualizer


Expand All @@ -32,3 +33,33 @@ def test_anomaly_scene(mock_image: Image, tmpdir: Path):
visualizer = Visualizer()
visualizer.save(mock_image, anomaly_result, tmpdir / "anomaly_scene.jpg")
assert Path(tmpdir / "anomaly_scene.jpg").exists()


def test_classification_scene(mock_image: Image, tmpdir: Path):
"""Test if the classification scene is created."""
classification_result = ClassificationResult(
top_labels=[
Label(name="cat", confidence=0.95),
Label(name="dog", confidence=0.90),
],
saliency_map=np.ones(mock_image.size, dtype=np.uint8),
)
visualizer = Visualizer()
visualizer.save(
mock_image, classification_result, tmpdir / "classification_scene.jpg"
)
assert Path(tmpdir / "classification_scene.jpg").exists()


def test_detection_scene(mock_image: Image, tmpdir: Path):
"""Test if the detection scene is created."""
detection_result = DetectionResult(
bboxes=np.array([[0, 0, 128, 128], [32, 32, 96, 96]]),
labels=np.array([0, 1]),
label_names=["person", "car"],
scores=np.array([0.85, 0.75]),
saliency_map=(np.ones((1, 2, 6, 8)) * 255).astype(np.uint8),
)
visualizer = Visualizer()
visualizer.save(mock_image, detection_result, tmpdir / "detection_scene.jpg")
assert Path(tmpdir / "detection_scene.jpg").exists()

0 comments on commit 63056fc

Please sign in to comment.