Skip to content

Commit

Permalink
Add the ability to do nested validation and saving for our REST APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Oct 8, 2024
1 parent 9041a06 commit 77cd08f
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 12 deletions.
26 changes: 17 additions & 9 deletions libs/core/kiln_ai/datamodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import json
from enum import Enum, IntEnum
from typing import TYPE_CHECKING, Dict, Self
from typing import TYPE_CHECKING, Dict, Self, Type

import jsonschema
import jsonschema.exceptions
from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str
from pydantic import Field, model_validator

from .basemodel import ID_TYPE, KilnBaseModel, KilnParentedModel
from .basemodel import ID_TYPE, KilnBaseModel, KilnParentedModel, KilnParentModel
from .json_schema import validate_schema

if TYPE_CHECKING:
Expand Down Expand Up @@ -191,7 +191,7 @@ class ExampleSource(str, Enum):
synthetic = "synthetic"


class Example(KilnParentedModel):
class Example(KilnParentedModel, KilnParentModel, parent_of={"outputs": ExampleOutput}):
"""
An example input to a specific Task.
"""
Expand All @@ -216,8 +216,9 @@ def relationship_name(cls):
def parent_type(cls):
return Task

# Needed for typechecking. TODO P2: fix this in KilnParentModel
def outputs(self) -> list[ExampleOutput]:
return ExampleOutput.all_children_of_parent_path(self.path)
return super().outputs() # type: ignore

def parent_task(self) -> Task | None:
if not isinstance(self.parent, Task):
Expand Down Expand Up @@ -267,7 +268,11 @@ class TaskDeterminism(str, Enum):
flexible = "flexible" # Flexible on semantic output. Eval should be custom based on parsing requirements.


class Task(KilnParentedModel):
class Task(
KilnParentedModel,
KilnParentModel,
parent_of={"requirements": TaskRequirement, "examples": Example},
):
name: str = NAME_FIELD
description: str = Field(default="")
priority: Priority = Field(default=Priority.p2)
Expand Down Expand Up @@ -295,16 +300,19 @@ def relationship_name(cls):
def parent_type(cls):
return Project

# Needed for typechecking. TODO P2: fix this in KilnParentModel
def requirements(self) -> list[TaskRequirement]:
return TaskRequirement.all_children_of_parent_path(self.path)
return super().requirements() # type: ignore

# Needed for typechecking. TODO P2: fix this in KilnParentModel
def examples(self) -> list[Example]:
return Example.all_children_of_parent_path(self.path)
return super().examples() # type: ignore


class Project(KilnBaseModel):
class Project(KilnParentModel, parent_of={"tasks": Task}):
name: str = NAME_FIELD
description: str = Field(default="")

# Needed for typechecking. TODO P2: fix this in KilnParentModel
def tasks(self) -> list[Task]:
return Task.all_children_of_parent_path(self.path)
return super().tasks() # type: ignore
83 changes: 82 additions & 1 deletion libs/core/kiln_ai/datamodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
from builtins import classmethod
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Self, Type, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Optional,
Self,
Type,
TypeVar,
)

