From 4f4d565260b57fdfec9e5b4bfa6cce8277985069 Mon Sep 17 00:00:00 2001 From: "Shin, Eunwoo" Date: Fri, 31 May 2024 16:33:43 +0900 Subject: [PATCH] align with pre-commit --- model_api/python/model_api/adapters/utils.py | 2 +- .../model_api/models/action_classification.py | 45 ++++++++++++++----- tests/python/accuracy/prepare_data.py | 2 +- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/model_api/python/model_api/adapters/utils.py b/model_api/python/model_api/adapters/utils.py index e5a4d987..7228e66e 100644 --- a/model_api/python/model_api/adapters/utils.py +++ b/model_api/python/model_api/adapters/utils.py @@ -529,5 +529,5 @@ def __call__(self, inputs): if self.is_trivial: return inputs if self.reverse_input_channels: - inputs = inputs[...,::-1] + inputs = inputs[..., ::-1] return (inputs - self.means) / self.std_scales diff --git a/model_api/python/model_api/models/action_classification.py b/model_api/python/model_api/models/action_classification.py index faa5daa9..f4799dfa 100644 --- a/model_api/python/model_api/models/action_classification.py +++ b/model_api/python/model_api/models/action_classification.py @@ -16,14 +16,14 @@ from __future__ import annotations -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np - from model_api.adapters.utils import RESIZE_TYPES, InputTransform + from .model import Model -from .utils import ClassificationResult, load_labels from .types import BooleanValue, ListValue, NumericalValue, StringValue +from .utils import ClassificationResult, load_labels if TYPE_CHECKING: from model_api.adapters.inference_adapter import InferenceAdapter @@ -57,7 +57,7 @@ def __init__( self, inference_adapter: InferenceAdapter, configuration: dict[str, Any] = dict(), - preload: bool = False + preload: bool = False, ) -> None: """Action classaification model constructor @@ -76,11 +76,17 @@ def __init__( self.image_blob_name = self.image_blob_names[0] self.nscthw_layout = "NSCTHW" in self.inputs[self.image_blob_name].layout if self.nscthw_layout: - self.n, self.s, self.c, self.t, self.h, self.w = self.inputs[self.image_blob_name].shape + self.n, self.s, self.c, self.t, self.h, self.w = self.inputs[ + self.image_blob_name + ].shape else: - self.n, self.s, self.t, self.h, self.w, self.c = self.inputs[self.image_blob_name].shape + self.n, self.s, self.t, self.h, self.w, self.c = self.inputs[ + self.image_blob_name + ].shape self.resize = RESIZE_TYPES[self.resize_type] - self.input_transform = InputTransform(self.reverse_input_channels, self.mean_values, self.scale_values) + self.input_transform = InputTransform( + self.reverse_input_channels, self.mean_values, self.scale_values + ) if self.path_to_labels: self.labels = load_labels(self.path_to_labels) @@ -146,7 +152,9 @@ def _get_inputs(self) -> tuple[list[str], list[str]]: ) return image_blob_names, image_info_blob_names - def preprocess(self, inputs: np.ndarray) -> tuple[dict[str, np.ndarray], dict[str, tuple[int, ...]]]: + def preprocess( + self, inputs: np.ndarray + ) -> tuple[dict[str, np.ndarray], dict[str, tuple[int, ...]]]: """Data preprocess method It performs basic preprocessing of a single image: @@ -171,8 +179,14 @@ def preprocess(self, inputs: np.ndarray) -> tuple[dict[str, np.ndarray], dict[st } - the input metadata, which might be used in `postprocess` method """ - meta = {"original_shape": inputs.shape, "resized_shape": (self.n, self.s, self.c, self.t, self.h, self.w)} - resized_inputs = [self.resize(frame, (self.w, self.h), pad_value=self.pad_value) for frame in inputs] + meta = { + "original_shape": inputs.shape, + "resized_shape": (self.n, self.s, self.c, self.t, self.h, self.w), + } + resized_inputs = [ + self.resize(frame, (self.w, self.h), pad_value=self.pad_value) + for frame in inputs + ] frames = self.input_transform(np.array(resized_inputs)) np_frames = self._change_layout(frames) dict_inputs = {self.image_blob_name: np_frames} @@ -192,8 +206,15 @@ def _change_layout(self, inputs: list[np.ndarray]) -> np.ndarray: return np_inputs.transpose(0, 1, -1, 2, 3, 4) # [1, 1, C, T, H, W] return np_inputs - def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> ClassificationResult: + def postprocess( + self, outputs: dict[str, np.ndarray], meta: dict[str, Any] + ) -> ClassificationResult: """Post-process.""" logits = next(iter(outputs.values())).squeeze() index = np.argmax(logits) - return ClassificationResult([(index, self.labels[index], logits[index])], np.ndarray(0), np.ndarray(0), np.ndarray(0)) + return ClassificationResult( + [(index, self.labels[index], logits[index])], + np.ndarray(0), + np.ndarray(0), + np.ndarray(0), + ) diff --git a/tests/python/accuracy/prepare_data.py b/tests/python/accuracy/prepare_data.py index 79b28c33..e11ced03 100644 --- a/tests/python/accuracy/prepare_data.py +++ b/tests/python/accuracy/prepare_data.py @@ -112,7 +112,7 @@ async def main(): download_otx_model( client, otx_models_dir, "cls_efficient_b0_shuffled_outputs" ), - download_otx_model( client, otx_models_dir, "action_cls_xd3_kinetic"), + download_otx_model(client, otx_models_dir, "action_cls_xd3_kinetic"), )