From 4c835ed48711aeb4aec6350bead935a6279f0262 Mon Sep 17 00:00:00 2001 From: Peizhao Zhang Date: Fri, 8 Sep 2023 21:33:34 -0700 Subject: [PATCH] add utils for visualization. Summary: add utils for visualization. * able to draw images with labels as a grid, preserving the resolutions of the input images. Reviewed By: zechenghe Differential Revision: D48223608 --- mobile_cv/common/misc/visualize_utils.py | 247 ++++++++++++++++++ .../common/tests/test_misc_visualize_utils.py | 164 ++++++++++++ 2 files changed, 411 insertions(+) create mode 100644 mobile_cv/common/misc/visualize_utils.py create mode 100644 mobile_cv/common/tests/test_misc_visualize_utils.py diff --git a/mobile_cv/common/misc/visualize_utils.py b/mobile_cv/common/misc/visualize_utils.py new file mode 100644 index 00000000..ad33a6f2 --- /dev/null +++ b/mobile_cv/common/misc/visualize_utils.py @@ -0,0 +1,247 @@ +import math +import os +from typing import Any, Dict, List, Optional + +from mobile_cv.common.misc.oss_utils import fb_overwritable + +from PIL import Image, ImageDraw, ImageFont + + +@fb_overwritable() +def get_font_path() -> Optional[str]: + return None + + +def save_as_image_grids( + rows: List[Dict[str, Any]], + output_dir: str, + path_manager, + max_rows_per_image: Optional[int] = None, + grid_padding_rows: int = 15, + grid_padding_cols: int = 10, + font_size: int = 10, +) -> List[str]: + """ + Draw and save image grids, preserve the image sizes. + rows: List of dicts that represent each row in the grid. Each dict must have + the keys "images", and optional keys "row_title", "titles", and "labels": + * images (List[str]): List of image paths for the row, + * row_title (Optional[str]): The title of the row + * titles (Optional[List[str]]): List of titles for each image in the row, + drawing on top of each image + * labels (Optional[List[str]]): List of labels for each image in the row, + drawing under each image + output_dir: output folder of the saved images, image will be saved with file + names "grid_{start_row_idx}.png" + path_manager: path manager for io + max_rows_per_image: maximum number of rows per image, multiple images could + be generated. + grid_padding_rows: Number of pixels between each row in the grid + grid_padding_cols: Number of pixels between each column in the grid + font_size: Size of font + """ + if not path_manager.exists(output_dir): + path_manager.mkdirs(output_dir) + + if max_rows_per_image is None: + max_rows_per_image = len(rows) + + ret = [] + start_row = 0 + while start_row < len(rows): + end_row = min(len(rows), (start_row + max_rows_per_image)) + grid_img = draw_image_grid_by_rows( + path_manager, + rows[start_row:end_row], + grid_padding_rows=grid_padding_rows, + grid_padding_cols=grid_padding_cols, + font_size=font_size, + ) + + output_filepath = os.path.join(output_dir, f"grid_{start_row:05d}.png") + with path_manager.open(output_filepath, "wb") as fp: + grid_img.save(fp) + start_row = end_row + ret.append(output_filepath) + + return ret + + +def draw_image_grid_by_rows( + path_manager, + rows: List[Dict[str, Any]], + grid_padding_rows: int = 15, + grid_padding_cols: int = 10, + font_size: int = 10, +) -> Image.Image: + """ + Draw a grid of images into a single image, preserve the image sizes. + rows: List of dicts that represent each row in the grid. Each dict must have + the keys "images", and optional keys "row_title", "titles", and "labels": + * images (List[str]): List of image paths for the row, + * row_title (Optional[str]): The title of the row + * titles (Optional[List[str]]): List of titles for each image in the row, + drawing on top of each image + * labels (Optional[List[str]]): List of labels for each image in the row, + drawing under each image + grid_padding_rows: Number of pixels between each row in the grid + grid_padding_cols: Number of pixels between each column in the grid + font_size: Size of font + """ + image_paths = [] + columns: Optional[int] = None + row_titles = [] + image_titles = [] + image_labels = [] + for item in rows: + # images + images = item["images"] + assert isinstance(images, list) + if columns is None: + columns = len(images) + else: + assert len(images) == columns + image_paths.extend(images) + + # row title + row_title = item.get("row_title", None) + row_titles.append(row_title) + + # image_titles + image_title = item.get("titles", None) + if image_title is not None: + assert isinstance(image_title, list) and len(image_title) == columns + image_titles.extend(image_title) + else: + image_titles.extend([None] * columns) + + # image_labels + image_label = item.get("labels", None) + if image_label is not None: + assert isinstance(image_label, list) and len(image_label) == columns + image_labels.extend(image_label) + else: + image_labels.extend([None] * columns) + + return draw_image_grid( + path_manager, + image_paths, + columns=columns, + row_titles=row_titles, + image_titles=image_titles, + image_labels=image_labels, + grid_padding_rows=grid_padding_rows, + grid_padding_cols=grid_padding_cols, + font_size=font_size, + ) + + +def draw_image_grid( + path_manager, + image_paths: List[str], + columns: int, + row_titles: Optional[List[str]] = None, + image_titles: Optional[List[str]] = None, + image_labels: Optional[List[str]] = None, + grid_padding_rows: int = 15, + grid_padding_cols: int = 10, + font_size: int = 10, +) -> Image.Image: + """ + Draw a grid of images into a single image, preserve the image sizes. + image_paths: List of images paths + columns: number of columns in the grid + """ + + num_images = len(image_paths) + rows = math.ceil(num_images / columns) + + if row_titles is not None: + assert len(row_titles) == rows, f"{len(row_titles), rows}" + if image_titles is not None: + assert len(image_titles) == num_images, f"{len(image_titles), num_images}" + if image_labels is not None: + assert len(image_labels) == num_images, f"{len(image_labels), num_images}" + + # Load the images from file paths + images = [] + for ip in image_paths: + with path_manager.open(ip, "rb") as fp: + images.append(Image.open(fp)) + images[-1].load() + + # Calculate grid dimensions + grid_width = columns * (max(image.width for image in images) + grid_padding_cols) + grid_height = rows * (max(image.height for image in images) + grid_padding_rows * 3) + + # Create a blank canvas for the grid + grid_image = Image.new("RGB", (grid_width, grid_height), color="white") + draw = ImageDraw.Draw(grid_image) + + # # Load fonts + font_path = get_font_path() + if font_path is not None: + font_path = path_manager.get_local_path(font_path) + title_font = ImageFont.truetype(font_path, font_size) + label_font = ImageFont.truetype(font_path, font_size) + else: + title_font = ImageFont.load_default() + label_font = ImageFont.load_default() + + # Draw the images and labels on the grid + for row in range(rows): + # Calculate starting position for the current row + start_x = 0 + start_y = row * (max(image.height for image in images) + grid_padding_rows * 3) + + # Draw the title for the current row + if row_titles is not None: + draw.text( + (start_x, start_y), row_titles[row], font=title_font, fill="black" + ) + + # Draw the images and labels for the current row + for col in range(columns): + # Calculate the position for the current image and label + image_index = row * columns + col + if image_index >= num_images: + break + + image = images[image_index] + label = image_labels[image_index] if image_labels is not None else None + image_title = ( + image_titles[image_index] if image_titles is not None else None + ) + + image_x = start_x + col * (image.width + grid_padding_cols) + image_y = start_y + grid_padding_rows * 2 + + # Paste the image onto the grid + grid_image.paste(image, (image_x, image_y)) + + # Draw the label below the image + if label is not None: + label_width, label_height = label_font.getsize(label) + label_x = ( + image_x + (image.width - label_width) // 2 + ) # Center horizontally + label_y = image_y + image.height + + draw.text((label_x, label_y), label, font=label_font, fill="black") + + # Draw the image title above the image + if image_title is not None: + image_title_width, image_title_height = title_font.getsize(image_title) + image_title_x = ( + image_x + (image.width - image_title_width) // 2 + ) # Center horizontally + image_title_y = start_y + grid_padding_rows + + draw.text( + (image_title_x, image_title_y), + image_title, + font=title_font, + fill="black", + ) + + return grid_image diff --git a/mobile_cv/common/tests/test_misc_visualize_utils.py b/mobile_cv/common/tests/test_misc_visualize_utils.py new file mode 100644 index 00000000..dda61e8c --- /dev/null +++ b/mobile_cv/common/tests/test_misc_visualize_utils.py @@ -0,0 +1,164 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import os +import tempfile +import unittest + +import mobile_cv.common.misc.visualize_utils as vu +import numpy as np +import PIL +from mobile_cv.common.utils_io import get_path_manager + + +SAVE_IMAGE = False + + +def _create_image(output_dir, idx, height, width, color=None): + os.makedirs(output_dir, exist_ok=True) + ret = os.path.join(output_dir, f"test_image_{idx}.png") + PIL.Image.new("RGB", (height, width), color=color).save(ret) + return ret + + +class TestMiscVisualizationUtils(unittest.TestCase): + def test_draw_image_grid(self): + """ + buck2 run @mode/dev-nosan //mobile-vision/mobile_cv/mobile_cv/common:tests -- -r test_draw_image_grid$ + """ + pm = get_path_manager() + + num_images = 12 + num_columns = 4 + image_size = 64 + + with tempfile.TemporaryDirectory() as output_dir: + images = [ + _create_image( + output_dir, + idx, + image_size, + image_size, + (idx * 20, idx * 20, idx * 20), + ) + for idx in range(num_images) + ] + ret = vu.draw_image_grid( + pm, + images, + columns=num_columns, + grid_padding_rows=0, + grid_padding_cols=0, + ) + self.assertEqual(ret.height, image_size * (num_images // num_columns)) + self.assertEqual(ret.width, image_size * num_columns) + + img = np.array(ret) + for row in range(3): + for col in range(num_columns): + rs = row * image_size + cs = col * image_size + self.assertEqual( + np.linalg.norm( + img[rs : rs + image_size, cs : cs + image_size, :] + - ((row * num_columns + col) * 20) + ), + 0.0, + ) + + if SAVE_IMAGE: + with pm.open(os.path.join(output_dir, "test_vis.png"), "wb") as fp: + ret.save(fp) + print(output_dir) + + def test_draw_image_grid_text(self): + """ + buck2 run @mode/dev-nosan //mobile-vision/mobile_cv/mobile_cv/common:tests -- -r test_draw_image_grid_text$ + """ + + pm = get_path_manager() + + num_images = 6 + num_columns = 3 + num_rows = num_images // num_columns + image_size = 64 + + with tempfile.TemporaryDirectory() as output_dir: + images = [ + _create_image( + output_dir, + idx, + image_size, + image_size, + (idx * 30, idx * 30, idx * 30), + ) + for idx in range(num_images) + ] + row_titles = [f"title_{x}" for x in range(num_rows)] + image_titles = [f"image_{x}" for x in range(num_images)] + image_labels = [f"label_{x}" for x in range(num_images)] + ret = vu.draw_image_grid( + pm, + images, + columns=num_columns, + row_titles=row_titles, + image_titles=image_titles, + image_labels=image_labels, + grid_padding_rows=12, + grid_padding_cols=5, + font_size=10, + ) + self.assertEqual(ret.height, (image_size + 12 * 3) * num_rows) + self.assertEqual(ret.width, (image_size + 5) * num_columns) + + img = np.array(ret) + for row in range(num_rows): + for col in range(num_columns): + rs = row * (image_size + 12 * 3) + 12 * 2 + cs = col * (image_size + 5) + self.assertEqual( + np.linalg.norm( + img[rs : rs + image_size, cs : cs + image_size, :] + - ((row * num_columns + col) * 30) + ), + 0.0, + ) + + if SAVE_IMAGE: + with pm.open(os.path.join(output_dir, "test_vis.png"), "wb") as fp: + ret.save(fp) + print(output_dir) + + def test_save_as_image_grids(self): + """ + buck2 run @mode/dev-nosan //mobile-vision/mobile_cv/mobile_cv/common:tests -- -r test_save_as_image_grids$ + """ + pm = get_path_manager() + + num_columns = 3 + num_rows = 10 + image_size = 32 + + with tempfile.TemporaryDirectory() as output_dir: + src_output_dir = os.path.join(output_dir, "src") + rows = [ + { + "images": [ + _create_image(src_output_dir, idx, image_size, image_size) + for idx in range(num_columns) + ], + "row_title": "row_title", + "titles": [f"title_{idx}" for idx in range(num_columns)], + "labels": [f"label_{idx}" for idx in range(num_columns)], + } + ] * num_rows + + dst_output_dir = os.path.join(output_dir, "dst") + ret = vu.save_as_image_grids( + rows=rows, + output_dir=dst_output_dir, + path_manager=pm, + max_rows_per_image=4, + ) + self.assertEqual(len(ret), 3) + for img_path in ret: + self.assertTrue(os.path.exists(img_path))