Skip to content

Commit

Permalink
align with pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed May 31, 2024
1 parent 63c9453 commit 4f4d565
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
2 changes: 1 addition & 1 deletion model_api/python/model_api/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,5 +529,5 @@ def __call__(self, inputs):
if self.is_trivial:
return inputs
if self.reverse_input_channels:
inputs = inputs[...,::-1]
inputs = inputs[..., ::-1]
return (inputs - self.means) / self.std_scales
45 changes: 33 additions & 12 deletions model_api/python/model_api/models/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

from __future__ import annotations

from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np

from model_api.adapters.utils import RESIZE_TYPES, InputTransform

from .model import Model
from .utils import ClassificationResult, load_labels
from .types import BooleanValue, ListValue, NumericalValue, StringValue
from .utils import ClassificationResult, load_labels

if TYPE_CHECKING:
from model_api.adapters.inference_adapter import InferenceAdapter
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
self,
inference_adapter: InferenceAdapter,
configuration: dict[str, Any] = dict(),
preload: bool = False
preload: bool = False,
) -> None:
"""Action classaification model constructor
Expand All @@ -76,11 +76,17 @@ def __init__(
self.image_blob_name = self.image_blob_names[0]
self.nscthw_layout = "NSCTHW" in self.inputs[self.image_blob_name].layout
if self.nscthw_layout:
self.n, self.s, self.c, self.t, self.h, self.w = self.inputs[self.image_blob_name].shape
self.n, self.s, self.c, self.t, self.h, self.w = self.inputs[
self.image_blob_name
].shape
else:
self.n, self.s, self.t, self.h, self.w, self.c = self.inputs[self.image_blob_name].shape
self.n, self.s, self.t, self.h, self.w, self.c = self.inputs[
self.image_blob_name
].shape
self.resize = RESIZE_TYPES[self.resize_type]
self.input_transform = InputTransform(self.reverse_input_channels, self.mean_values, self.scale_values)
self.input_transform = InputTransform(
self.reverse_input_channels, self.mean_values, self.scale_values
)
if self.path_to_labels:
self.labels = load_labels(self.path_to_labels)

Expand Down Expand Up @@ -146,7 +152,9 @@ def _get_inputs(self) -> tuple[list[str], list[str]]:
)
return image_blob_names, image_info_blob_names

def preprocess(self, inputs: np.ndarray) -> tuple[dict[str, np.ndarray], dict[str, tuple[int, ...]]]:
def preprocess(
self, inputs: np.ndarray
) -> tuple[dict[str, np.ndarray], dict[str, tuple[int, ...]]]:
"""Data preprocess method
It performs basic preprocessing of a single image:
Expand All @@ -171,8 +179,14 @@ def preprocess(self, inputs: np.ndarray) -> tuple[dict[str, np.ndarray], dict[st
}
- the input metadata, which might be used in `postprocess` method
"""
meta = {"original_shape": inputs.shape, "resized_shape": (self.n, self.s, self.c, self.t, self.h, self.w)}
resized_inputs = [self.resize(frame, (self.w, self.h), pad_value=self.pad_value) for frame in inputs]
meta = {
"original_shape": inputs.shape,
"resized_shape": (self.n, self.s, self.c, self.t, self.h, self.w),
}
resized_inputs = [
self.resize(frame, (self.w, self.h), pad_value=self.pad_value)
for frame in inputs
]
frames = self.input_transform(np.array(resized_inputs))
np_frames = self._change_layout(frames)
dict_inputs = {self.image_blob_name: np_frames}
Expand All @@ -192,8 +206,15 @@ def _change_layout(self, inputs: list[np.ndarray]) -> np.ndarray:
return np_inputs.transpose(0, 1, -1, 2, 3, 4) # [1, 1, C, T, H, W]
return np_inputs

def postprocess(self, outputs: dict[str, np.ndarray], meta: dict[str, Any]) -> ClassificationResult:
def postprocess(
self, outputs: dict[str, np.ndarray], meta: dict[str, Any]
) -> ClassificationResult:
"""Post-process."""
logits = next(iter(outputs.values())).squeeze()
index = np.argmax(logits)
return ClassificationResult([(index, self.labels[index], logits[index])], np.ndarray(0), np.ndarray(0), np.ndarray(0))
return ClassificationResult(
[(index, self.labels[index], logits[index])],
np.ndarray(0),
np.ndarray(0),
np.ndarray(0),
)
2 changes: 1 addition & 1 deletion tests/python/accuracy/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def main():
download_otx_model(
client, otx_models_dir, "cls_efficient_b0_shuffled_outputs"
),
download_otx_model( client, otx_models_dir, "action_cls_xd3_kinetic"),
download_otx_model(client, otx_models_dir, "action_cls_xd3_kinetic"),
)


Expand Down

0 comments on commit 4f4d565

Please sign in to comment.