Skip to content

Commit

Permalink
Support configurable input size (#3788)
Browse files Browse the repository at this point in the history
* draft implementation

* draft implementation2

* check input size constant value

* update model part

* update interface

* implement adaptive input size draft version

* handle edge case

* add input_size_multiplier and pass it to datamodule in cli

* change typehint from sequence to tuple

* align with pre-commit

* write doc string

* implement unit test

* update unit test

* implement left unit test

* align with develop branch

* fix typo

* exclude batch and num channel from input size

* update docstring

* update unit test

* adaptive input size supports not square

* update changelog

* fix typo

* fix typo

* update base data pipeline

* update keypoint detection

* align with pre-commit

* update docstring

* update unit test

* update auto_configurator to use None intead of none

* revert data module policy to apply input_size to subset cfg

* revert keypoint detection

* add comments to explain a reason of priority in compute_robust_dataset_statistics

* add integration test

* update unit test

* apply input_size to anomaly task

* update docstring

* remove unused comment

* re-enable anomaly integration test

* apply configurable input size to keypoint detection

* update unit test

* update unit test

* update h-label head
  • Loading branch information
eunwoosh authored Aug 14, 2024
1 parent 2ecaac1 commit 0b5ed3b
Show file tree
Hide file tree
Showing 84 changed files with 1,679 additions and 488 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ All notable changes to this project will be documented in this file.
(https://github.com/openvinotoolkit/training_extensions/pull/3801)
- Update head and h-label format for hierarchical label classification
(https://github.com/openvinotoolkit/training_extensions/pull/3810)
- Support configurable input size
(https://github.com/openvinotoolkit/training_extensions/pull/3788)

### Enhancements

Expand Down
2 changes: 2 additions & 0 deletions src/otx/algo/action_classification/movinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class MoViNet(OTXActionClsModel):
def __init__(
self,
label_info: LabelInfoTypes,
input_size: tuple[int, int] = (224, 224),
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
Expand All @@ -40,6 +41,7 @@ def __init__(
self.load_from = "https://github.com/Atze00/MoViNet-pytorch/blob/main/weights/modelA0_statedict_v3?raw=true"
super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
Expand Down
2 changes: 2 additions & 0 deletions src/otx/algo/action_classification/x3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class X3D(OTXActionClsModel):
def __init__(
self,
label_info: LabelInfoTypes,
input_size: tuple[int, int] = (224, 224),
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
Expand All @@ -41,6 +42,7 @@ def __init__(
self.load_from = "https://download.openmmlab.com/mmaction/recognition/x3d/facebook/x3d_m_facebook_16x5x1_kinetics400_rgb_20201027-3f42382a.pth"
super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
Expand Down
6 changes: 5 additions & 1 deletion src/otx/algo/anomaly/padim.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Padim(OTXAnomaly, AnomalibPadim):
task (Literal[
OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION
], optional): Task type of Anomaly Task. Defaults to OTXTaskType.ANOMALY_CLASSIFICATION.
input_size (tuple[int, int], optional):
Model input size in the order of height and width. Defaults to (256, 256)
"""

def __init__(
Expand All @@ -47,8 +49,9 @@ def __init__(
OTXTaskType.ANOMALY_DETECTION,
OTXTaskType.ANOMALY_SEGMENTATION,
] = OTXTaskType.ANOMALY_CLASSIFICATION,
input_size: tuple[int, int] = (256, 256),
) -> None:
OTXAnomaly.__init__(self)
OTXAnomaly.__init__(self, input_size)
AnomalibPadim.__init__(
self,
backbone=backbone,
Expand All @@ -57,6 +60,7 @@ def __init__(
n_features=n_features,
)
self.task = task
self.input_size = input_size

def configure_optimizers(self) -> tuple[list[Optimizer], list[Optimizer]] | None:
"""PADIM doesn't require optimization, therefore returns no optimizers."""
Expand Down
6 changes: 5 additions & 1 deletion src/otx/algo/anomaly/stfpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class Stfpm(OTXAnomaly, AnomalibStfpm):
task (Literal[
OTXTaskType.ANOMALY_CLASSIFICATION, OTXTaskType.ANOMALY_DETECTION, OTXTaskType.ANOMALY_SEGMENTATION
], optional): Task type of Anomaly Task. Defaults to OTXTaskType.ANOMALY_CLASSIFICATION.
input_size (tuple[int, int], optional):
Model input size in the order of height and width. Defaults to (256, 256)
"""

def __init__(
Expand All @@ -43,15 +45,17 @@ def __init__(
OTXTaskType.ANOMALY_DETECTION,
OTXTaskType.ANOMALY_SEGMENTATION,
] = OTXTaskType.ANOMALY_CLASSIFICATION,
input_size: tuple[int, int] = (256, 256),
**kwargs,
) -> None:
OTXAnomaly.__init__(self)
OTXAnomaly.__init__(self, input_size=input_size)
AnomalibStfpm.__init__(
self,
backbone=backbone,
layers=layers,
)
self.task = task
self.input_size = input_size

@property
def trainable_model(self) -> str:
Expand Down
5 changes: 4 additions & 1 deletion src/otx/algo/classification/backbones/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ class OTXEfficientNet(EfficientNet):
in_size : tuple of two ints. Spatial size of the expected input image.
"""

def __init__(self, version: EFFICIENTNET_VERSION, **kwargs):
def __init__(self, version: EFFICIENTNET_VERSION, input_size: tuple[int, int] | None = None, **kwargs):
self.model_name = "efficientnet_" + version

if version == "b0":
Expand Down Expand Up @@ -615,6 +615,9 @@ def __init__(self, version: EFFICIENTNET_VERSION, **kwargs):
msg = f"Unsupported EfficientNet version {version}"
raise ValueError(msg)

if input_size is not None:
in_size = input_size

init_block_channels = 32
layers = [1, 2, 2, 3, 3, 4, 1]
downsample = [1, 1, 1, 1, 0, 1, 0]
Expand Down
21 changes: 16 additions & 5 deletions src/otx/algo/classification/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from __future__ import annotations

from copy import deepcopy
from copy import copy, deepcopy
from math import ceil
from typing import TYPE_CHECKING, Literal

from torch import Tensor, nn
Expand Down Expand Up @@ -57,13 +58,15 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
train_type: Literal[OTXTrainType.SUPERVISED, OTXTrainType.SEMI_SUPERVISED] = OTXTrainType.SUPERVISED,
) -> None:
self.version = version
self.pretrained = pretrained

super().__init__(
label_info=label_info,
input_size=input_size,
optimizer=optimizer,
scheduler=scheduler,
metric=metric,
Expand All @@ -86,7 +89,7 @@ def _create_model(self) -> nn.Module:
return model

def _build_model(self, num_classes: int) -> nn.Module:
backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained)
backbone = OTXEfficientNet(version=self.version, input_size=self.input_size, pretrained=self.pretrained)
neck = GlobalAveragePooling(dim=2)
loss = nn.CrossEntropyLoss(reduction="none")
if self.train_type == OTXTrainType.SEMI_SUPERVISED:
Expand Down Expand Up @@ -149,6 +152,7 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
) -> None:
self.version = version
self.pretrained = pretrained
Expand All @@ -159,6 +163,7 @@ def __init__(
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)

def _create_model(self) -> nn.Module:
Expand All @@ -176,7 +181,7 @@ def _create_model(self) -> nn.Module:
return model

def _build_model(self, num_classes: int) -> nn.Module:
backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained)
backbone = OTXEfficientNet(version=self.version, input_size=self.input_size, pretrained=self.pretrained)
return ImageClassifier(
backbone=backbone,
neck=GlobalAveragePooling(dim=2),
Expand Down Expand Up @@ -229,6 +234,7 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallble,
torch_compile: bool = False,
input_size: tuple[int, int] = (224, 224),
) -> None:
self.version = version
self.pretrained = pretrained
Expand All @@ -239,6 +245,7 @@ def __init__(
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)

def _create_model(self) -> nn.Module:
Expand All @@ -262,15 +269,19 @@ def _build_model(self, head_config: dict) -> nn.Module:
if not isinstance(self.label_info, HLabelInfo):
raise TypeError(self.label_info)

backbone = OTXEfficientNet(version=self.version, pretrained=self.pretrained)
backbone = OTXEfficientNet(version=self.version, input_size=self.input_size, pretrained=self.pretrained)

copied_head_config = copy(head_config)
copied_head_config["step_size"] = (ceil(self.input_size[0] / 32), ceil(self.input_size[1] / 32))

return ImageClassifier(
backbone=backbone,
neck=nn.Identity(),
head=HierarchicalCBAMClsHead(
in_channels=backbone.num_features,
multiclass_loss=nn.CrossEntropyLoss(),
multilabel_loss=AsymmetricAngularLossWithIgnore(gamma_pos=0.0, gamma_neg=1.0, reduction="sum"),
**head_config,
**copied_head_config,
),
optimize_gap=False,
)
Expand Down
23 changes: 14 additions & 9 deletions src/otx/algo/classification/heads/hlabel_cls_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ class HierarchicalCBAMClsHead(HierarchicalClsHead):
thr (float, optional): Predictions with scores under the thresholds are considered
as negative. Defaults to 0.5.
init_cfg (dict | None, optional): Initialize configuration key-values, Defaults to None.
step_size (int, optional): Step size value for HierarchicalCBAMClsHead, Defaults to 7.
step_size (int | tuple[int, int], optional): Step size value for HierarchicalCBAMClsHead, Defaults to 7.
"""

def __init__(
Expand All @@ -435,7 +435,7 @@ def __init__(
multilabel_loss: nn.Module | None = None,
thr: float = 0.5,
init_cfg: dict | None = None,
step_size: int = 7,
step_size: int | tuple[int, int] = 7,
**kwargs,
):
super().__init__(
Expand All @@ -452,19 +452,19 @@ def __init__(
init_cfg=init_cfg,
**kwargs,
)
self.step_size = step_size
self.fc_superclass = nn.Linear(in_channels * step_size * step_size, num_multiclass_heads)
self.attention_fc = nn.Linear(num_multiclass_heads, in_channels * step_size * step_size)
self.step_size = (step_size, step_size) if isinstance(step_size, int) else tuple(step_size)
self.fc_superclass = nn.Linear(in_channels * self.step_size[0] * self.step_size[1], num_multiclass_heads)
self.attention_fc = nn.Linear(num_multiclass_heads, in_channels * self.step_size[0] * self.step_size[1])
self.cbam = CBAM(in_channels)
self.fc_subclass = nn.Linear(in_channels * step_size * step_size, num_single_label_classes)
self.fc_subclass = nn.Linear(in_channels * self.step_size[0] * self.step_size[1], num_single_label_classes)

self._init_layers()

def pre_logits(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor:
"""The process before the final classification head."""
if isinstance(feats, Sequence):
feats = feats[-1]
return feats.view(feats.size(0), self.in_channels * self.step_size * self.step_size)
return feats.view(feats.size(0), self.in_channels * self.step_size[0] * self.step_size[1])

def _init_layers(self) -> None:
"""Iniitialize weights of classification head."""
Expand All @@ -479,10 +479,15 @@ def forward(self, feats: tuple[torch.Tensor] | torch.Tensor) -> torch.Tensor:
attention_weights = torch.sigmoid(self.attention_fc(out_superclass))
attended_features = pre_logits * attention_weights

attended_features = attended_features.view(pre_logits.size(0), self.in_channels, self.step_size, self.step_size)
attended_features = attended_features.view(
pre_logits.size(0),
self.in_channels,
self.step_size[0],
self.step_size[1],
)
attended_features = self.cbam(attended_features)
attended_features = attended_features.view(
pre_logits.size(0),
self.in_channels * self.step_size * self.step_size,
self.in_channels * self.step_size[0] * self.step_size[1],
)
return self.fc_subclass(attended_features)
19 changes: 19 additions & 0 deletions src/otx/algo/classification/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any

import torch
from torch import Tensor, nn
from transformers import AutoModelForImageClassification
from transformers.configuration_utils import PretrainedConfig

from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.classification import (
Expand All @@ -29,6 +31,10 @@
from otx.core.metrics import MetricCallable


DEFAULT_INPUT_SIZE = (224, 224)
logger = logging.getLogger(__name__)


class HuggingFaceModelForMulticlassCls(OTXMulticlassClsModel):
"""HuggingFaceModelForMulticlassCls is a class that represents a Hugging Face model for multiclass classification.
Expand All @@ -38,6 +44,8 @@ class HuggingFaceModelForMulticlassCls(OTXMulticlassClsModel):
optimizer (OptimizerCallable, optional): The optimizer callable for training the model.
scheduler (LRSchedulerCallable | LRSchedulerListCallable, optional): The learning rate scheduler callable.
torch_compile (bool, optional): Whether to compile the model using TorchScript. Defaults to False.
input_size (tuple[int, int], optional):
Model input size in the order of height and width. Defaults to (224, 224)
Example:
1. API
Expand All @@ -59,6 +67,7 @@ def __init__(
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
input_size: tuple[int, int] = DEFAULT_INPUT_SIZE,
) -> None:
self.model_name = model_name_or_path

Expand All @@ -68,13 +77,23 @@ def __init__(
scheduler=scheduler,
metric=metric,
torch_compile=torch_compile,
input_size=input_size,
)

def _create_model(self) -> nn.Module:
model_config, _ = PretrainedConfig.get_config_dict(self.model_name)
kwargs = {}
if "image_size" in model_config:
kwargs["image_size"] = self.input_size[0]
elif self.input_size != DEFAULT_INPUT_SIZE:
msg = "There is no 'image_size' argument in the model configuration. There may be unexpected results."
logger.warning(msg)

return AutoModelForImageClassification.from_pretrained(
pretrained_model_name_or_path=self.model_name,
num_labels=self.label_info.num_classes,
ignore_mismatched_sizes=True,
**kwargs,
)

def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]:
Expand Down
Loading

0 comments on commit 0b5ed3b

Please sign in to comment.