Skip to content

Commit

Permalink
Add Polygon Primitive (#254)
Browse files Browse the repository at this point in the history
* Refactor primitives

Signed-off-by: Ashwin Vaidya <[email protected]>

* Add polygon primitive

Signed-off-by: Ashwin Vaidya <[email protected]>

* Add docstrings

Signed-off-by: Ashwin Vaidya <[email protected]>

---------

Signed-off-by: Ashwin Vaidya <[email protected]>
  • Loading branch information
ashwinvaidya17 authored Jan 16, 2025
1 parent 6e7d108 commit ae3241a
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 40 deletions.
4 changes: 2 additions & 2 deletions src/python/model_api/visualizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# SPDX-License-Identifier: Apache-2.0

from .layout import Flatten, HStack, Layout
from .primitive import BoundingBox, Overlay
from .primitive import BoundingBox, Overlay, Polygon
from .scene import Scene
from .visualizer import Visualizer

__all__ = ["BoundingBox", "Overlay", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
__all__ = ["BoundingBox", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"]
11 changes: 11 additions & 0 deletions src/python/model_api/visualizer/primitive/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Primitive classes."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .bounding_box import BoundingBox
from .overlay import Overlay
from .polygon import Polygon
from .primitive import Primitive

__all__ = ["Primitive", "BoundingBox", "Overlay", "Polygon"]
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
"""Base class for primitives."""
"""Bounding box primitive."""

# Copyright (C) 2024 Intel Corporation
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from abc import ABC, abstractmethod

import numpy as np
import PIL
from PIL import Image, ImageDraw


class Primitive(ABC):
"""Primitive class."""

@abstractmethod
def compute(self, image: Image) -> Image:
pass
from .primitive import Primitive


class BoundingBox(Primitive):
Expand Down Expand Up @@ -71,27 +61,3 @@ def compute(self, image: Image) -> Image:
draw.text((0, 0), self.label, fill="white")
image.paste(label_image, (self.x1, self.y1))
return image


class Overlay(Primitive):
"""Overlay primitive.
Useful for XAI and Anomaly Maps.
Args:
image (PIL.Image | np.ndarray): Image to be overlaid.
opacity (float): Opacity of the overlay.
"""

def __init__(self, image: PIL.Image | np.ndarray, opacity: float = 0.4) -> None:
self.image = self._to_pil(image)
self.opacity = opacity

def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
if isinstance(image, np.ndarray):
return PIL.Image.fromarray(image)
return image

def compute(self, image: PIL.Image) -> PIL.Image:
_image = self.image.resize(image.size)
return PIL.Image.blend(image, _image, self.opacity)
35 changes: 35 additions & 0 deletions src/python/model_api/visualizer/primitive/overlay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Overlay primitive."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import numpy as np
import PIL

from .primitive import Primitive


class Overlay(Primitive):
"""Overlay primitive.
Useful for XAI and Anomaly Maps.
Args:
image (PIL.Image | np.ndarray): Image to be overlaid.
opacity (float): Opacity of the overlay.
"""

def __init__(self, image: PIL.Image | np.ndarray, opacity: float = 0.4) -> None:
self.image = self._to_pil(image)
self.opacity = opacity

def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image:
if isinstance(image, np.ndarray):
return PIL.Image.fromarray(image)
return image

def compute(self, image: PIL.Image) -> PIL.Image:
image_ = self.image.resize(image.size)
return PIL.Image.blend(image, image_, self.opacity)
93 changes: 93 additions & 0 deletions src/python/model_api/visualizer/primitive/polygon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Polygon primitive."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from typing import TYPE_CHECKING

import cv2
from PIL import Image, ImageDraw

from .primitive import Primitive

if TYPE_CHECKING:
import numpy as np


class Polygon(Primitive):
"""Polygon primitive.
Args:
points: List of points.
mask: Mask to draw the polygon.
color: Color of the polygon.
Examples:
>>> polygon = Polygon(points=[(10, 10), (100, 10), (100, 100), (10, 100)], color="red")
>>> polygon = Polygon(mask=mask, color="red")
>>> polygon.compute(image).save("polygon.jpg")
>>> polygon = Polygon(mask=mask, color="red")
>>> polygon.compute(image).save("polygon.jpg")
"""

def __init__(
self,
points: list[tuple[int, int]] | None = None,
mask: np.ndarray | None = None,
color: str | tuple[int, int, int] = "blue",
) -> None:
self.points = self._get_points(points, mask)
self.color = color

def _get_points(self, points: list[tuple[int, int]] | None, mask: np.ndarray | None) -> list[tuple[int, int]]:
"""Get points from either points or mask.
Note:
Either points or mask should be provided.
Args:
points: List of points.
mask: Mask to draw the polygon.
Returns:
List of points.
"""
if points is not None and mask is not None:
msg = "Either points or mask should be provided, not both."
raise ValueError(msg)
if points is not None:
points_ = points
elif mask is not None:
points_ = self._get_points_from_mask(mask)
else:
msg = "Either points or mask should be provided."
raise ValueError(msg)
return points_

def _get_points_from_mask(self, mask: np.ndarray) -> list[tuple[int, int]]:
"""Get points from mask.
Args:
mask: Mask to draw the polygon.
Returns:
List of points.
"""
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
points_ = contours[0].squeeze().tolist()
return [tuple(point) for point in points_]

def compute(self, image: Image) -> Image:
"""Compute the polygon.
Args:
image: Image to draw the polygon on.
Returns:
Image with the polygon drawn on it.
"""
draw = ImageDraw.Draw(image)
draw.polygon(self.points, fill=self.color)
return image
20 changes: 20 additions & 0 deletions src/python/model_api/visualizer/primitive/primitive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Base class for primitives."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import PIL


class Primitive(ABC):
"""Base class for primitives."""

@abstractmethod
def compute(self, image: PIL.Image) -> PIL.Image:
"""Compute the primitive."""
21 changes: 20 additions & 1 deletion tests/python/unit/visualizer/test_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import PIL
from PIL import ImageDraw

from model_api.visualizer import BoundingBox, Overlay
from model_api.visualizer import BoundingBox, Overlay, Polygon


def test_overlay(mock_image: PIL.Image):
Expand All @@ -32,3 +32,22 @@ def test_bounding_box(mock_image: PIL.Image):
draw.rectangle((10, 10, 100, 100), outline="blue", width=2)
bounding_box = BoundingBox(x1=10, y1=10, x2=100, y2=100)
assert bounding_box.compute(mock_image) == expected_image


def test_polygon(mock_image: PIL.Image):
"""Test if the polygon is created correctly."""
# Test from points
expected_image = mock_image.copy()
draw = ImageDraw.Draw(expected_image)
draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red")
polygon = Polygon(points=[(10, 10), (100, 10), (100, 100), (10, 100)], color="red")
assert polygon.compute(mock_image) == expected_image

# Test from mask
mask = np.zeros((100, 100), dtype=np.uint8)
mask[10:100, 10:100] = 255
expected_image = mock_image.copy()
draw = ImageDraw.Draw(expected_image)
draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red")
polygon = Polygon(mask=mask, color="red")
assert polygon.compute(mock_image) == expected_image

0 comments on commit ae3241a

Please sign in to comment.