Skip to content

Commit

Permalink
✅ [Pass] test in multiclass label&dynamic shape
Browse files Browse the repository at this point in the history
  • Loading branch information
henrytsui000 committed Nov 5, 2024
1 parent 3092710 commit 2522f72
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/test_tools/test_data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_mosaic():

# Mock parent with image_size and get_more_data method
class MockParent:
image_size = (100, 100)
base_size = 100

def get_more_data(self, num_images):
return [(img, boxes) for _ in range(num_images)]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_tools/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def test_training_data_loader_correctness(train_dataloader: DataLoader):
def test_validation_data_loader_correctness(validation_dataloader: DataLoader):
batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
assert batch_size == 4
assert images.shape == (4, 3, 640, 640)
assert images.shape == (4, 3, 512, 768)
assert targets.shape == (4, 18, 5)
assert reverse_tensors.shape == (4, 5)
expected_paths = [
Path("tests/data/images/val/000000151480.jpg"),
Path("tests/data/images/val/000000284106.jpg"),
Path("tests/data/images/val/000000323571.jpg"),
Path("tests/data/images/val/000000151480.jpg"),
Path("tests/data/images/val/000000570456.jpg"),
Path("tests/data/images/val/000000323571.jpg"),
]
assert list(image_paths) == list(expected_paths)

Expand Down
12 changes: 10 additions & 2 deletions tests/test_utils/test_bounding_box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_bbox_nms():
dtype=float32,
)

nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)
nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5, max_bbox=400)

# Batch 1:
# - box 1 is kept with class 0 as it has a higher confidence than box 4 i.e. box 4 is filtered out
Expand All @@ -197,16 +197,24 @@ def test_bbox_nms():
[
[0.0, 0.0, 0.0, 160.0, 120.0, 0.6682],
[1.0, 160.0, 120.0, 320.0, 240.0, 0.6457],
[0.0, 160.0, 120.0, 320.0, 240.0, 0.5744],
[2.0, 0.0, 0.0, 160.0, 120.0, 0.5498],
[1.0, 16.0, 12.0, 176.0, 132.0, 0.5498],
[2.0, 160.0, 120.0, 320.0, 240.0, 0.5250],
],
[
[0.0, 16.0, 12.0, 176.0, 132.0, 0.6900],
[2.0, 0.0, 120.0, 160.0, 240.0, 0.6570],
[1.0, 0.0, 0.0, 160.0, 120.0, 0.5622],
[2.0, 0.0, 0.0, 160.0, 120.0, 0.5498],
[1.0, 0.0, 120.0, 160.0, 240.0, 0.5498],
[0.0, 0.0, 120.0, 160.0, 240.0, 0.5374],
],
]
)

output = bbox_nms(cls_dist, bbox, nms_cfg)

print(output)
for out, exp in zip(output, expected_output):
assert allclose(out, exp, atol=1e-4), f"Output: {out} Expected: {exp}"

Expand Down

0 comments on commit 2522f72

Please sign in to comment.