From e04645ffa4243492ba2f234366ac4e13e3f1e702 Mon Sep 17 00:00:00 2001 From: faph Date: Fri, 17 Nov 2023 11:20:34 +0000 Subject: [PATCH] Support annotated enums --- src/py_avro_schema/_schemas.py | 4 +++- tests/test_primitives.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/py_avro_schema/_schemas.py b/src/py_avro_schema/_schemas.py index ba33ca3..2cdcc81 100644 --- a/src/py_avro_schema/_schemas.py +++ b/src/py_avro_schema/_schemas.py @@ -674,6 +674,7 @@ def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Opti :param options: Schema generation options. """ super().__init__(py_type, namespace=namespace, options=options) + py_type = _type_from_annotated(py_type) self.name = py_type.__name__ def __str__(self): @@ -707,7 +708,7 @@ class EnumSchema(NamedSchema): @classmethod def handles_type(cls, py_type: Type) -> bool: """Whether this schema class can represent a given Python class""" - return inspect.isclass(py_type) and issubclass(py_type, enum.Enum) + return _is_class(py_type, enum.Enum) def __init__(self, py_type: Type[enum.Enum], namespace: Optional[str] = None, options: Option = Option(0)): """ @@ -718,6 +719,7 @@ def __init__(self, py_type: Type[enum.Enum], namespace: Optional[str] = None, op :param options: Schema generation options. """ super().__init__(py_type, namespace=namespace, options=options) + py_type = _type_from_annotated(py_type) self.symbols = [member.value for member in py_type] symbol_types = {type(symbol) for symbol in self.symbols} if symbol_types != {str}: diff --git a/tests/test_primitives.py b/tests/test_primitives.py index df9d1f8..f25afcb 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -366,6 +366,23 @@ class PyType(enum.Enum): assert_schema(PyType, expected) +def test_enum_annotated(): + class PyType(enum.Enum): + RED = "RED" + GREEN = "GREEN" + + expected = { + "type": "enum", + "name": "PyType", + "symbols": [ + "RED", + "GREEN", + ], + "default": "RED", + } + assert_schema(Annotated[PyType, ...], expected) + + def test_enum_str_subclass(): class PyType(str, enum.Enum): RED = "RED"