Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support configurable input size #3788

Merged
merged 47 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
27df73c
draft implementation
eunwoosh Jul 31, 2024
bb9b66e
draft implementation2
eunwoosh Jul 31, 2024
b6a7685
check input size constant value
eunwoosh Jul 31, 2024
4c0781f
update model part
eunwoosh Aug 5, 2024
bf0c736
draft implementation
eunwoosh Aug 5, 2024
a40db40
update interface
eunwoosh Aug 5, 2024
ac506ef
implement adaptive input size draft version
eunwoosh Aug 7, 2024
b1121f0
handle edge case
eunwoosh Aug 7, 2024
659a9cf
add input_size_multiplier and pass it to datamodule in cli
eunwoosh Aug 7, 2024
e329029
change typehint from sequence to tuple
eunwoosh Aug 8, 2024
8474b07
align with pre-commit
eunwoosh Aug 8, 2024
28dcd63
write doc string
eunwoosh Aug 8, 2024
01cb629
implement unit test
eunwoosh Aug 9, 2024
30172dd
update unit test
eunwoosh Aug 9, 2024
73598ab
implement left unit test
eunwoosh Aug 9, 2024
14b7a59
Merge branch 'develop' into configurable_input_size
eunwoosh Aug 9, 2024
4df5674
align with develop branch
eunwoosh Aug 9, 2024
4fcc627
fix typo
eunwoosh Aug 9, 2024
39b0650
exclude batch and num channel from input size
eunwoosh Aug 9, 2024
aee7600
update docstring
eunwoosh Aug 9, 2024
5d6c481
update unit test
eunwoosh Aug 9, 2024
ff8ecf9
adaptive input size supports not square
eunwoosh Aug 9, 2024
82e41e0
update changelog
eunwoosh Aug 9, 2024
4e8ce70
fix typo
eunwoosh Aug 9, 2024
9260a8c
fix typo
eunwoosh Aug 12, 2024
d40a9f0
update base data pipeline
eunwoosh Aug 12, 2024
c049044
update keypoint detection
eunwoosh Aug 12, 2024
9097b3d
align with pre-commit
eunwoosh Aug 12, 2024
0adb7ea
update docstring
eunwoosh Aug 12, 2024
1791e84
Merge branch 'develop' into configurable_input_size
eunwoosh Aug 12, 2024
0896ee3
update unit test
eunwoosh Aug 12, 2024
05483ce
update auto_configurator to use None intead of none
eunwoosh Aug 12, 2024
8d34d2c
revert data module policy to apply input_size to subset cfg
eunwoosh Aug 12, 2024
982c985
revert keypoint detection
eunwoosh Aug 12, 2024
f8f9e28
add comments to explain a reason of priority in compute_robust_datase…
eunwoosh Aug 12, 2024
09772d8
add integration test
eunwoosh Aug 12, 2024
5b3f198
update unit test
eunwoosh Aug 13, 2024
5fe6777
apply input_size to anomaly task
eunwoosh Aug 13, 2024
9dccf2d
update docstring
eunwoosh Aug 13, 2024
081c94b
remove unused comment
eunwoosh Aug 13, 2024
4ee1551
re-enable anomaly integration test
eunwoosh Aug 13, 2024
9127e5d
apply configurable input size to keypoint detection
eunwoosh Aug 13, 2024
a6b922d
update unit test
eunwoosh Aug 13, 2024
255a4e0
update unit test
eunwoosh Aug 13, 2024
900faaf
Merge branch 'develop' into configurable_input_size
eunwoosh Aug 13, 2024
8e6f8f8
update h-label head
eunwoosh Aug 13, 2024
4868ba5
Merge branch 'develop' into configurable_input_size
eunwoosh Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ All notable changes to this project will be documented in this file.
(https://github.com/openvinotoolkit/training_extensions/pull/3781)
- Update head and h-label format for hierarchical label classification
(https://github.com/openvinotoolkit/training_extensions/pull/3810)
- Support configurable input size
(https://github.com/openvinotoolkit/training_extensions/pull/3788)

### Enhancements

Expand Down
2 changes: 2 additions & 0 deletions src/otx/algo/action_classification/movinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class MoViNet(OTXActionClsModel):
def __init__(
self,
label_info: LabelInfoTypes,
input_size: tuple[int, int] = (224, 224),
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
Expand All @@ -40,6 +41,7 @@ def __init__(
self.load_from = "https://github.com/Atze00/MoViNet-pytorch/blob/main/weights/modelA0_statedict_v3?raw=true"
super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
Expand Down
2 changes: 2 additions & 0 deletions src/otx/algo/action_classification/x3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class X3D(OTXActionClsModel):
def __init__(
self,
label_info: LabelInfoTypes,
input_size: tuple[int, int] = (224, 224),
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
Expand All @@ -41,6 +42,7 @@ def __init__(
self.load_from = "https://download.openmmlab.com/mmaction/recognition/x3d/facebook/x3d_m_facebook_16x5x1_kinetics400_rgb_20201027-3f42382a.pth"
super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
Expand Down
6 changes: 5 additions & 1 deletion src/otx/algo/anomaly/padim.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Padim(OTXAnomaly, AnomalibPadim):
task (Literal[
OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION
], optional): Task type of Anomaly Task. Defaults to OTXTaskType.ANOMALY_CLASSIFICATION.
input_size (tuple[int, int], optional):
Model input size in the order of height and width. Defaults to (256, 256)
"""

def __init__(
Expand All @@ -47,8 +49,9 @@ def __init__(
OTXTaskType.ANOMALY_DETECTION,
OTXTaskType.ANOMALY_SEGMENTATION,
] = OTXTaskType.ANOMALY_CLASSIFICATION,
input_size: tuple[int, int] = (256, 256),
) -> None:
OTXAnomaly.__init__(self)
OTXAnomaly.__init__(self, input_size)
AnomalibPadim.__init__(
self,
backbone=backbone,
Expand All @@ -57,6 +60,7 @@ def __init__(
n_features=n_features,
)
self.task = task
self.input_size = input_size

def configure_optimizers(self) -> tuple[list[Optimizer], list[Optimizer]] | None:
"""PADIM doesn't require optimization, therefore returns no optimizers."""
Expand Down
6 changes: 5 additions & 1 deletion src/otx/algo/anomaly/stfpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class Stfpm(OTXAnomaly, AnomalibStfpm):
task (Literal[
OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION
], optional): Task type of Anomaly Task. Defaults to OTXTaskType.ANOMALY_CLASSIFICATION.
input_size (tuple[int, int], optional):
Model input size in the order of height and width. Defaults to (256, 256)
"""

def __init__(
Expand All @@ -43,15 +45,17 @@ def __init__(
OTXTaskType.ANOMALY_DETECTION,
OTXTaskType.ANOMALY_SEGMENTATION,
] = OTXTaskType.ANOMALY_CLASSIFICATION,
input_size: tuple[int, int] = (256, 256),
**kwargs,
) -> None:
OTXAnomaly.__init__(self)
OTXAnomaly.__init__(self, input_size=input_size)
AnomalibStfpm.__init__(
self,
backbone=backbone,
layers=layers,
)
self.task = task
self.input_size = input_size

@property
def trainable_model(self) -> str:
Expand Down
5 changes: 4 additions & 1 deletion src/otx/algo/classification/backbones/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ class OTXEfficientNet(EfficientNet):
in_size : tuple of two ints. Spatial size of the expected input image.
"""

def __init__(self, version: EFFICIENTNET_VERSION, **kwargs):
def __init__(self, version: EFFICIENTNET_VERSION, input_size: tuple[int, int] | None = None, **kwargs):
self.model_name = "efficientnet_" + version

if version == "b0":
Expand Down Expand Up @@ -615,6 +615,9 @@ def __init__(self, version: EFFICIENTNET_VERSION, **kwargs):
msg = f"Unsupported EfficientNet version {version}"
raise ValueError(msg)

if input_size is not None:
in_size = input_size

init_block_channels = 32
layers = [1, 2, 2, 3, 3, 4, 1]
downsample = [1, 1, 1, 1, 0, 1, 0]
Expand Down
12 changes: 9 additions & 3 deletions src/otx/algo/classification/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED,
) -> None:
self.version = version
self.pretrained = pretrained

super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
Expand All @@ -86,7 +88,7 @@ def _create_model(self) -> nn.Module:
return model

def _build_model(self, num_classes: int) -> nn.Module:
backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained)
backbone = OTXEfficientNet(version=self.version, input_size=self.input_size, pretrained=self.pretrained)
neck = GlobalAveragePooling(dim=2)
loss = nn.CrossEntropyLoss(reduction="none")
if self.train_type == OTXTrainType.SEMI_SUPERVISED:
Expand Down Expand Up @@ -149,6 +151,7 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
) -> None:
self.version = version
self.pretrained = pretrained
Expand All @@ -159,6 +162,7 @@ def __init__(
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)

def _create_model(self) -> nn.Module:
Expand All @@ -176,7 +180,7 @@ def _create_model(self) -> nn.Module:
return model

def _build_model(self, num_classes: int) -> nn.Module:
backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained)
backbone = OTXEfficientNet(version=self.version, input_size=self.input_size, pretrained=self.pretrained)
return ImageClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
Expand Down Expand Up @@ -229,6 +233,7 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallble,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
) -> None:
self.version = version
self.pretrained = pretrained
Expand All @@ -239,6 +244,7 @@ def __init__(
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)

def _create_model(self) -> nn.Module:
Expand All @@ -262,7 +268,7 @@ def _build_model(self, head_config: dict) -> nn.Module:
if not isinstance(self.label_info, HLabelInfo):
raise TypeError(self.label_info)

backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained)
backbone = OTXEfficientNet(version=self.version, input_size=self.input_size, pretrained=self.pretrained)
return ImageClassifier(
backbone=backbone,
neck=nn.Identity(),
Expand Down
19 changes: 19 additions & 0 deletions src/otx/algo/classification/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any

import torch
from torch import Tensor, nn
from transformers import AutoModelForImageClassification
from transformers.configuration_utils import PretrainedConfig

from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.classification import (
Expand All @@ -29,6 +31,10 @@
from otx.core.metrics import MetricCallable


DEFAULT_INPUT_SIZE = (224, 224)
logger = logging.getLogger(__name__)


class HuggingFaceModelForMulticlassCls(OTXMulticlassClsModel):
"""HuggingFaceModelForMulticlassCls is a class that represents a Hugging Face model for multiclass classification.

Expand All @@ -38,6 +44,8 @@ class HuggingFaceModelForMulticlassCls(OTXMulticlassClsModel):
optimizer (OptimizerCallable, optional): The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
input_size (tuple[int, int], optional):
Model input size in the order of height and width. Defaults to (224, 224)

Example:
1. API
Expand All @@ -59,6 +67,7 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = DEFAULT_INPUT_SIZE,
) -> None:
self.model_name = model_name_or_path

Expand All @@ -68,13 +77,23 @@ def __init__(
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)

def _create_model(self) -> nn.Module:
model_config, _ = PretrainedConfig.get_config_dict(self.model_name)
kwargs = {}
if "image_size" in model_config:
kwargs["image_size"] = self.input_size[0]
elif self.input_size != DEFAULT_INPUT_SIZE:
msg = "There is no 'image_size' argument in the model configuration. There may be unexpected results."
logger.warning(msg)

return AutoModelForImageClassification.from_pretrained(
pretrained_model_name_or_path=self.model_name,
num_labels=self.label_info.num_classes,
ignore_mismatched_sizes=True,
**kwargs,
)

def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]:
Expand Down
18 changes: 13 additions & 5 deletions src/otx/algo/classification/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class MobileNetV3ForMulticlassCls(OTXMulticlassClsModel):
metric (MetricCallable, optional): The metric callable. Defaults to MultiClassClsMetricCallable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
freeze_backbone (bool, optional): Whether to freeze the backbone layers during training. Defaults to False.
input_size (tuple[int, int], optional):
Model input size in the order of height and width. Defaults to (224, 224)
"""

def __init__(
Expand All @@ -72,6 +74,7 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED,
) -> None:
self.mode = mode
Expand All @@ -82,6 +85,7 @@ def __init__(
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
train_type=train_type,
)

Expand All @@ -100,7 +104,7 @@ def _create_model(self) -> nn.Module:
return model

def _build_model(self, num_classes: int) -> nn.Module:
backbone = OTXMobileNetV3(mode=self.mode)
backbone = OTXMobileNetV3(mode=self.mode, input_size=self.input_size)
neck = GlobalAveragePooling(dim=2)
loss = nn.CrossEntropyLoss(reduction="none")
in_channels = 960 if self.mode == "large" else 576
Expand Down Expand Up @@ -163,6 +167,7 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
) -> None:
self.mode = mode
super().__init__(
Expand All @@ -171,6 +176,7 @@ def __init__(
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)

def _create_model(self) -> nn.Module:
Expand All @@ -189,7 +195,7 @@ def _create_model(self) -> nn.Module:

def _build_model(self, num_classes: int) -> nn.Module:
return ImageClassifier(
backbone=OTXMobileNetV3(mode=self.mode),
backbone=OTXMobileNetV3(mode=self.mode, input_size=self.input_size),
neck=GlobalAveragePooling(dim=2),
head=MultiLabelNonLinearClsHead(
num_classes=num_classes,
Expand Down Expand Up @@ -246,7 +252,7 @@ def _exporter(self) -> OTXModelExporter:
"""Creates OTXModelExporter object that can export the model."""
return OTXNativeModelExporter(
task_level_export_parameters=self._export_parameters,
input_size=(1, 3, 224, 224),
input_size=(1, 3, *self.input_size),
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
resize_mode="standard",
Expand Down Expand Up @@ -292,6 +298,7 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallble,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
) -> None:
self.mode = mode
super().__init__(
Expand All @@ -300,6 +307,7 @@ def __init__(
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)

def _create_model(self) -> nn.Module:
Expand All @@ -324,7 +332,7 @@ def _build_model(self, head_config: dict) -> nn.Module:
raise TypeError(self.label_info)

return ImageClassifier(
backbone=OTXMobileNetV3(mode=self.mode),
backbone=OTXMobileNetV3(mode=self.mode, input_size=self.input_size),
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=960,
Expand Down Expand Up @@ -403,7 +411,7 @@ def _exporter(self) -> OTXModelExporter:
"""Creates OTXModelExporter object that can export the model."""
return OTXNativeModelExporter(
task_level_export_parameters=self._export_parameters,
input_size=(1, 3, 224, 224),
input_size=(1, 3, *self.input_size),
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
resize_mode="standard",
Expand Down
Loading
Loading