from kiln_ai.utils.config import Config
from pydantic import (
Expand Down Expand Up @@ -223,3 +232,75 @@ def all_children_of_parent_path(
children.append(child)

return children


# Parent create methods for all child relationships
# You must pass in parent_of in the subclass definition, defining the child relationships
class KilnParentModel(KilnBaseModel, metaclass=ABCMeta):
@classmethod
def _create_child_method(
cls, relationship_name: str, child_class: Type[KilnParentedModel]
):
def child_method(self) -> list[child_class]:
return child_class.all_children_of_parent_path(self.path)

child_method.__name__ = relationship_name
child_method.__annotations__ = {"return": List[child_class]}
setattr(cls, relationship_name, child_method)

@classmethod
def __init_subclass__(cls, parent_of: Dict[str, Type[KilnParentedModel]], **kwargs):
super().__init_subclass__(**kwargs)
cls._parent_of = parent_of
for relationship_name, child_class in parent_of.items():
cls._create_child_method(relationship_name, child_class)

@classmethod
def validate_and_save_with_subrelations(
cls, data: Dict[str, Any], path: Path | None = None
):
# Validate first, then save. Don't want error half way through, and partly persisted
# TODO P2: save to tmp dir, then move atomically. But need to merge directories so later.
cls._validate_nested(data, save=False, path=path)
instance = cls._validate_nested(data, save=True, path=path)
return instance

@classmethod
def _validate_nested(
cls,
data: Dict[str, Any],
save: bool = False,
parent: KilnBaseModel | None = None,
path: Path | None = None,
):
instance = cls.model_validate(data)
if path is not None:
instance.path = path
if parent is not None and isinstance(instance, KilnParentedModel):
instance.parent = parent
if save:
instance.save_to_file()
for key, value_list in data.items():
if key in cls._parent_of:
parent_type = cls._parent_of[key]
if not isinstance(value_list, list):
raise ValueError(
f"Expected a list for {key}, but got {type(value_list)}"
)
for value in value_list:
if issubclass(parent_type, KilnParentModel):
parent_type._validate_nested(
data=value, save=save, parent=instance
)
elif issubclass(parent_type, KilnBaseModel):
# Root node
subinstance = parent_type.model_validate(value)
subinstance.parent = instance
if save:
subinstance.save_to_file()
else:
raise ValueError(
f"Invalid type {parent_type}. Should be KilnBaseModel based."
)

return instance
3 changes: 1 addition & 2 deletions libs/core/kiln_ai/datamodel/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ def test_load_tasks(test_project_file):
task3.save_to_file()

# Load tasks from the project
# tasks = project.tasks()
tasks = Task.all_children_of_parent_path(test_project_file)
tasks = project.tasks()

# Verify that all tasks are loaded correctly
assert len(tasks) == 3
Expand Down
151 changes: 151 additions & 0 deletions libs/core/kiln_ai/datamodel/test_nested_save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import pytest
from kiln_ai.datamodel.basemodel import KilnParentedModel, KilnParentModel
from pydantic import Field, ValidationError


class ModelC(KilnParentedModel):
code: str = Field(..., pattern=r"^[A-Z]{3}$")

@classmethod
def relationship_name(cls) -> str:
return "cs"

@classmethod
def parent_type(cls):
return ModelB


class ModelB(KilnParentedModel, KilnParentModel, parent_of={"cs": ModelC}):
value: int = Field(..., ge=0)

@classmethod
def relationship_name(cls) -> str:
return "bs"

@classmethod
def parent_type(cls):
return ModelA


# Define the hierarchy
class ModelA(KilnParentModel, parent_of={"bs": ModelB}):
name: str = Field(..., min_length=3)


def test_validation_error_in_c_level():
data = {
"name": "Root",
"bs": [
{
"value": 10,
"cs": [
{"code": "ABC"},
{"code": "DEF"},
{"code": "invalid"}, # This should cause a validation error
],
}
],
}

with pytest.raises(ValidationError) as exc_info:
ModelA.validate_and_save_with_subrelations(data)

assert "String should match pattern" in str(exc_info.value)


def test_persist_three_level_hierarchy(tmp_path):
# Set up temporary paths
root_path = tmp_path / "model_a.kiln"

data = {
"name": "Root",
"bs": [
{"value": 10, "cs": [{"code": "ABC"}, {"code": "DEF"}]},
{"value": 20, "cs": [{"code": "XYZ"}]},
],
}

instance = ModelA.validate_and_save_with_subrelations(data, path=root_path)

assert isinstance(instance, ModelA)
assert instance.name == "Root"
assert instance.path == root_path
assert len(instance.bs()) == 2

# Load the instance back from the file to double-check
instance = ModelA.load_from_file(root_path)

bs = instance.bs()
assert len(bs) == 2

# Check for the existence of both expected B models
b_values = [b.value for b in bs]
assert 10 in b_values
assert 20 in b_values

# Find the B models by their values
b10 = next(b for b in bs if b.value == 10)
b20 = next(b for b in bs if b.value == 20)

assert len(b10.cs()) == 2
assert len(b20.cs()) == 1

# Check C models for b10
c_codes_b10 = [c.code for c in b10.cs()]
assert "ABC" in c_codes_b10
assert "DEF" in c_codes_b10

# Check C model for b20
c_codes_b20 = [c.code for c in b20.cs()]
assert "XYZ" in c_codes_b20

# Check that all objects have their parent set correctly
assert all(b.parent == instance for b in bs)
assert all(c.parent.id == b10.id for c in b10.cs())
assert all(c.parent.id == b20.id for c in b20.cs())


def test_persist_model_a_without_children(tmp_path):
# Set up temporary path
root_path = tmp_path / "model_a_no_children.kiln"

data = {"name": "RootNoChildren"}

instance = ModelA.validate_and_save_with_subrelations(data, path=root_path)

assert isinstance(instance, ModelA)
assert instance.name == "RootNoChildren"
assert instance.path == root_path
assert len(instance.bs()) == 0

# Verify that the file was created
assert root_path.exists()

# Load the instance back from the file to double-check
loaded_instance = ModelA.load_from_file(root_path)
assert loaded_instance.name == "RootNoChildren"
assert len(loaded_instance.bs()) == 0


def test_validate_without_saving(tmp_path):
data = {
"name": "ValidateOnly",
"bs": [
{"value": 30, "cs": [{"code": "GHI"}, {"code": "JKL"}]},
{"value": 40, "cs": [{"code": "MNO"}]},
],
}

# Validate the data without saving
ModelA._validate_nested(data, save=False)

data = {
"name": "ValidateOnly",
"bs": [
{"value": 30, "cs": [{"code": "GHI"}, {"code": "JKL"}]},
{"value": 40, "cs": [{"code": 123}]},
],
}

with pytest.raises(ValidationError):
ModelA._validate_nested(data, save=False)

0 comments on commit 77cd08f

Please sign in to comment.