diff --git a/src/py_avro_schema/_schemas.py b/src/py_avro_schema/_schemas.py index 0190231..1d8f487 100644 --- a/src/py_avro_schema/_schemas.py +++ b/src/py_avro_schema/_schemas.py @@ -456,12 +456,11 @@ class DecimalSchema(Schema): @classmethod def handles_type(cls, py_type: Type) -> bool: """Whether this schema class can represent a given Python class""" - try: - # A decimal.Decimal type with annotations indicating precision and optionally scale. - cls._decimal_meta(py_type) - return True - except TypeError: - return False + # Here we are greedy: we catch any decimal.Decimal. However, data() might fail if the annotation is not correct. + return ( + _is_class(py_type, decimal.Decimal) # Using DecimalMeta + or get_origin(py_type) is decimal.Decimal # Deprecated: DecimalType + ) @classmethod def _decimal_meta(cls, py_type: Type) -> py_avro_schema._typing.DecimalMeta: @@ -473,8 +472,8 @@ def _decimal_meta(cls, py_type: Type) -> py_avro_schema._typing.DecimalMeta: try: # At least one of the annotations should be a DecimalMeta object (meta,) = (arg for arg in args[1:] if isinstance(arg, py_avro_schema._typing.DecimalMeta)) - except ValueError: # not enough values to unpack - raise TypeError(f"{py_type} is not annotated with a 'py_avro_schema.DecimalMeta` object") + except ValueError: # not enough/too many values to unpack + raise TypeError(f"{py_type} is not annotated with a single 'py_avro_schema.DecimalMeta' object") return meta elif origin is decimal.Decimal: # Deprecated pas.DecimalType[4, 2] diff --git a/src/py_avro_schema/_typing.py b/src/py_avro_schema/_typing.py index ef3e675..f479b63 100644 --- a/src/py_avro_schema/_typing.py +++ b/src/py_avro_schema/_typing.py @@ -21,7 +21,7 @@ import typeguard -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) # Needs to be hashable to work in unioned types class DecimalMeta: """ Meta data to annotate a :class:`decimal.Decimal` with precision and scale information @@ -39,6 +39,20 @@ class DecimalMeta: precision: int scale: Optional[int] = None + def __post_init__(self): + """ + Validate input data + + See: https://avro.apache.org/docs/1.11.1/specification/#decimal + """ + if self.precision < 1: + raise ValueError(f"Precision must be at least 1. Given value: {self.precision}") + if self.scale is not None: + if self.scale < 0: + raise ValueError(f"Scale must be positive. Given value: {self.scale}") + elif self.scale > self.precision: + raise ValueError(f"Scale must be no more than precision of {self.precision}. Given value: {self.scale}") + class DecimalType: """ diff --git a/tests/test_dataclass.py b/tests/test_dataclass.py index c944f63..0ee9cd5 100644 --- a/tests/test_dataclass.py +++ b/tests/test_dataclass.py @@ -625,7 +625,7 @@ class PyType: def test_decimal_field_default(): @dataclasses.dataclass class PyType: - field_a: pas.DecimalType[4, 2] = decimal.Decimal("3.14") + field_a: Annotated[decimal.Decimal, pas.DecimalMeta(4, 2)] = decimal.Decimal("3.14") expected = { "type": "record", @@ -649,7 +649,7 @@ class PyType: def test_decimal_field_default_precision_too_big(): @dataclasses.dataclass class PyType: - field_a: pas.DecimalType[4, 2] = decimal.Decimal("123.45") + field_a: Annotated[decimal.Decimal, pas.DecimalMeta(4, 2)] = decimal.Decimal("123.45") with pytest.raises( ValueError, match="Default value 123.45 has precision 5 which is greater than the schema's precision 4" @@ -660,7 +660,7 @@ class PyType: def test_decimal_field_default_scale_too_big(): @dataclasses.dataclass class PyType: - field_a: pas.DecimalType[4, 2] = decimal.Decimal("1.234") + field_a: Annotated[decimal.Decimal, pas.DecimalMeta(4, 2)] = decimal.Decimal("1.234") with pytest.raises(ValueError, match="Default value 1.234 has scale 3 which is greater than the schema's scale 2"): assert_schema(PyType, {}) diff --git a/tests/test_logicals.py b/tests/test_logicals.py index 337dff8..71d9fe4 100644 --- a/tests/test_logicals.py +++ b/tests/test_logicals.py @@ -11,8 +11,9 @@ import datetime import decimal +import re import uuid -from typing import Annotated, Any, Dict, List +from typing import Annotated, Any, Dict, List, Union import pytest @@ -149,17 +150,6 @@ def test_annotated_decimal_default_scale(): assert_schema(py_type, expected) -def test_annotated_decimal_neg_scale(): - py_type = Annotated[decimal.Decimal, pas.DecimalMeta(precision=5, scale=-2)] - expected = { - "type": "bytes", - "logicalType": "decimal", - "precision": 5, - "scale": -2, - } - assert_schema(py_type, expected) - - def test_annotated_decimal_additional_meta(): py_type = Annotated[decimal.Decimal, "something else", pas.DecimalMeta(precision=5, scale=2)] expected = { @@ -171,15 +161,53 @@ def test_annotated_decimal_additional_meta(): assert_schema(py_type, expected) +def test_annotated_decimal_in_union(): + py_type = Union[Annotated[decimal.Decimal, pas.DecimalMeta(precision=5, scale=2)], None] + expected = [ + { + "type": "bytes", + "logicalType": "decimal", + "precision": 5, + "scale": 2, + }, + "null", + ] + assert_schema(py_type, expected) + + def test_annotated_decimal_no_meta(): py_type = Annotated[decimal.Decimal, ...] - with pytest.raises(pas.TypeNotSupportedError): + with pytest.raises( + TypeError, + match=re.escape( + "typing.Annotated[decimal.Decimal, Ellipsis] is not annotated with a single 'py_avro_schema.DecimalMeta' " + "object" + ), + ): + assert_schema(py_type, {}) + + +def test_annotated_decimal_2_meta(): + py_type = Annotated[decimal.Decimal, pas.DecimalMeta(precision=5, scale=2), pas.DecimalMeta(precision=4)] + with pytest.raises( + TypeError, + match=re.escape( + "typing.Annotated[decimal.Decimal, DecimalMeta(precision=5, scale=2), DecimalMeta(precision=4, scale=None)]" + " is not annotated with a single 'py_avro_schema.DecimalMeta' object" + ), + ): assert_schema(py_type, {}) def test_annotated_decimal_tuple(): py_type = Annotated[decimal.Decimal, (5, 2)] - with pytest.raises(pas.TypeNotSupportedError): + with pytest.raises( + TypeError, + match=re.escape( + "typing.Annotated[decimal.Decimal, (5, 2)] is not annotated with a single 'py_avro_schema.DecimalMeta' " + "object" + ), + ): assert_schema(py_type, {}) diff --git a/tests/test_typing.py b/tests/test_typing.py index dec3b25..ec658d5 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -15,7 +15,39 @@ import pytest import typeguard -from py_avro_schema._typing import DecimalType +from py_avro_schema._typing import DecimalMeta, DecimalType + + +def test_decimal_meta(): + meta = DecimalMeta(precision=4, scale=2) + assert meta.precision == 4 + assert meta.scale == 2 + + +def test_decimal_meta_hashable(): + meta = DecimalMeta(precision=4, scale=2) + assert hash(meta) + + +def test_decimal_default_scale(): + meta = DecimalMeta(precision=4) + assert meta.precision == 4 + assert meta.scale is None + + +def test_decimal_precision_must_be_positive(): + with pytest.raises(ValueError, match=re.escape("Precision must be at least 1. Given value: 0")): + DecimalMeta(precision=0) + + +def test_decimal_scale_must_be_positive(): + with pytest.raises(ValueError, match=re.escape("Scale must be positive. Given value: -1")): + DecimalMeta(precision=4, scale=-1) + + +def test_decimal_scale_must_not_exceed_precision(): + with pytest.raises(ValueError, match=re.escape("Scale must be no more than precision of 4. Given value: 5")): + DecimalMeta(precision=4, scale=5) def test_decimal_type():