Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enhanced implementations of construct() function, to allow bypass validation… #209

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion hydrolib/core/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
"""
import logging
from abc import ABC, abstractclassmethod
from collections.abc import Mapping
from contextlib import contextmanager
from contextvars import ContextVar
from enum import IntEnum
from inspect import isclass
from pathlib import Path
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
from warnings import warn
Expand All @@ -19,7 +21,7 @@
from pydantic.fields import PrivateAttr

from hydrolib.core.io.base import DummmyParser, DummySerializer
from hydrolib.core.utils import to_key
from hydrolib.core.utils import to_key, to_list

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,6 +56,73 @@ def __init__(self, **data: Any) -> None:
# If there is an identifier, include this in the ValidationError messages.
raise ValidationError([ErrorWrapper(e, loc=identifier)], self.__class__)

@classmethod
def construct(cls, _fields_set=None, **values):
"""
Creates a new model setting __dict__ and __fields_set__ from trusted or pre-validated data.
Default values are respected, but no other validation is performed.
Behaves as if `Config.extra = 'allow'` was set since it adds all passed values
Nested model fields are supported, and recursively populated using this same construct().
"""

# Code adapted from: https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836
m = cls.__new__(cls)

# Start by allowing all input values
fields_values = values.copy()

config = cls.__config__

# For all known class fields, replace values by a nested BaseModel when necessary
for name, field in cls.__fields__.items():
key = field.alias
if (
key not in values and config.allow_population_by_field_name
): # Added this to allow population by field name
key = name

if key in values:
if (
values[key] is None and not field.required
): # Moved this check since None value can be passed for Optional nested field
fields_values[name] = field.get_default()
else:
if (isclass(field.type_)) and issubclass(field.type_, BaseModel):
# If field is a list of BaseModels (and given value may or may not be a list)
if field.shape == 2:
fields_values[name] = [
field.type_.construct(**e)
if isinstance(e, Mapping)
else field.type_.construct(e)
for e in to_list(values[key])
]
else:
if isinstance(values[key], Mapping):
fields_values[name] = field.outer_type_.construct(
**values[key]
)
else:
# Input value need not be a dict, some BaseModel subclass
# may support this as a separate input argument.
fields_values[name] = field.outer_type_.construct(
values[key]
)
else:
# No BaseModel, simply set value
fields_values[name] = values[key]
if key != name:
# Remove earlier set-by-key value, only leave set-by-aliasname
del fields_values[key]
elif not field.required:
fields_values[name] = field.get_default()

object.__setattr__(m, "__dict__", fields_values)
if _fields_set is None:
_fields_set = set(values.keys())
object.__setattr__(m, "__fields_set__", _fields_set)
m._init_private_attributes()
return m

def is_file_link(self) -> bool:
"""Generic attribute for models backed by a file."""
return False
Expand Down Expand Up @@ -555,6 +624,21 @@ def _post_init_load(self) -> None:
"""
pass

@classmethod
def construct(cls, filepath: Optional[Path] = None, _fields_set=None, **values):
"""
Creates a new model setting __dict__ and __fields_set__ from trusted or pre-validated data.
Default values are respected, but no other validation is performed.
Behaves as if `Config.extra = 'allow'` was set since it adds all passed values

This implementation takes specific care of a FileModel's filepath field.
"""

# Consistent with __init__(), explicitly assign filepath input argument to the 'filepath' field.
values.update({"filepath": filepath})
m = super(FileModel, cls).construct(**values)
return m

@property
def _resolved_filepath(self) -> Optional[Path]:
if self.filepath is None:
Expand Down
22 changes: 22 additions & 0 deletions tests/test_basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,28 @@ def test_synchronize_filepaths_updates_save_location_correctly(self):
assert forcing.save_location == self._resolve(forcing.filepath, other_dir) # type: ignore
assert not forcing.save_location.is_file() # type: ignore

def test_construc_model_without_validation(self):
# First *with* validation:
model_validated = FMModel(self._reference_model_path)

output_dir = (
test_output_dir / self.test_construc_model_without_validation.__name__
)

model_validated.save(filepath=output_dir / "validated.mdu", recurse=False)

# Second, manually parse input and construct model, to bypass validation:
data = FMModel()._parse(self._reference_model_path)
model_unvalidated = FMModel.construct(**data)
model_unvalidated.save(filepath=output_dir / "unvalidated.mdu", recurse=False)

files_in_output = list(output_dir.glob("**/*"))
assert len(files_in_output) == 2
assert files_in_output[0] == output_dir / "unvalidated.mdu"
assert files_in_output[1] == output_dir / "validated.mdu"

# ^^ No true test here, but intended to inspect produced MDU files by developer.


class TestContextManagerFileLoadContext:
def test_context_is_created_and_disposed_properly(self):
Expand Down