From 9a7a02e6bdebc5a80fa89d2d678b47533bd53708 Mon Sep 17 00:00:00 2001 From: scosman Date: Wed, 11 Sep 2024 16:19:59 -0400 Subject: [PATCH] Add output format checking --- libs/core/kiln_ai/datamodel/basemodel.py | 11 +++- libs/core/kiln_ai/datamodel/models.py | 52 +++++++++++----- .../kiln_ai/datamodel/test_example_models.py | 62 ++++++++++++++----- 3 files changed, 96 insertions(+), 29 deletions(-) diff --git a/libs/core/kiln_ai/datamodel/basemodel.py b/libs/core/kiln_ai/datamodel/basemodel.py index e9c5ee8..e900d36 100644 --- a/libs/core/kiln_ai/datamodel/basemodel.py +++ b/libs/core/kiln_ai/datamodel/basemodel.py @@ -6,7 +6,14 @@ from pathlib import Path from typing import Optional, Self, Type, TypeVar -from pydantic import BaseModel, Field, computed_field, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError, + computed_field, + model_validator, +) # ID is a 10 digit hex string ID_FIELD = Field(default_factory=lambda: uuid.uuid4().hex[:10].upper()) @@ -20,6 +27,8 @@ def snake_case(s: str) -> str: class KilnBaseModel(BaseModel): + model_config = ConfigDict(validate_assignment=True) + v: int = 1 # schema_version path: Optional[Path] = Field(default=None, exclude=True) diff --git a/libs/core/kiln_ai/datamodel/models.py b/libs/core/kiln_ai/datamodel/models.py index ca3d1de..e01faf7 100644 --- a/libs/core/kiln_ai/datamodel/models.py +++ b/libs/core/kiln_ai/datamodel/models.py @@ -1,10 +1,19 @@ +from __future__ import annotations + +import json from enum import Enum, IntEnum -from typing import Dict, Self +from typing import TYPE_CHECKING, Dict, Self +import jsonschema +import jsonschema.exceptions from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str from pydantic import Field, model_validator from .basemodel import ID_TYPE, KilnBaseModel, KilnParentedModel +from .json_schema import validate_schema + +if TYPE_CHECKING: + from .models import Task # Conventions: # 1) Names are filename safe as they may be used as file names. They are informational and not to be used in prompts/training/validation. @@ -89,32 +98,47 @@ def parent_type(cls): # TODO validators for output and fixed_output: validate they follow the tas - def __init__(self, **data): - super().__init__(**data) - self.validate_requirement_rating_keys_manual() - @model_validator(mode="after") - def validate_requirement_rating_keys(self) -> Self: - return self.validate_requirement_rating_keys_manual() - - def validate_requirement_rating_keys_manual(self) -> Self: - if len(self.requirement_ratings) == 0: + def validate_output_format(self) -> Self: + task = self.task_for_validation() + if task is None: + # don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving) return self + + # validate output + if task.output_json_schema is not None: + try: + validate_schema(json.loads(self.output), task.output_json_schema) + except json.JSONDecodeError: + raise ValueError("Output is not a valid JSON object") + except jsonschema.exceptions.ValidationError as e: + raise ValueError(f"Output does not match task output schema: {e}") + return self + + def task_for_validation(self) -> Task | None: example = self.parent if example is None: - # don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving) - return self + return None if not isinstance(example, Example): raise ValueError("ExampleOutput must have a valid parent Example") task = example.parent if task is None: - # don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving) - return self + return None if not isinstance(task, Task): raise ValueError( "ExampleOutput's parent Example must have a valid parent Task" ) + return task + + @model_validator(mode="after") + def validate_requirement_rating_keys(self) -> Self: + if len(self.requirement_ratings) == 0: + return self + task = self.task_for_validation() + if task is None: + # don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving) + return self valid_requirement_ids = {req.id for req in task.requirements()} for key in self.requirement_ratings.keys(): diff --git a/libs/core/kiln_ai/datamodel/test_example_models.py b/libs/core/kiln_ai/datamodel/test_example_models.py index e9a75ec..ba033f4 100644 --- a/libs/core/kiln_ai/datamodel/test_example_models.py +++ b/libs/core/kiln_ai/datamodel/test_example_models.py @@ -61,10 +61,14 @@ def test_example_relationship(): assert example.parent_type().__name__ == "Task" -def test_example_output_model_validation(): +def test_example_output_model_validation(tmp_path): # Valid example output + task = Task(name="Test Task", path=tmp_path / Task.base_filename()) + task.save_to_file() + example = Example(input="Test input", source=ExampleSource.human, parent=task) + example.save_to_file() valid_output = ExampleOutput( - path="/test/path", + parent=example, output="Test output", source=ExampleOutputSource.human, source_properties={"creator": "Jane Doe"}, @@ -117,18 +121,6 @@ def test_example_output_model_validation(): ) -def test_example_output_relationship(): - example_output = ExampleOutput( - path="/test/path", - output="Test output", - source=ExampleOutputSource.human, - source_properties={}, - requirement_ratings={}, - ) - assert example_output.relationship_name() == "outputs" - assert example_output.parent_type().__name__ == "Example" - - def test_structured_output_workflow(tmp_path): tmp_project_file = ( tmp_path / "test_structured_output_examples" / Project.base_filename() @@ -269,3 +261,45 @@ def test_example_output_requirement_rating_keys(tmp_path): }, ) output.save_to_file() + + +def test_example_output_schema_validation(tmp_path): + # Create a project, task, and example hierarchy + project = Project(name="Test Project", path=(tmp_path / "test_project")) + project.save_to_file() + task = Task( + name="Test Task", + parent=project, + output_json_schema=json.dumps( + { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name", "age"], + } + ), + ) + task.save_to_file() + example = Example(input="Test input", source="human", parent=task) + example.save_to_file() + + # Create an example output with a valid schema + valid_output = ExampleOutput( + output='{"name": "John Doe", "age": 30}', + source="human", + parent=example, + ) + valid_output.save_to_file() + + # changing to invalid output + with pytest.raises(ValueError): + valid_output.output = '{"name": "John Doe", "age": "thirty"}' + valid_output.save_to_file() + + # Invalid case: output does not match task output schema + with pytest.raises(ValueError): + output = ExampleOutput( + output='{"name": "John Doe", "age": "thirty"}', + source="human", + parent=example, + ) + output.save_to_file()