diff --git a/libs/core/kiln_ai/datamodel/basemodel.py b/libs/core/kiln_ai/datamodel/basemodel.py index f1f7ce7..e9c5ee8 100644 --- a/libs/core/kiln_ai/datamodel/basemodel.py +++ b/libs/core/kiln_ai/datamodel/basemodel.py @@ -4,14 +4,15 @@ from abc import ABCMeta, abstractmethod from builtins import classmethod from pathlib import Path -from typing import Optional, Type, TypeVar +from typing import Optional, Self, Type, TypeVar -from pydantic import BaseModel, Field, computed_field, field_validator +from pydantic import BaseModel, Field, computed_field, model_validator # ID is a 10 digit hex string ID_FIELD = Field(default_factory=lambda: uuid.uuid4().hex[:10].upper()) ID_TYPE = str T = TypeVar("T", bound="KilnBaseModel") +PT = TypeVar("PT", bound="KilnParentedModel") def snake_case(s: str) -> str: @@ -49,6 +50,8 @@ def load_from_file(cls: Type[T], path: Path) -> T: # Once for model_type, once for model. Can't call model_validate with parsed json because enum types break; they get strings instead of enums. parsed_json = json.loads(file_data) m = cls.model_validate_json(file_data, strict=True) + if not isinstance(m, cls): + raise ValueError(f"Loaded model is not of type {cls.__name__}") file_data = None m.path = path if m.v > m.max_schema_version(): @@ -91,7 +94,38 @@ def max_schema_version(self) -> int: class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta): id: ID_TYPE = ID_FIELD - parent: Optional[KilnBaseModel] = Field(default=None, exclude=True) + _parent: KilnBaseModel | None = None + + def __init__(self, **data): + super().__init__(**data) + if "parent" in data: + self.parent = data["parent"] + + @property + def parent(self) -> Optional[KilnBaseModel]: + if self._parent is not None: + return self._parent + # lazy load parent from path + if self.path is None: + return None + # TODO: this only works with base_filename. If we every support custom names, we need to change this. + parent_path = ( + self.path.parent.parent.parent / self.parent_type().base_filename() + ) + if parent_path is None: + return None + self._parent = self.parent_type().load_from_file(parent_path) + return self._parent + + @parent.setter + def parent(self, value: Optional[KilnBaseModel]): + if value is not None: + expected_parent_type = self.parent_type() + if not isinstance(value, expected_parent_type): + raise ValueError( + f"Parent must be of type {expected_parent_type}, but was {type(value)}" + ) + self._parent = value @classmethod @abstractmethod @@ -103,16 +137,15 @@ def relationship_name(cls) -> str: def parent_type(cls) -> Type[KilnBaseModel]: pass - @field_validator("parent") - @classmethod - def check_parent_type(cls, v: Optional[KilnBaseModel]) -> Optional[KilnBaseModel]: - if v is not None: - expected_parent_type = cls.parent_type() - if not isinstance(v, expected_parent_type): + @model_validator(mode="after") + def check_parent_type(self) -> Self: + if self._parent is not None: + expected_parent_type = self.__class__.parent_type() + if not isinstance(self._parent, expected_parent_type): raise ValueError( - f"Parent must be of type {expected_parent_type}, but was {type(v)}" + f"Parent must be of type {expected_parent_type}, but was {type(self._parent)}" ) - return v + return self def build_child_dirname(self) -> Path: # Default implementation for readable folder names. @@ -146,7 +179,9 @@ def build_path(self) -> Path | None: ) @classmethod - def all_children_of_parent_path(cls: Type[T], parent_path: Path | None) -> list[T]: + def all_children_of_parent_path( + cls: Type[PT], parent_path: Path | None + ) -> list[PT]: if parent_path is None: raise ValueError("Parent path must be set to load children") # Determine the parent folder @@ -155,6 +190,10 @@ def all_children_of_parent_path(cls: Type[T], parent_path: Path | None) -> list[ else: parent_folder = parent_path + parent = cls.parent_type().load_from_file(parent_path) + if parent is None: + raise ValueError("Parent must be set to load children") + # Ignore type error: this is abstract base class, but children must implement relationship_name relationship_folder = parent_folder / Path(cls.relationship_name()) # type: ignore diff --git a/libs/core/kiln_ai/datamodel/models.py b/libs/core/kiln_ai/datamodel/models.py index 5be0c8d..ca3d1de 100644 --- a/libs/core/kiln_ai/datamodel/models.py +++ b/libs/core/kiln_ai/datamodel/models.py @@ -1,8 +1,8 @@ from enum import Enum, IntEnum -from typing import Dict +from typing import Dict, Self from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str -from pydantic import Field +from pydantic import Field, model_validator from .basemodel import ID_TYPE, KilnBaseModel, KilnParentedModel @@ -88,7 +88,41 @@ def parent_type(cls): return Example # TODO validators for output and fixed_output: validate they follow the tas - # TODO validator that requirement_rating keys are requirement IDs + + def __init__(self, **data): + super().__init__(**data) + self.validate_requirement_rating_keys_manual() + + @model_validator(mode="after") + def validate_requirement_rating_keys(self) -> Self: + return self.validate_requirement_rating_keys_manual() + + def validate_requirement_rating_keys_manual(self) -> Self: + if len(self.requirement_ratings) == 0: + return self + example = self.parent + if example is None: + # don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving) + return self + if not isinstance(example, Example): + raise ValueError("ExampleOutput must have a valid parent Example") + + task = example.parent + if task is None: + # don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving) + return self + if not isinstance(task, Task): + raise ValueError( + "ExampleOutput's parent Example must have a valid parent Task" + ) + + valid_requirement_ids = {req.id for req in task.requirements()} + for key in self.requirement_ratings.keys(): + if key not in valid_requirement_ids: + raise ValueError( + f"Requirement ID '{key}' is not a valid requirement ID for this task" + ) + return self class ExampleSource(str, Enum): diff --git a/libs/core/kiln_ai/datamodel/test_basemodel.py b/libs/core/kiln_ai/datamodel/test_basemodel.py index b862737..f0889d3 100644 --- a/libs/core/kiln_ai/datamodel/test_basemodel.py +++ b/libs/core/kiln_ai/datamodel/test_basemodel.py @@ -94,7 +94,7 @@ def test_parented_model_path_gen(tmp_path): class BaseParentExample(KilnBaseModel): - pass + name: Optional[str] = None # Instance of the parented model for abstract methods, with default name builder @@ -234,3 +234,31 @@ def test_load_from_folder(test_base_parented_file): loaded_child1 = DefaultParentedModel.load_from_folder(child1.path.parent) assert loaded_child1.name == "Child1" + + +def test_lazy_load_parent(tmp_path): + # Create a parent + parent = BaseParentExample( + name="Parent", path=(tmp_path / BaseParentExample.base_filename()) + ) + parent.save_to_file() + print("PARENT", parent.path) + + # Create a child + child = DefaultParentedModel(parent=parent, name="Child") + child.save_to_file() + + # Load the child by path + loaded_child = DefaultParentedModel.load_from_file(child.path) + + # Access the parent to trigger lazy loading + loaded_parent = loaded_child.parent + + # Verify that the parent is now loaded and correct + assert loaded_parent is not None + assert loaded_parent.name == "Parent" + assert loaded_parent.path == parent.path + + # Verify that the _parent attribute is now set + assert hasattr(loaded_child, "_parent") + assert loaded_child._parent is loaded_parent diff --git a/libs/core/kiln_ai/datamodel/test_example_models.py b/libs/core/kiln_ai/datamodel/test_example_models.py index 46be4cc..e9a75ec 100644 --- a/libs/core/kiln_ai/datamodel/test_example_models.py +++ b/libs/core/kiln_ai/datamodel/test_example_models.py @@ -68,15 +68,12 @@ def test_example_output_model_validation(): output="Test output", source=ExampleOutputSource.human, source_properties={"creator": "Jane Doe"}, - requirement_ratings={ - "req1": ReasonRating(rating=4, reason="Good performance"), - "req2": ReasonRating(rating=3, reason="Meets expectations"), - }, + requirement_ratings={}, ) assert valid_output.output == "Test output" assert valid_output.source == ExampleOutputSource.human assert valid_output.source_properties == {"creator": "Jane Doe"} - assert len(valid_output.requirement_ratings) == 2 + assert len(valid_output.requirement_ratings) == 0 # Invalid source with pytest.raises(ValidationError): @@ -133,9 +130,11 @@ def test_example_output_relationship(): def test_structured_output_workflow(tmp_path): - tmp_project_dir = tmp_path / "test_structured_output_examples" + tmp_project_file = ( + tmp_path / "test_structured_output_examples" / Project.base_filename() + ) # Create project - project = Project(name="Test Project", path=str(tmp_path / tmp_project_dir)) + project = Project(name="Test Project", path=str(tmp_project_file)) project.save_to_file() # Create task with requirements @@ -198,7 +197,7 @@ def test_structured_output_workflow(tmp_path): outputs[0].save_to_file() # Load from disk and validate - loaded_project = Project.load_from_file(tmp_project_dir) + loaded_project = Project.load_from_file(tmp_project_file) loaded_task = loaded_project.tasks()[0] assert loaded_task.name == "Structured Output Task" @@ -207,8 +206,9 @@ def test_structured_output_workflow(tmp_path): loaded_examples = loaded_task.examples() for example in loaded_examples: - assert len(example.outputs()) == 1 - output = example.outputs()[0] + outputs = example.outputs() + assert len(outputs) == 1 + output = outputs[0] assert output.rating is not None assert len(output.requirement_ratings) == 2 @@ -226,3 +226,46 @@ def test_structured_output_workflow(tmp_path): example_with_fixed_output.outputs()[0].fixed_output == '{"name": "John Doe", "age": 31}' ) + + +def test_example_output_requirement_rating_keys(tmp_path): + # Create a project, task, and example hierarchy + project = Project(name="Test Project", path=(tmp_path / "test_project")) + project.save_to_file() + task = Task(name="Test Task", parent=project) + task.save_to_file() + example = Example(input="Test input", source="human", parent=task) + example.save_to_file() + + # Create task requirements + req1 = TaskRequirement(name="Requirement 1", parent=task) + req1.save_to_file() + req2 = TaskRequirement(name="Requirement 2", parent=task) + req2.save_to_file() + # Valid case: all requirement IDs are valid + valid_output = ExampleOutput( + output="Test output", + source="human", + parent=example, + requirement_ratings={ + req1.id: {"rating": 5, "reason": "Excellent"}, + req2.id: {"rating": 4, "reason": "Good"}, + }, + ) + valid_output.save_to_file() + assert valid_output.requirement_ratings is not None + + # Invalid case: unknown requirement ID + with pytest.raises( + ValueError, + match="Requirement ID .* is not a valid requirement ID for this task", + ): + output = ExampleOutput( + output="Test output", + source="human", + parent=example, + requirement_ratings={ + "unknown_id": {"rating": 4, "reason": "Good"}, + }, + ) + output.save_to_file()