From ae3241a78a1dc4f7ebfc075922b6deaa8602b202 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Thu, 16 Jan 2025 14:04:47 +0100 Subject: [PATCH] Add Polygon Primitive (#254) * Refactor primitives Signed-off-by: Ashwin Vaidya * Add polygon primitive Signed-off-by: Ashwin Vaidya * Add docstrings Signed-off-by: Ashwin Vaidya --------- Signed-off-by: Ashwin Vaidya --- src/python/model_api/visualizer/__init__.py | 4 +- .../visualizer/primitive/__init__.py | 11 +++ .../bounding_box.py} | 40 +------- .../model_api/visualizer/primitive/overlay.py | 35 +++++++ .../model_api/visualizer/primitive/polygon.py | 93 +++++++++++++++++++ .../visualizer/primitive/primitive.py | 20 ++++ .../python/unit/visualizer/test_primitive.py | 21 ++++- 7 files changed, 184 insertions(+), 40 deletions(-) create mode 100644 src/python/model_api/visualizer/primitive/__init__.py rename src/python/model_api/visualizer/{primitive.py => primitive/bounding_box.py} (66%) create mode 100644 src/python/model_api/visualizer/primitive/overlay.py create mode 100644 src/python/model_api/visualizer/primitive/polygon.py create mode 100644 src/python/model_api/visualizer/primitive/primitive.py diff --git a/src/python/model_api/visualizer/__init__.py b/src/python/model_api/visualizer/__init__.py index 2c8a0062..4d7718e5 100644 --- a/src/python/model_api/visualizer/__init__.py +++ b/src/python/model_api/visualizer/__init__.py @@ -4,8 +4,8 @@ # SPDX-License-Identifier: Apache-2.0 from .layout import Flatten, HStack, Layout -from .primitive import BoundingBox, Overlay +from .primitive import BoundingBox, Overlay, Polygon from .scene import Scene from .visualizer import Visualizer -__all__ = ["BoundingBox", "Overlay", "Scene", "Visualizer", "Layout", "Flatten", "HStack"] +__all__ = ["BoundingBox", "Overlay", "Polygon", "Scene", "Visualizer", "Layout", "Flatten", "HStack"] diff --git a/src/python/model_api/visualizer/primitive/__init__.py b/src/python/model_api/visualizer/primitive/__init__.py new file mode 100644 index 00000000..51837c59 --- /dev/null +++ b/src/python/model_api/visualizer/primitive/__init__.py @@ -0,0 +1,11 @@ +"""Primitive classes.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .bounding_box import BoundingBox +from .overlay import Overlay +from .polygon import Polygon +from .primitive import Primitive + +__all__ = ["Primitive", "BoundingBox", "Overlay", "Polygon"] diff --git a/src/python/model_api/visualizer/primitive.py b/src/python/model_api/visualizer/primitive/bounding_box.py similarity index 66% rename from src/python/model_api/visualizer/primitive.py rename to src/python/model_api/visualizer/primitive/bounding_box.py index 8a67c719..f9dcd534 100644 --- a/src/python/model_api/visualizer/primitive.py +++ b/src/python/model_api/visualizer/primitive/bounding_box.py @@ -1,23 +1,13 @@ -"""Base class for primitives.""" +"""Bounding box primitive.""" -# Copyright (C) 2024 Intel Corporation +# Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -from abc import ABC, abstractmethod - -import numpy as np -import PIL from PIL import Image, ImageDraw - -class Primitive(ABC): - """Primitive class.""" - - @abstractmethod - def compute(self, image: Image) -> Image: - pass +from .primitive import Primitive class BoundingBox(Primitive): @@ -71,27 +61,3 @@ def compute(self, image: Image) -> Image: draw.text((0, 0), self.label, fill="white") image.paste(label_image, (self.x1, self.y1)) return image - - -class Overlay(Primitive): - """Overlay primitive. - - Useful for XAI and Anomaly Maps. - - Args: - image (PIL.Image | np.ndarray): Image to be overlaid. - opacity (float): Opacity of the overlay. - """ - - def __init__(self, image: PIL.Image | np.ndarray, opacity: float = 0.4) -> None: - self.image = self._to_pil(image) - self.opacity = opacity - - def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image: - if isinstance(image, np.ndarray): - return PIL.Image.fromarray(image) - return image - - def compute(self, image: PIL.Image) -> PIL.Image: - _image = self.image.resize(image.size) - return PIL.Image.blend(image, _image, self.opacity) diff --git a/src/python/model_api/visualizer/primitive/overlay.py b/src/python/model_api/visualizer/primitive/overlay.py new file mode 100644 index 00000000..f69bc958 --- /dev/null +++ b/src/python/model_api/visualizer/primitive/overlay.py @@ -0,0 +1,35 @@ +"""Overlay primitive.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import numpy as np +import PIL + +from .primitive import Primitive + + +class Overlay(Primitive): + """Overlay primitive. + + Useful for XAI and Anomaly Maps. + + Args: + image (PIL.Image | np.ndarray): Image to be overlaid. + opacity (float): Opacity of the overlay. + """ + + def __init__(self, image: PIL.Image | np.ndarray, opacity: float = 0.4) -> None: + self.image = self._to_pil(image) + self.opacity = opacity + + def _to_pil(self, image: PIL.Image | np.ndarray) -> PIL.Image: + if isinstance(image, np.ndarray): + return PIL.Image.fromarray(image) + return image + + def compute(self, image: PIL.Image) -> PIL.Image: + image_ = self.image.resize(image.size) + return PIL.Image.blend(image, image_, self.opacity) diff --git a/src/python/model_api/visualizer/primitive/polygon.py b/src/python/model_api/visualizer/primitive/polygon.py new file mode 100644 index 00000000..a6ccc99c --- /dev/null +++ b/src/python/model_api/visualizer/primitive/polygon.py @@ -0,0 +1,93 @@ +"""Polygon primitive.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import cv2 +from PIL import Image, ImageDraw + +from .primitive import Primitive + +if TYPE_CHECKING: + import numpy as np + + +class Polygon(Primitive): + """Polygon primitive. + + Args: + points: List of points. + mask: Mask to draw the polygon. + color: Color of the polygon. + + Examples: + >>> polygon = Polygon(points=[(10, 10), (100, 10), (100, 100), (10, 100)], color="red") + >>> polygon = Polygon(mask=mask, color="red") + >>> polygon.compute(image).save("polygon.jpg") + + >>> polygon = Polygon(mask=mask, color="red") + >>> polygon.compute(image).save("polygon.jpg") + """ + + def __init__( + self, + points: list[tuple[int, int]] | None = None, + mask: np.ndarray | None = None, + color: str | tuple[int, int, int] = "blue", + ) -> None: + self.points = self._get_points(points, mask) + self.color = color + + def _get_points(self, points: list[tuple[int, int]] | None, mask: np.ndarray | None) -> list[tuple[int, int]]: + """Get points from either points or mask. + Note: + Either points or mask should be provided. + + Args: + points: List of points. + mask: Mask to draw the polygon. + + Returns: + List of points. + """ + if points is not None and mask is not None: + msg = "Either points or mask should be provided, not both." + raise ValueError(msg) + if points is not None: + points_ = points + elif mask is not None: + points_ = self._get_points_from_mask(mask) + else: + msg = "Either points or mask should be provided." + raise ValueError(msg) + return points_ + + def _get_points_from_mask(self, mask: np.ndarray) -> list[tuple[int, int]]: + """Get points from mask. + + Args: + mask: Mask to draw the polygon. + + Returns: + List of points. + """ + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + points_ = contours[0].squeeze().tolist() + return [tuple(point) for point in points_] + + def compute(self, image: Image) -> Image: + """Compute the polygon. + + Args: + image: Image to draw the polygon on. + + Returns: + Image with the polygon drawn on it. + """ + draw = ImageDraw.Draw(image) + draw.polygon(self.points, fill=self.color) + return image diff --git a/src/python/model_api/visualizer/primitive/primitive.py b/src/python/model_api/visualizer/primitive/primitive.py new file mode 100644 index 00000000..bda0967c --- /dev/null +++ b/src/python/model_api/visualizer/primitive/primitive.py @@ -0,0 +1,20 @@ +"""Base class for primitives.""" + +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import PIL + + +class Primitive(ABC): + """Base class for primitives.""" + + @abstractmethod + def compute(self, image: PIL.Image) -> PIL.Image: + """Compute the primitive.""" diff --git a/tests/python/unit/visualizer/test_primitive.py b/tests/python/unit/visualizer/test_primitive.py index d9a84624..2f8a540b 100644 --- a/tests/python/unit/visualizer/test_primitive.py +++ b/tests/python/unit/visualizer/test_primitive.py @@ -7,7 +7,7 @@ import PIL from PIL import ImageDraw -from model_api.visualizer import BoundingBox, Overlay +from model_api.visualizer import BoundingBox, Overlay, Polygon def test_overlay(mock_image: PIL.Image): @@ -32,3 +32,22 @@ def test_bounding_box(mock_image: PIL.Image): draw.rectangle((10, 10, 100, 100), outline="blue", width=2) bounding_box = BoundingBox(x1=10, y1=10, x2=100, y2=100) assert bounding_box.compute(mock_image) == expected_image + + +def test_polygon(mock_image: PIL.Image): + """Test if the polygon is created correctly.""" + # Test from points + expected_image = mock_image.copy() + draw = ImageDraw.Draw(expected_image) + draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red") + polygon = Polygon(points=[(10, 10), (100, 10), (100, 100), (10, 100)], color="red") + assert polygon.compute(mock_image) == expected_image + + # Test from mask + mask = np.zeros((100, 100), dtype=np.uint8) + mask[10:100, 10:100] = 255 + expected_image = mock_image.copy() + draw = ImageDraw.Draw(expected_image) + draw.polygon([(10, 10), (100, 10), (100, 100), (10, 100)], fill="red") + polygon = Polygon(mask=mask, color="red") + assert polygon.compute(mock_image) == expected_image