From f4f86dfc837f2b18130f9adacfc18e03ae3f8fb6 Mon Sep 17 00:00:00 2001 From: Reuven <44209964+reuvenperetz@users.noreply.github.com> Date: Mon, 16 Sep 2024 16:20:47 +0300 Subject: [PATCH] Raise exception for invalid onnx metadata values types (#104) Co-authored-by: reuvenp --- mct_quantizers/pytorch/metadata.py | 3 +++ tests/pytorch_tests/test_pytorch_load_model.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/mct_quantizers/pytorch/metadata.py b/mct_quantizers/pytorch/metadata.py index 7ffbbfb..dad71a6 100644 --- a/mct_quantizers/pytorch/metadata.py +++ b/mct_quantizers/pytorch/metadata.py @@ -82,6 +82,7 @@ def get_metadata(model): def add_onnx_metadata(model: onnx.ModelProto, metadata: Dict): """ Init the metadata dictionary and verify its compliance, then add it to the model metadata_props. + Metadata values have to be of byte type. Args: model (ModelProto): onnx model to add the metadata to. @@ -101,6 +102,8 @@ def add_onnx_metadata(model: onnx.ModelProto, metadata: Dict): for k, v in metadata.items(): meta = model.metadata_props.add() + if not isinstance(v, (bytes, str)): + Logger.critical(f"ONNX metadata must be of byte type, but {v} has type {type(v)}") meta.key, meta.value = k, v return model diff --git a/tests/pytorch_tests/test_pytorch_load_model.py b/tests/pytorch_tests/test_pytorch_load_model.py index c273d96..b62a0ea 100644 --- a/tests/pytorch_tests/test_pytorch_load_model.py +++ b/tests/pytorch_tests/test_pytorch_load_model.py @@ -255,3 +255,8 @@ def test_save_and_load_metadata(self): self.assertTrue(get_onnx_metadata(onnx_model) == get_onnx_metadata(loaded_onnx_model)) os.remove(tmp_onnx_file) + + # Make sure assertion is raised in cases of invalid metadata value type. + with self.assertRaisesRegex(Exception, r"ONNX metadata must be of byte type, but 4.2 has type "): + add_onnx_metadata(onnx_model, {'test': 'test456', 'foo': 4.2}) +