Skip to content

Commit

Permalink
Add validators for requirement ID mappings
Browse files Browse the repository at this point in the history
Fix an issue I found where you can't open a file and then traverse to the parent. Parent now auto-discovered from current path if not pre-set
  • Loading branch information
scosman committed Sep 11, 2024
1 parent 4918e1b commit 264d838
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 26 deletions.
63 changes: 51 additions & 12 deletions libs/core/kiln_ai/datamodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from abc import ABCMeta, abstractmethod
from builtins import classmethod
from pathlib import Path
from typing import Optional, Type, TypeVar
from typing import Optional, Self, Type, TypeVar

from pydantic import BaseModel, Field, computed_field, field_validator
from pydantic import BaseModel, Field, computed_field, model_validator

# ID is a 10 digit hex string
ID_FIELD = Field(default_factory=lambda: uuid.uuid4().hex[:10].upper())
ID_TYPE = str
T = TypeVar("T", bound="KilnBaseModel")
PT = TypeVar("PT", bound="KilnParentedModel")


def snake_case(s: str) -> str:
Expand Down Expand Up @@ -49,6 +50,8 @@ def load_from_file(cls: Type[T], path: Path) -> T:
# Once for model_type, once for model. Can't call model_validate with parsed json because enum types break; they get strings instead of enums.
parsed_json = json.loads(file_data)
m = cls.model_validate_json(file_data, strict=True)
if not isinstance(m, cls):
raise ValueError(f"Loaded model is not of type {cls.__name__}")
file_data = None
m.path = path
if m.v > m.max_schema_version():
Expand Down Expand Up @@ -91,7 +94,38 @@ def max_schema_version(self) -> int:

class KilnParentedModel(KilnBaseModel, metaclass=ABCMeta):
id: ID_TYPE = ID_FIELD
parent: Optional[KilnBaseModel] = Field(default=None, exclude=True)
_parent: KilnBaseModel | None = None

def __init__(self, **data):
super().__init__(**data)
if "parent" in data:
self.parent = data["parent"]

@property
def parent(self) -> Optional[KilnBaseModel]:
if self._parent is not None:
return self._parent
# lazy load parent from path
if self.path is None:
return None
# TODO: this only works with base_filename. If we every support custom names, we need to change this.
parent_path = (
self.path.parent.parent.parent / self.parent_type().base_filename()
)
if parent_path is None:
return None
self._parent = self.parent_type().load_from_file(parent_path)
return self._parent

@parent.setter
def parent(self, value: Optional[KilnBaseModel]):
if value is not None:
expected_parent_type = self.parent_type()
if not isinstance(value, expected_parent_type):
raise ValueError(
f"Parent must be of type {expected_parent_type}, but was {type(value)}"
)
self._parent = value

@classmethod
@abstractmethod
Expand All @@ -103,16 +137,15 @@ def relationship_name(cls) -> str:
def parent_type(cls) -> Type[KilnBaseModel]:
pass

@field_validator("parent")
@classmethod
def check_parent_type(cls, v: Optional[KilnBaseModel]) -> Optional[KilnBaseModel]:
if v is not None:
expected_parent_type = cls.parent_type()
if not isinstance(v, expected_parent_type):
@model_validator(mode="after")
def check_parent_type(self) -> Self:
if self._parent is not None:
expected_parent_type = self.__class__.parent_type()
if not isinstance(self._parent, expected_parent_type):
raise ValueError(
f"Parent must be of type {expected_parent_type}, but was {type(v)}"
f"Parent must be of type {expected_parent_type}, but was {type(self._parent)}"
)
return v
return self

def build_child_dirname(self) -> Path:
# Default implementation for readable folder names.
Expand Down Expand Up @@ -146,7 +179,9 @@ def build_path(self) -> Path | None:
)

