Skip to content

Commit

Permalink
#226: Refactor RtcBaseModel a little bit.
Browse files Browse the repository at this point in the history
  • Loading branch information
priscavdsluis committed Sep 27, 2022
1 parent a760d11 commit 00852e0
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions hydrolib/core/io/rtc/basemodel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from pydantic.class_validators import root_validator
from hydrolib.core.basemodel import BaseModel
from typing import Any, get_origin
from typing import get_origin
from hydrolib.core.utils import to_list
from pydantic.fields import ModelField

class RtcBaseModel(BaseModel):
@root_validator(pre=True)
def validate(cls, data) -> dict:
if not isinstance(data, dict):
return data

data = cls._get_root_data(data)
data = cls._get_list_data(data)

Expand Down Expand Up @@ -36,11 +40,37 @@ def _get_list_data(cls, data: dict) -> dict:
for key, value in data.items():

field = cls.__fields__.get(key)
field_type = field.outer_type_ # returns e.g. List[str] or Optional[List[str]]
non_generic_field_type = get_origin(field_type) or field_type.__origin__ # returns e.g. list
if not field:
continue

if non_generic_field_type is list:
field_type = cls._get_field_type(field)

if field_type is list:
data[key] = to_list(value)

return data

@classmethod
def _get_field_type(cls, field: ModelField):
"""Gets the non-generic field type of a model field.
For example, if the model field type is `List[str]`,
the returned type will be `list`.
Args:
field (ModelField): The model field.
Returns:
The non-generic type of the model field for example `list`.
"""
# `outer_type_` returns e.g. List[str]
field_type = field.outer_type_

# `origin` returns e.g. list
origin = get_origin(field_type)
if origin:
return origin

if hasattr(field_type, "__origin__"):
return field_type.__origin__

return None

0 comments on commit 00852e0

Please sign in to comment.