Skip to content

Commit

Permalink
Support annotated enums
Browse files Browse the repository at this point in the history
  • Loading branch information
faph committed Nov 17, 2023
1 parent 441eb97 commit e04645f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/py_avro_schema/_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Opti
:param options: Schema generation options.
"""
super().__init__(py_type, namespace=namespace, options=options)
py_type = _type_from_annotated(py_type)
self.name = py_type.__name__

def __str__(self):
Expand Down Expand Up @@ -707,7 +708,7 @@ class EnumSchema(NamedSchema):
@classmethod
def handles_type(cls, py_type: Type) -> bool:
"""Whether this schema class can represent a given Python class"""
return inspect.isclass(py_type) and issubclass(py_type, enum.Enum)
return _is_class(py_type, enum.Enum)

def __init__(self, py_type: Type[enum.Enum], namespace: Optional[str] = None, options: Option = Option(0)):
"""
Expand All @@ -718,6 +719,7 @@ def __init__(self, py_type: Type[enum.Enum], namespace: Optional[str] = None, op
:param options: Schema generation options.
"""
super().__init__(py_type, namespace=namespace, options=options)
py_type = _type_from_annotated(py_type)
self.symbols = [member.value for member in py_type]
symbol_types = {type(symbol) for symbol in self.symbols}
if symbol_types != {str}:
Expand Down
17 changes: 17 additions & 0 deletions tests/test_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,23 @@ class PyType(enum.Enum):
assert_schema(PyType, expected)


def test_enum_annotated():
class PyType(enum.Enum):
RED = "RED"
GREEN = "GREEN"

expected = {
"type": "enum",
"name": "PyType",
"symbols": [
"RED",
"GREEN",
],
"default": "RED",
}
assert_schema(Annotated[PyType, ...], expected)


def test_enum_str_subclass():
class PyType(str, enum.Enum):
RED = "RED"
Expand Down

0 comments on commit e04645f

Please sign in to comment.