@classmethod
def all_children_of_parent_path(cls: Type[T], parent_path: Path | None) -> list[T]:
def all_children_of_parent_path(
cls: Type[PT], parent_path: Path | None
) -> list[PT]:
if parent_path is None:
raise ValueError("Parent path must be set to load children")
# Determine the parent folder
Expand All @@ -155,6 +190,10 @@ def all_children_of_parent_path(cls: Type[T], parent_path: Path | None) -> list[
else:
parent_folder = parent_path

parent = cls.parent_type().load_from_file(parent_path)
if parent is None:
raise ValueError("Parent must be set to load children")

# Ignore type error: this is abstract base class, but children must implement relationship_name
relationship_folder = parent_folder / Path(cls.relationship_name()) # type: ignore

Expand Down
40 changes: 37 additions & 3 deletions libs/core/kiln_ai/datamodel/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from enum import Enum, IntEnum
from typing import Dict
from typing import Dict, Self

from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str
from pydantic import Field
from pydantic import Field, model_validator

from .basemodel import ID_TYPE, KilnBaseModel, KilnParentedModel

Expand Down Expand Up @@ -88,7 +88,41 @@ def parent_type(cls):
return Example

# TODO validators for output and fixed_output: validate they follow the tas
# TODO validator that requirement_rating keys are requirement IDs

def __init__(self, **data):
super().__init__(**data)
self.validate_requirement_rating_keys_manual()

@model_validator(mode="after")
def validate_requirement_rating_keys(self) -> Self:
return self.validate_requirement_rating_keys_manual()

def validate_requirement_rating_keys_manual(self) -> Self:
if len(self.requirement_ratings) == 0:
return self
example = self.parent
if example is None:
# don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving)
return self
if not isinstance(example, Example):
raise ValueError("ExampleOutput must have a valid parent Example")

task = example.parent
if task is None:
# don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving)
return self
if not isinstance(task, Task):
raise ValueError(
"ExampleOutput's parent Example must have a valid parent Task"
)

valid_requirement_ids = {req.id for req in task.requirements()}
for key in self.requirement_ratings.keys():
if key not in valid_requirement_ids:
raise ValueError(
f"Requirement ID '{key}' is not a valid requirement ID for this task"
)
return self


class ExampleSource(str, Enum):
Expand Down
30 changes: 29 additions & 1 deletion libs/core/kiln_ai/datamodel/test_basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_parented_model_path_gen(tmp_path):


class BaseParentExample(KilnBaseModel):
pass
name: Optional[str] = None


# Instance of the parented model for abstract methods, with default name builder
Expand Down Expand Up @@ -234,3 +234,31 @@ def test_load_from_folder(test_base_parented_file):

loaded_child1 = DefaultParentedModel.load_from_folder(child1.path.parent)
assert loaded_child1.name == "Child1"


def test_lazy_load_parent(tmp_path):
# Create a parent
parent = BaseParentExample(
name="Parent", path=(tmp_path / BaseParentExample.base_filename())
)
parent.save_to_file()
print("PARENT", parent.path)

# Create a child
child = DefaultParentedModel(parent=parent, name="Child")
child.save_to_file()

# Load the child by path
loaded_child = DefaultParentedModel.load_from_file(child.path)

# Access the parent to trigger lazy loading
loaded_parent = loaded_child.parent

# Verify that the parent is now loaded and correct
assert loaded_parent is not None
assert loaded_parent.name == "Parent"
assert loaded_parent.path == parent.path

# Verify that the _parent attribute is now set
assert hasattr(loaded_child, "_parent")
assert loaded_child._parent is loaded_parent
63 changes: 53 additions & 10 deletions libs/core/kiln_ai/datamodel/test_example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,12 @@ def test_example_output_model_validation():
output="Test output",
source=ExampleOutputSource.human,
source_properties={"creator": "Jane Doe"},
requirement_ratings={
"req1": ReasonRating(rating=4, reason="Good performance"),
"req2": ReasonRating(rating=3, reason="Meets expectations"),
},
requirement_ratings={},
)
assert valid_output.output == "Test output"
assert valid_output.source == ExampleOutputSource.human
assert valid_output.source_properties == {"creator": "Jane Doe"}
assert len(valid_output.requirement_ratings) == 2
assert len(valid_output.requirement_ratings) == 0

# Invalid source
with pytest.raises(ValidationError):
Expand Down Expand Up @@ -133,9 +130,11 @@ def test_example_output_relationship():


def test_structured_output_workflow(tmp_path):
tmp_project_dir = tmp_path / "test_structured_output_examples"
tmp_project_file = (
tmp_path / "test_structured_output_examples" / Project.base_filename()
)
# Create project
project = Project(name="Test Project", path=str(tmp_path / tmp_project_dir))
project = Project(name="Test Project", path=str(tmp_project_file))
project.save_to_file()

# Create task with requirements
Expand Down Expand Up @@ -198,7 +197,7 @@ def test_structured_output_workflow(tmp_path):
outputs[0].save_to_file()

# Load from disk and validate
loaded_project = Project.load_from_file(tmp_project_dir)
loaded_project = Project.load_from_file(tmp_project_file)
loaded_task = loaded_project.tasks()[0]

assert loaded_task.name == "Structured Output Task"
Expand All @@ -207,8 +206,9 @@ def test_structured_output_workflow(tmp_path):

loaded_examples = loaded_task.examples()
for example in loaded_examples:
assert len(example.outputs()) == 1
output = example.outputs()[0]
outputs = example.outputs()
assert len(outputs) == 1
output = outputs[0]
assert output.rating is not None
assert len(output.requirement_ratings) == 2

Expand All @@ -226,3 +226,46 @@ def test_structured_output_workflow(tmp_path):
example_with_fixed_output.outputs()[0].fixed_output
== '{"name": "John Doe", "age": 31}'
)


def test_example_output_requirement_rating_keys(tmp_path):
# Create a project, task, and example hierarchy
project = Project(name="Test Project", path=(tmp_path / "test_project"))
project.save_to_file()
task = Task(name="Test Task", parent=project)
task.save_to_file()
example = Example(input="Test input", source="human", parent=task)
example.save_to_file()

# Create task requirements
req1 = TaskRequirement(name="Requirement 1", parent=task)
req1.save_to_file()
req2 = TaskRequirement(name="Requirement 2", parent=task)
req2.save_to_file()
# Valid case: all requirement IDs are valid
valid_output = ExampleOutput(
output="Test output",
source="human",
parent=example,
requirement_ratings={
req1.id: {"rating": 5, "reason": "Excellent"},
req2.id: {"rating": 4, "reason": "Good"},
},
)
valid_output.save_to_file()
assert valid_output.requirement_ratings is not None

# Invalid case: unknown requirement ID
with pytest.raises(
ValueError,
match="Requirement ID .* is not a valid requirement ID for this task",
):
output = ExampleOutput(
output="Test output",
source="human",
parent=example,
requirement_ratings={
"unknown_id": {"rating": 4, "reason": "Good"},
},
)
output.save_to_file()

0 comments on commit 264d838

Please sign in to comment.