Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
sooahleex committed Oct 15, 2024
1 parent 5a9fe0f commit 280eb39
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions tests/unit/components/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: MIT

from typing import List, Tuple
from typing import List, Optional, Tuple

import pytest

Expand All @@ -11,7 +11,7 @@
from datumaro.components.dataset import Dataset
from datumaro.components.dataset_base import DatasetItem
from datumaro.components.launcher import Launcher
from datumaro.components.transformer import ModelTransform
from datumaro.components.transformer import ModelTransform, TabularTransform


class MockLauncher(Launcher):
Expand Down Expand Up @@ -64,3 +64,30 @@ def test_model_transform(
assert item.annotations == [Annotation(id=0), Annotation(id=1)]
else:
assert item.annotations == [Annotation(id=1)]


class TabularTransformTest:
@pytest.fixture
def fxt_dataset(self):
return Dataset.from_iterable(
[DatasetItem(id=f"item_{i}", annotations=[Annotation(id=0)]) for i in range(10)]
)

@pytest.mark.parametrize("batch_size", [1, 10])
@pytest.mark.parametrize("num_workers", [0, 2])
def test_tabular_transform(self, fxt_dataset, batch_size, num_workers):
class MockTabularTransform(TabularTransform):
def transform_item(self, item: DatasetItem) -> Optional[DatasetItem]:
# Mock transformation logic
item.annotations.append(Annotation(id=1))
return item

transform = MockTabularTransform(
extractor=fxt_dataset,
batch_size=batch_size,
num_workers=num_workers,
)

for idx, item in enumerate(transform):
assert item.id == f"item_{idx}"
assert item.annotations == [Annotation(id=0), Annotation(id=1)]

0 comments on commit 280eb39

Please sign in to comment.