From 612df44e33a7d600cd20599094f5c845aea371c7 Mon Sep 17 00:00:00 2001 From: Arthur van Dam Date: Mon, 31 Jan 2022 16:08:00 +0100 Subject: [PATCH] Add enhanced implementations of construct() function, to allow bypassing validation of nested BaseModel input files. I'm not too happy yet with how much custom code is required to bypass validation, but this is still in line with the recommendations typically found online. --- hydrolib/core/basemodel.py | 86 +++++++++++++++++++++++++++++++++++++- tests/test_basemodel.py | 22 ++++++++++ 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/hydrolib/core/basemodel.py b/hydrolib/core/basemodel.py index 4e413023c..67eff9c44 100644 --- a/hydrolib/core/basemodel.py +++ b/hydrolib/core/basemodel.py @@ -6,9 +6,11 @@ """ import logging from abc import ABC, abstractclassmethod +from collections.abc import Mapping from contextlib import contextmanager from contextvars import ContextVar from enum import IntEnum +from inspect import isclass from pathlib import Path from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar from warnings import warn @@ -19,7 +21,7 @@ from pydantic.fields import PrivateAttr from hydrolib.core.io.base import DummmyParser, DummySerializer -from hydrolib.core.utils import to_key +from hydrolib.core.utils import to_key, to_list logger = logging.getLogger(__name__) @@ -54,6 +56,73 @@ def __init__(self, **data: Any) -> None: # If there is an identifier, include this in the ValidationError messages. raise ValidationError([ErrorWrapper(e, loc=identifier)], self.__class__) + @classmethod + def construct(cls, _fields_set=None, **values): + """ + Creates a new model setting __dict__ and __fields_set__ from trusted or pre-validated data. + Default values are respected, but no other validation is performed. + Behaves as if `Config.extra = 'allow'` was set since it adds all passed values + Nested model fields are supported, and recursively populated using this same construct(). + """ + + # Code adapted from: https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836 + m = cls.__new__(cls) + + # Start by allowing all input values + fields_values = values.copy() + + config = cls.__config__ + + # For all known class fields, replace values by a nested BaseModel when necessary + for name, field in cls.__fields__.items(): + key = field.alias + if ( + key not in values and config.allow_population_by_field_name + ): # Added this to allow population by field name + key = name + + if key in values: + if ( + values[key] is None and not field.required + ): # Moved this check since None value can be passed for Optional nested field + fields_values[name] = field.get_default() + else: + if (isclass(field.type_)) and issubclass(field.type_, BaseModel): + # If field is a list of BaseModels (and given value may or may not be a list) + if field.shape == 2: + fields_values[name] = [ + field.type_.construct(**e) + if isinstance(e, Mapping) + else field.type_.construct(e) + for e in to_list(values[key]) + ] + else: + if isinstance(values[key], Mapping): + fields_values[name] = field.outer_type_.construct( + **values[key] + ) + else: + # Input value need not be a dict, some BaseModel subclass + # may support this as a separate input argument. + fields_values[name] = field.outer_type_.construct( + values[key] + ) + else: + # No BaseModel, simply set value + fields_values[name] = values[key] + if key != name: + # Remove earlier set-by-key value, only leave set-by-aliasname + del fields_values[key] + elif not field.required: + fields_values[name] = field.get_default() + + object.__setattr__(m, "__dict__", fields_values) + if _fields_set is None: + _fields_set = set(values.keys()) + object.__setattr__(m, "__fields_set__", _fields_set) + m._init_private_attributes() + return m + def is_file_link(self) -> bool: """Generic attribute for models backed by a file.""" return False @@ -555,6 +624,21 @@ def _post_init_load(self) -> None: """ pass + @classmethod + def construct(cls, filepath: Optional[Path] = None, _fields_set=None, **values): + """ + Creates a new model setting __dict__ and __fields_set__ from trusted or pre-validated data. + Default values are respected, but no other validation is performed. + Behaves as if `Config.extra = 'allow'` was set since it adds all passed values + + This implementation takes specific care of a FileModel's filepath field. + """ + + # Consistent with __init__(), explicitly assign filepath input argument to the 'filepath' field. + values.update({"filepath": filepath}) + m = super(FileModel, cls).construct(**values) + return m + @property def _resolved_filepath(self) -> Optional[Path]: if self.filepath is None: diff --git a/tests/test_basemodel.py b/tests/test_basemodel.py index aa6151e5a..7c5d22935 100644 --- a/tests/test_basemodel.py +++ b/tests/test_basemodel.py @@ -284,6 +284,28 @@ def test_synchronize_filepaths_updates_save_location_correctly(self): assert forcing.save_location == self._resolve(forcing.filepath, other_dir) # type: ignore assert not forcing.save_location.is_file() # type: ignore + def test_construc_model_without_validation(self): + # First *with* validation: + model_validated = FMModel(self._reference_model_path) + + output_dir = ( + test_output_dir / self.test_construc_model_without_validation.__name__ + ) + + model_validated.save(filepath=output_dir / "validated.mdu", recurse=False) + + # Second, manually parse input and construct model, to bypass validation: + data = FMModel()._parse(self._reference_model_path) + model_unvalidated = FMModel.construct(**data) + model_unvalidated.save(filepath=output_dir / "unvalidated.mdu", recurse=False) + + files_in_output = list(output_dir.glob("**/*")) + assert len(files_in_output) == 2 + assert files_in_output[0] == output_dir / "unvalidated.mdu" + assert files_in_output[1] == output_dir / "validated.mdu" + + # ^^ No true test here, but intended to inspect produced MDU files by developer. + class TestContextManagerFileLoadContext: def test_context_is_created_and_disposed_properly(self):