From 77cd08f2b55964f7d5a0074809f3cbb3f6258aee Mon Sep 17 00:00:00 2001 From: scosman Date: Mon, 7 Oct 2024 20:19:05 -0400 Subject: [PATCH] Add the ability to do nested validation and saving for our REST APIs --- libs/core/kiln_ai/datamodel/__init__.py | 26 +-- libs/core/kiln_ai/datamodel/basemodel.py | 83 +++++++++- libs/core/kiln_ai/datamodel/test_models.py | 3 +- .../kiln_ai/datamodel/test_nested_save.py | 151 ++++++++++++++++++ 4 files changed, 251 insertions(+), 12 deletions(-) create mode 100644 libs/core/kiln_ai/datamodel/test_nested_save.py diff --git a/libs/core/kiln_ai/datamodel/__init__.py b/libs/core/kiln_ai/datamodel/__init__.py index 6db1c87..a5d68e5 100644 --- a/libs/core/kiln_ai/datamodel/__init__.py +++ b/libs/core/kiln_ai/datamodel/__init__.py @@ -2,14 +2,14 @@ import json from enum import Enum, IntEnum -from typing import TYPE_CHECKING, Dict, Self +from typing import TYPE_CHECKING, Dict, Self, Type import jsonschema import jsonschema.exceptions from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str from pydantic import Field, model_validator -from .basemodel import ID_TYPE, KilnBaseModel, KilnParentedModel +from .basemodel import ID_TYPE, KilnBaseModel, KilnParentedModel, KilnParentModel from .json_schema import validate_schema if TYPE_CHECKING: @@ -191,7 +191,7 @@ class ExampleSource(str, Enum): synthetic = "synthetic" -class Example(KilnParentedModel): +class Example(KilnParentedModel, KilnParentModel, parent_of={"outputs": ExampleOutput}): """ An example input to a specific Task. """ @@ -216,8 +216,9 @@ def relationship_name(cls): def parent_type(cls): return Task + # Needed for typechecking. TODO P2: fix this in KilnParentModel def outputs(self) -> list[ExampleOutput]: - return ExampleOutput.all_children_of_parent_path(self.path) + return super().outputs() # type: ignore def parent_task(self) -> Task | None: if not isinstance(self.parent, Task): @@ -267,7 +268,11 @@ class TaskDeterminism(str, Enum): flexible = "flexible" # Flexible on semantic output. Eval should be custom based on parsing requirements. -class Task(KilnParentedModel): +class Task( + KilnParentedModel, + KilnParentModel, + parent_of={"requirements": TaskRequirement, "examples": Example}, +): name: str = NAME_FIELD description: str = Field(default="") priority: Priority = Field(default=Priority.p2) @@ -295,16 +300,19 @@ def relationship_name(cls): def parent_type(cls): return Project + # Needed for typechecking. TODO P2: fix this in KilnParentModel def requirements(self) -> list[TaskRequirement]: - return TaskRequirement.all_children_of_parent_path(self.path) + return super().requirements() # type: ignore + # Needed for typechecking. TODO P2: fix this in KilnParentModel def examples(self) -> list[Example]: - return Example.all_children_of_parent_path(self.path) + return super().examples() # type: ignore -class Project(KilnBaseModel): +class Project(KilnParentModel, parent_of={"tasks": Task}): name: str = NAME_FIELD description: str = Field(default="") + # Needed for typechecking. TODO P2: fix this in KilnParentModel def tasks(self) -> list[Task]: - return Task.all_children_of_parent_path(self.path) + return super().tasks() # type: ignore diff --git a/libs/core/kiln_ai/datamodel/basemodel.py b/libs/core/kiln_ai/datamodel/basemodel.py index 588885b..8826ec2 100644 --- a/libs/core/kiln_ai/datamodel/basemodel.py +++ b/libs/core/kiln_ai/datamodel/basemodel.py @@ -5,7 +5,16 @@ from builtins import classmethod from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Optional, Self, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Self, + Type, + TypeVar, +) from kiln_ai.utils.config import Config from pydantic import ( @@ -223,3 +232,75 @@ def all_children_of_parent_path( children.append(child) return children + + +# Parent create methods for all child relationships +# You must pass in parent_of in the subclass definition, defining the child relationships +class KilnParentModel(KilnBaseModel, metaclass=ABCMeta): + @classmethod + def _create_child_method( + cls, relationship_name: str, child_class: Type[KilnParentedModel] + ): + def child_method(self) -> list[child_class]: + return child_class.all_children_of_parent_path(self.path) + + child_method.__name__ = relationship_name + child_method.__annotations__ = {"return": List[child_class]} + setattr(cls, relationship_name, child_method) + + @classmethod + def __init_subclass__(cls, parent_of: Dict[str, Type[KilnParentedModel]], **kwargs): + super().__init_subclass__(**kwargs) + cls._parent_of = parent_of + for relationship_name, child_class in parent_of.items(): + cls._create_child_method(relationship_name, child_class) + + @classmethod + def validate_and_save_with_subrelations( + cls, data: Dict[str, Any], path: Path | None = None + ): + # Validate first, then save. Don't want error half way through, and partly persisted + # TODO P2: save to tmp dir, then move atomically. But need to merge directories so later. + cls._validate_nested(data, save=False, path=path) + instance = cls._validate_nested(data, save=True, path=path) + return instance + + @classmethod + def _validate_nested( + cls, + data: Dict[str, Any], + save: bool = False, + parent: KilnBaseModel | None = None, + path: Path | None = None, + ): + instance = cls.model_validate(data) + if path is not None: + instance.path = path + if parent is not None and isinstance(instance, KilnParentedModel): + instance.parent = parent + if save: + instance.save_to_file() + for key, value_list in data.items(): + if key in cls._parent_of: + parent_type = cls._parent_of[key] + if not isinstance(value_list, list): + raise ValueError( + f"Expected a list for {key}, but got {type(value_list)}" + ) + for value in value_list: + if issubclass(parent_type, KilnParentModel): + parent_type._validate_nested( + data=value, save=save, parent=instance + ) + elif issubclass(parent_type, KilnBaseModel): + # Root node + subinstance = parent_type.model_validate(value) + subinstance.parent = instance + if save: + subinstance.save_to_file() + else: + raise ValueError( + f"Invalid type {parent_type}. Should be KilnBaseModel based." + ) + + return instance diff --git a/libs/core/kiln_ai/datamodel/test_models.py b/libs/core/kiln_ai/datamodel/test_models.py index 7bd828e..7e3bf60 100644 --- a/libs/core/kiln_ai/datamodel/test_models.py +++ b/libs/core/kiln_ai/datamodel/test_models.py @@ -123,8 +123,7 @@ def test_load_tasks(test_project_file): task3.save_to_file() # Load tasks from the project - # tasks = project.tasks() - tasks = Task.all_children_of_parent_path(test_project_file) + tasks = project.tasks() # Verify that all tasks are loaded correctly assert len(tasks) == 3 diff --git a/libs/core/kiln_ai/datamodel/test_nested_save.py b/libs/core/kiln_ai/datamodel/test_nested_save.py new file mode 100644 index 0000000..3309a03 --- /dev/null +++ b/libs/core/kiln_ai/datamodel/test_nested_save.py @@ -0,0 +1,151 @@ +import pytest +from kiln_ai.datamodel.basemodel import KilnParentedModel, KilnParentModel +from pydantic import Field, ValidationError + + +class ModelC(KilnParentedModel): + code: str = Field(..., pattern=r"^[A-Z]{3}$") + + @classmethod + def relationship_name(cls) -> str: + return "cs" + + @classmethod + def parent_type(cls): + return ModelB + + +class ModelB(KilnParentedModel, KilnParentModel, parent_of={"cs": ModelC}): + value: int = Field(..., ge=0) + + @classmethod + def relationship_name(cls) -> str: + return "bs" + + @classmethod + def parent_type(cls): + return ModelA + + +# Define the hierarchy +class ModelA(KilnParentModel, parent_of={"bs": ModelB}): + name: str = Field(..., min_length=3) + + +def test_validation_error_in_c_level(): + data = { + "name": "Root", + "bs": [ + { + "value": 10, + "cs": [ + {"code": "ABC"}, + {"code": "DEF"}, + {"code": "invalid"}, # This should cause a validation error + ], + } + ], + } + + with pytest.raises(ValidationError) as exc_info: + ModelA.validate_and_save_with_subrelations(data) + + assert "String should match pattern" in str(exc_info.value) + + +def test_persist_three_level_hierarchy(tmp_path): + # Set up temporary paths + root_path = tmp_path / "model_a.kiln" + + data = { + "name": "Root", + "bs": [ + {"value": 10, "cs": [{"code": "ABC"}, {"code": "DEF"}]}, + {"value": 20, "cs": [{"code": "XYZ"}]}, + ], + } + + instance = ModelA.validate_and_save_with_subrelations(data, path=root_path) + + assert isinstance(instance, ModelA) + assert instance.name == "Root" + assert instance.path == root_path + assert len(instance.bs()) == 2 + + # Load the instance back from the file to double-check + instance = ModelA.load_from_file(root_path) + + bs = instance.bs() + assert len(bs) == 2 + + # Check for the existence of both expected B models + b_values = [b.value for b in bs] + assert 10 in b_values + assert 20 in b_values + + # Find the B models by their values + b10 = next(b for b in bs if b.value == 10) + b20 = next(b for b in bs if b.value == 20) + + assert len(b10.cs()) == 2 + assert len(b20.cs()) == 1 + + # Check C models for b10 + c_codes_b10 = [c.code for c in b10.cs()] + assert "ABC" in c_codes_b10 + assert "DEF" in c_codes_b10 + + # Check C model for b20 + c_codes_b20 = [c.code for c in b20.cs()] + assert "XYZ" in c_codes_b20 + + # Check that all objects have their parent set correctly + assert all(b.parent == instance for b in bs) + assert all(c.parent.id == b10.id for c in b10.cs()) + assert all(c.parent.id == b20.id for c in b20.cs()) + + +def test_persist_model_a_without_children(tmp_path): + # Set up temporary path + root_path = tmp_path / "model_a_no_children.kiln" + + data = {"name": "RootNoChildren"} + + instance = ModelA.validate_and_save_with_subrelations(data, path=root_path) + + assert isinstance(instance, ModelA) + assert instance.name == "RootNoChildren" + assert instance.path == root_path + assert len(instance.bs()) == 0 + + # Verify that the file was created + assert root_path.exists() + + # Load the instance back from the file to double-check + loaded_instance = ModelA.load_from_file(root_path) + assert loaded_instance.name == "RootNoChildren" + assert len(loaded_instance.bs()) == 0 + + +def test_validate_without_saving(tmp_path): + data = { + "name": "ValidateOnly", + "bs": [ + {"value": 30, "cs": [{"code": "GHI"}, {"code": "JKL"}]}, + {"value": 40, "cs": [{"code": "MNO"}]}, + ], + } + + # Validate the data without saving + ModelA._validate_nested(data, save=False) + + data = { + "name": "ValidateOnly", + "bs": [ + {"value": 30, "cs": [{"code": "GHI"}, {"code": "JKL"}]}, + {"value": 40, "cs": [{"code": 123}]}, + ], + } + + with pytest.raises(ValidationError): + ModelA._validate_nested(data, save=False)