Skip to content

Commit

Permalink
Add output format checking
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Sep 11, 2024
1 parent 264d838 commit 9a7a02e
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 29 deletions.
11 changes: 10 additions & 1 deletion libs/core/kiln_ai/datamodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
from pathlib import Path
from typing import Optional, Self, Type, TypeVar

from pydantic import BaseModel, Field, computed_field, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
ValidationError,
computed_field,
model_validator,
)

# ID is a 10 digit hex string
ID_FIELD = Field(default_factory=lambda: uuid.uuid4().hex[:10].upper())
Expand All @@ -20,6 +27,8 @@ def snake_case(s: str) -> str:


class KilnBaseModel(BaseModel):
model_config = ConfigDict(validate_assignment=True)

v: int = 1 # schema_version
path: Optional[Path] = Field(default=None, exclude=True)

Expand Down
52 changes: 38 additions & 14 deletions libs/core/kiln_ai/datamodel/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from __future__ import annotations

import json
from enum import Enum, IntEnum
from typing import Dict, Self
from typing import TYPE_CHECKING, Dict, Self

import jsonschema
import jsonschema.exceptions
from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str
from pydantic import Field, model_validator

from .basemodel import ID_TYPE, KilnBaseModel, KilnParentedModel
from .json_schema import validate_schema

if TYPE_CHECKING:
from .models import Task

# Conventions:
# 1) Names are filename safe as they may be used as file names. They are informational and not to be used in prompts/training/validation.
Expand Down Expand Up @@ -89,32 +98,47 @@ def parent_type(cls):

# TODO validators for output and fixed_output: validate they follow the tas

def __init__(self, **data):
super().__init__(**data)
self.validate_requirement_rating_keys_manual()

@model_validator(mode="after")
def validate_requirement_rating_keys(self) -> Self:
return self.validate_requirement_rating_keys_manual()

def validate_requirement_rating_keys_manual(self) -> Self:
if len(self.requirement_ratings) == 0:
def validate_output_format(self) -> Self:
task = self.task_for_validation()
if task is None:
# don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving)
return self

# validate output
if task.output_json_schema is not None:
try:
validate_schema(json.loads(self.output), task.output_json_schema)
except json.JSONDecodeError:
raise ValueError("Output is not a valid JSON object")
except jsonschema.exceptions.ValidationError as e:
raise ValueError(f"Output does not match task output schema: {e}")
return self

def task_for_validation(self) -> Task | None:
example = self.parent
if example is None:
# don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving)
return self
return None
if not isinstance(example, Example):
raise ValueError("ExampleOutput must have a valid parent Example")

task = example.parent
if task is None:
# don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving)
return self
return None
if not isinstance(task, Task):
raise ValueError(
"ExampleOutput's parent Example must have a valid parent Task"
)
return task

@model_validator(mode="after")
def validate_requirement_rating_keys(self) -> Self:
if len(self.requirement_ratings) == 0:
return self
task = self.task_for_validation()
if task is None:
# don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving)
return self

valid_requirement_ids = {req.id for req in task.requirements()}
for key in self.requirement_ratings.keys():
Expand Down
62 changes: 48 additions & 14 deletions libs/core/kiln_ai/datamodel/test_example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,14 @@ def test_example_relationship():
assert example.parent_type().__name__ == "Task"


def test_example_output_model_validation():
def test_example_output_model_validation(tmp_path):
# Valid example output
task = Task(name="Test Task", path=tmp_path / Task.base_filename())
task.save_to_file()
example = Example(input="Test input", source=ExampleSource.human, parent=task)
example.save_to_file()
valid_output = ExampleOutput(
path="/test/path",
parent=example,
output="Test output",
source=ExampleOutputSource.human,
source_properties={"creator": "Jane Doe"},
Expand Down Expand Up @@ -117,18 +121,6 @@ def test_example_output_model_validation():
)


def test_example_output_relationship():
example_output = ExampleOutput(
path="/test/path",
output="Test output",
source=ExampleOutputSource.human,
source_properties={},
requirement_ratings={},
)
assert example_output.relationship_name() == "outputs"
assert example_output.parent_type().__name__ == "Example"


def test_structured_output_workflow(tmp_path):
tmp_project_file = (
tmp_path / "test_structured_output_examples" / Project.base_filename()
Expand Down Expand Up @@ -269,3 +261,45 @@ def test_example_output_requirement_rating_keys(tmp_path):
},
)
output.save_to_file()


def test_example_output_schema_validation(tmp_path):
# Create a project, task, and example hierarchy
project = Project(name="Test Project", path=(tmp_path / "test_project"))
project.save_to_file()
task = Task(
name="Test Task",
parent=project,
output_json_schema=json.dumps(
{
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
"required": ["name", "age"],
}
),
)
task.save_to_file()
example = Example(input="Test input", source="human", parent=task)
example.save_to_file()

# Create an example output with a valid schema
valid_output = ExampleOutput(
output='{"name": "John Doe", "age": 30}',
source="human",
parent=example,
)
valid_output.save_to_file()

# changing to invalid output
with pytest.raises(ValueError):
valid_output.output = '{"name": "John Doe", "age": "thirty"}'
valid_output.save_to_file()

# Invalid case: output does not match task output schema
with pytest.raises(ValueError):
output = ExampleOutput(
output='{"name": "John Doe", "age": "thirty"}',
source="human",
parent=example,
)
output.save_to_file()

0 comments on commit 9a7a02e

Please sign in to comment.