diff --git a/hydrolib/core/basemodel.py b/hydrolib/core/basemodel.py index 4e413023c..67eff9c44 100644 --- a/hydrolib/core/basemodel.py +++ b/hydrolib/core/basemodel.py @@ -6,9 +6,11 @@ """ import logging from abc import ABC, abstractclassmethod +from collections.abc import Mapping from contextlib import contextmanager from contextvars import ContextVar from enum import IntEnum +from inspect import isclass from pathlib import Path from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar from warnings import warn @@ -19,7 +21,7 @@ from pydantic.fields import PrivateAttr from hydrolib.core.io.base import DummmyParser, DummySerializer -from hydrolib.core.utils import to_key +from hydrolib.core.utils import to_key, to_list logger = logging.getLogger(__name__) @@ -54,6 +56,73 @@ def __init__(self, **data: Any) -> None: # If there is an identifier, include this in the ValidationError messages. raise ValidationError([ErrorWrapper(e, loc=identifier)], self.__class__) + @classmethod + def construct(cls, _fields_set=None, **values): + """ + Creates a new model setting __dict__ and __fields_set__ from trusted or pre-validated data. + Default values are respected, but no other validation is performed. + Behaves as if `Config.extra = 'allow'` was set since it adds all passed values + Nested model fields are supported, and recursively populated using this same construct(). + """ + + # Code adapted from: https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836 + m = cls.__new__(cls) + + # Start by allowing all input values + fields_values = values.copy() + + config = cls.__config__ + + # For all known class fields, replace values by a nested BaseModel when necessary + for name, field in cls.__fields__.items(): + key = field.alias + if ( + key not in values and config.allow_population_by_field_name + ): # Added this to allow population by field name + key = name + + if key in values: + if ( + values[key] is None and not field.required + ): # Moved this check since None value can be passed for Optional nested field + fields_values[name] = field.get_default() + else: + if (isclass(field.type_)) and issubclass(field.type_, BaseModel): + # If field is a list of BaseModels (and given value may or may not be a list) + if field.shape == 2: + fields_values[name] = [ + field.type_.construct(**e) + if isinstance(e, Mapping) + else field.type_.construct(e) + for e in to_list(values[key]) + ] + else: + if isinstance(values[key], Mapping): + fields_values[name] = field.outer_type_.construct( + **values[key] + ) + else: + # Input value need not be a dict, some BaseModel subclass + # may support this as a separate input argument. + fields_values[name] = field.outer_type_.construct( + values[key] + ) + else: + # No BaseModel, simply set value + fields_values[name] = values[key] + if key != name: + # Remove earlier set-by-key value, only leave set-by-aliasname + del fields_values[key] + elif not field.required: + fields_values[name] = field.get_default() + + object.__setattr__(m, "__dict__", fields_values) + if _fields_set is None: + _fields_set = set(values.keys()) + object.__setattr__(m, "__fields_set__", _fields_set) + m._init_private_attributes() + return m + def is_file_link(self) -> bool: """Generic attribute for models backed by a file.""" return False @@ -555,6 +624,21 @@ def _post_init_load(self) -> None: """ pass + @classmethod + def construct(cls, filepath: Optional[Path] = None, _fields_set=None, **values): + """ + Creates a new model setting __dict__ and __fields_set__ from trusted or pre-validated data. + Default values are respected, but no other validation is performed. + Behaves as if `Config.extra = 'allow'` was set since it adds all passed values + + This implementation takes specific care of a FileModel's filepath field. + """ + + # Consistent with __init__(), explicitly assign filepath input argument to the 'filepath' field. + values.update({"filepath": filepath}) + m = super(FileModel, cls).construct(**values) + return m + @property def _resolved_filepath(self) -> Optional[Path]: if self.filepath is None: diff --git a/tests/test_basemodel.py b/tests/test_basemodel.py index aa6151e5a..7c5d22935 100644 --- a/tests/test_basemodel.py +++ b/tests/test_basemodel.py @@ -284,6 +284,28 @@ def test_synchronize_filepaths_updates_save_location_correctly(self): assert forcing.save_location == self._resolve(forcing.filepath, other_dir) # type: ignore assert not forcing.save_location.is_file() # type: ignore + def test_construc_model_without_validation(self): + # First *with* validation: + model_validated = FMModel(self._reference_model_path) + + output_dir = ( + test_output_dir / self.test_construc_model_without_validation.__name__ + ) + + model_validated.save(filepath=output_dir / "validated.mdu", recurse=False) + + # Second, manually parse input and construct model, to bypass validation: + data = FMModel()._parse(self._reference_model_path) + model_unvalidated = FMModel.construct(**data) + model_unvalidated.save(filepath=output_dir / "unvalidated.mdu", recurse=False) + + files_in_output = list(output_dir.glob("**/*")) + assert len(files_in_output) == 2 + assert files_in_output[0] == output_dir / "unvalidated.mdu" + assert files_in_output[1] == output_dir / "validated.mdu" + + # ^^ No true test here, but intended to inspect produced MDU files by developer. + class TestContextManagerFileLoadContext: def test_context_is_created_and_disposed_properly(self):