From 89bf6b966d44ffe46c3aa54fc5210e6be07bd41c Mon Sep 17 00:00:00 2001 From: scosman Date: Wed, 11 Sep 2024 16:30:08 -0400 Subject: [PATCH] Add input validation --- libs/core/kiln_ai/datamodel/basemodel.py | 1 - libs/core/kiln_ai/datamodel/models.py | 21 +++++++ .../kiln_ai/datamodel/test_example_models.py | 58 ++++++++++++++++--- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/libs/core/kiln_ai/datamodel/basemodel.py b/libs/core/kiln_ai/datamodel/basemodel.py index e900d36..5827800 100644 --- a/libs/core/kiln_ai/datamodel/basemodel.py +++ b/libs/core/kiln_ai/datamodel/basemodel.py @@ -10,7 +10,6 @@ BaseModel, ConfigDict, Field, - ValidationError, computed_field, model_validator, ) diff --git a/libs/core/kiln_ai/datamodel/models.py b/libs/core/kiln_ai/datamodel/models.py index e01faf7..c8279a6 100644 --- a/libs/core/kiln_ai/datamodel/models.py +++ b/libs/core/kiln_ai/datamodel/models.py @@ -186,6 +186,27 @@ def parent_type(cls): def outputs(self) -> list[ExampleOutput]: return ExampleOutput.all_children_of_parent_path(self.path) + @model_validator(mode="after") + def validate_input_format(self) -> Self: + task = self.parent + if task is None: + # don't validate this relationship until we have a path or parent. Give them time to build it (but will catch it before saving) + return self + if not isinstance(task, Task): + raise ValueError( + "ExampleOutput's parent Example must have a valid parent Task" + ) + + # validate output + if task.input_json_schema is not None: + try: + validate_schema(json.loads(self.input), task.input_json_schema) + except json.JSONDecodeError: + raise ValueError("Input is not a valid JSON object") + except jsonschema.exceptions.ValidationError as e: + raise ValueError(f"Input does not match task input schema: {e}") + return self + class TaskRequirement(KilnParentedModel): name: str = NAME_FIELD diff --git a/libs/core/kiln_ai/datamodel/test_example_models.py b/libs/core/kiln_ai/datamodel/test_example_models.py index ba033f4..16d5ddd 100644 --- a/libs/core/kiln_ai/datamodel/test_example_models.py +++ b/libs/core/kiln_ai/datamodel/test_example_models.py @@ -15,10 +15,12 @@ from pydantic import ValidationError -def test_example_model_validation(): +def test_example_model_validation(tmp_path): # Valid example + task = Task(name="Test Task", path=tmp_path / Task.base_filename()) + task.save_to_file() valid_example = Example( - path="/test/path", + parent=task, input="Test input", source=ExampleSource.human, source_properties={"creator": "John Doe"}, @@ -30,7 +32,7 @@ def test_example_model_validation(): # Invalid source with pytest.raises(ValidationError): Example( - path="/test/path", + parent=task, input="Test input", source="invalid_source", source_properties={}, @@ -38,21 +40,23 @@ def test_example_model_validation(): # Missing required field with pytest.raises(ValidationError): - Example(path="/test/path", source=ExampleSource.human, source_properties={}) + Example(parent=task, source=ExampleSource.human, source_properties={}) # Invalid source_properties type with pytest.raises(ValidationError): Example( - path="/test/path", + parent=task, input="Test input", source=ExampleSource.human, source_properties="invalid", ) -def test_example_relationship(): +def test_example_relationship(tmp_path): + task = Task(name="Test Task", path=tmp_path / Task.base_filename()) + task.save_to_file() example = Example( - path="/test/path", + parent=task, input="Test input", source=ExampleSource.human, source_properties={}, @@ -303,3 +307,43 @@ def test_example_output_schema_validation(tmp_path): parent=example, ) output.save_to_file() + + +def test_example_input_schema_validation(tmp_path): + # Create a project and task hierarchy + project = Project(name="Test Project", path=(tmp_path / "test_project")) + project.save_to_file() + task = Task( + name="Test Task", + parent=project, + input_json_schema=json.dumps( + { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "required": ["name", "age"], + } + ), + ) + task.save_to_file() + + # Create an example with a valid input schema + valid_example = Example( + input='{"name": "John Doe", "age": 30}', + source=ExampleSource.human, + parent=task, + ) + valid_example.save_to_file() + + # Changing to invalid input + with pytest.raises(ValueError): + valid_example.input = '{"name": "John Doe", "age": "thirty"}' + valid_example.save_to_file() + + # Invalid case: input does not match task input schema + with pytest.raises(ValueError): + example = Example( + input='{"name": "John Doe", "age": "thirty"}', + source=ExampleSource.human, + parent=task, + ) + example.save_to_file()