diff --git a/src/py_avro_schema/_schemas.py b/src/py_avro_schema/_schemas.py index 2d6a5f4..5690ba3 100644 --- a/src/py_avro_schema/_schemas.py +++ b/src/py_avro_schema/_schemas.py @@ -484,8 +484,13 @@ def _size(cls, py_type: Type) -> Tuple[int, int]: @staticmethod def _validate_size_tuple(tuple_: Tuple) -> bool: - """Checks whether a given tuple is a tuple of 2 integers (precision, scale)""" - return len(tuple_) == 2 and all(isinstance(item, int) for item in tuple_) + """ + Checks whether a given tuple is a tuple of (precision, scale) + + Must be integers. + Scale must not be greater than precision + """ + return len(tuple_) == 2 and all(isinstance(item, int) for item in tuple_) and tuple_[0] >= tuple_[1] def data(self, names: NamesType) -> JSONObj: """Return the schema data""" diff --git a/tests/test_logicals.py b/tests/test_logicals.py index ea60353..b8f589a 100644 --- a/tests/test_logicals.py +++ b/tests/test_logicals.py @@ -112,40 +112,28 @@ def test_annotated_decimal_neg_scale(): assert_schema(py_type, expected) -def test_annotated_decimal_bad_no_tuple(): +def test_annotated_decimal_scale_too_big(): + py_type = Annotated[decimal.Decimal, (5, 6)] + with pytest.raises(pas.TypeNotSupportedError): + assert_schema(py_type, {}) + + +def test_annotated_decimal_no_tuple(): py_type = Annotated[decimal.Decimal, ...] - expected = { - "type": "bytes", - "logicalType": "decimal", - "precision": 5, - "scale": 2, - } with pytest.raises(pas.TypeNotSupportedError): - assert_schema(py_type, expected) + assert_schema(py_type, {}) def test_annotated_decimal_tuple_wrong_length(): py_type = Annotated[decimal.Decimal, (3, 2, 1)] - expected = { - "type": "bytes", - "logicalType": "decimal", - "precision": 5, - "scale": 2, - } with pytest.raises(pas.TypeNotSupportedError): - assert_schema(py_type, expected) + assert_schema(py_type, {}) def test_annotated_decimal_tuple_wrong_type(): py_type = Annotated[decimal.Decimal, ("a", 1)] - expected = { - "type": "bytes", - "logicalType": "decimal", - "precision": 5, - "scale": 2, - } with pytest.raises(pas.TypeNotSupportedError): - assert_schema(py_type, expected) + assert_schema(py_type, {}) def test_multiple_decimals():