Skip to content

Commit

Permalink
Validate decimal meta args
Browse files Browse the repository at this point in the history
  • Loading branch information
faph committed Nov 5, 2023
1 parent 7c9878f commit b314a2d
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/py_avro_schema/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Tuple,
Type,
Union,
cast,
get_args,
get_origin,
get_type_hints,
Expand Down Expand Up @@ -469,12 +470,22 @@ def _size(cls, py_type: Type) -> Tuple[int, int]:
args = get_args(py_type)
if origin is Annotated and args and args[0] is decimal.Decimal:
# Annotated[decimal.Decimal, (4, 2)]
return args[1]
size_args = args[1]
elif origin is decimal.Decimal:
# Deprecated pas.DecimalType[4, 2]
return get_args(py_type)
# Anything else is not a supported decimal type
raise TypeError(f"{py_type} is not a decimal type")
size_args = args
else:
# Anything else is not a supported decimal type
raise TypeError(f"{py_type} is not a decimal type")
if cls._validate_size_tuple(size_args):
return cast(Tuple[int, int], size_args)
else:
raise TypeError(f"{py_type} is not annotated with a tuple of integers (precision, scale)")

@staticmethod
def _validate_size_tuple(tuple_: Tuple) -> bool:
"""Checks whether a given tuple is a tuple of 2 integers (precision, scale)"""
return len(tuple_) == 2 and all(isinstance(item, int) for item in tuple_)

def data(self, names: NamesType) -> JSONObj:
"""Return the schema data"""
Expand Down
49 changes: 49 additions & 0 deletions tests/test_logicals.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import uuid
from typing import Annotated, Any, Dict, List

import pytest

import py_avro_schema as pas
from py_avro_schema._testing import assert_schema

Expand Down Expand Up @@ -99,6 +101,53 @@ def test_annotated_decimal():
assert_schema(py_type, expected)


def test_annotated_decimal_neg_scale():
py_type = Annotated[decimal.Decimal, (5, -2)]
expected = {
"type": "bytes",
"logicalType": "decimal",
"precision": 5,
"scale": -2,
}
assert_schema(py_type, expected)


def test_annotated_decimal_bad_no_tuple():
py_type = Annotated[decimal.Decimal, ...]
expected = {
"type": "bytes",
"logicalType": "decimal",
"precision": 5,
"scale": 2,
}
with pytest.raises(pas.TypeNotSupportedError):
assert_schema(py_type, expected)


def test_annotated_decimal_tuple_wrong_length():
py_type = Annotated[decimal.Decimal, (3, 2, 1)]
expected = {
"type": "bytes",
"logicalType": "decimal",
"precision": 5,
"scale": 2,
}
with pytest.raises(pas.TypeNotSupportedError):
assert_schema(py_type, expected)


def test_annotated_decimal_tuple_wrong_type():
py_type = Annotated[decimal.Decimal, ("a", 1)]
expected = {
"type": "bytes",
"logicalType": "decimal",
"precision": 5,
"scale": 2,
}
with pytest.raises(pas.TypeNotSupportedError):
assert_schema(py_type, expected)


def test_multiple_decimals():
# Test the magic with _GenericAlias!
py_type_1 = pas.DecimalType[5, 2]
Expand Down

0 comments on commit b314a2d

Please sign in to comment.