diff --git a/hydrolib/core/io/rtc/basemodel.py b/hydrolib/core/io/rtc/basemodel.py index bcf49f8e3..e085b04b7 100644 --- a/hydrolib/core/io/rtc/basemodel.py +++ b/hydrolib/core/io/rtc/basemodel.py @@ -1,11 +1,15 @@ from pydantic.class_validators import root_validator from hydrolib.core.basemodel import BaseModel -from typing import Any, get_origin +from typing import get_origin from hydrolib.core.utils import to_list +from pydantic.fields import ModelField class RtcBaseModel(BaseModel): @root_validator(pre=True) def validate(cls, data) -> dict: + if not isinstance(data, dict): + return data + data = cls._get_root_data(data) data = cls._get_list_data(data) @@ -36,11 +40,37 @@ def _get_list_data(cls, data: dict) -> dict: for key, value in data.items(): field = cls.__fields__.get(key) - field_type = field.outer_type_ # returns e.g. List[str] or Optional[List[str]] - non_generic_field_type = get_origin(field_type) or field_type.__origin__ # returns e.g. list + if not field: + continue - if non_generic_field_type is list: + field_type = cls._get_field_type(field) + + if field_type is list: data[key] = to_list(value) return data + @classmethod + def _get_field_type(cls, field: ModelField): + """Gets the non-generic field type of a model field. + For example, if the model field type is `List[str]`, + the returned type will be `list`. + + Args: + field (ModelField): The model field. + + Returns: + The non-generic type of the model field for example `list`. + """ + # `outer_type_` returns e.g. List[str] + field_type = field.outer_type_ + + # `origin` returns e.g. list + origin = get_origin(field_type) + if origin: + return origin + + if hasattr(field_type, "__origin__"): + return field_type.__origin__ + + return None