From ec627db4f01c73d56b599a50c3a2cce86bb227bc Mon Sep 17 00:00:00 2001 From: scosman Date: Mon, 30 Sep 2024 16:31:26 -0400 Subject: [PATCH] Add validaton for source_properties --- .../kiln_ai/adapters/test_prompt_builders.py | 4 + libs/core/kiln_ai/datamodel/__init__.py | 28 ++++++ .../kiln_ai/datamodel/test_example_models.py | 98 ++++++++++++++++++- 3 files changed, 129 insertions(+), 1 deletion(-) diff --git a/libs/core/kiln_ai/adapters/test_prompt_builders.py b/libs/core/kiln_ai/adapters/test_prompt_builders.py index 03946c0..a8c50fa 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_builders.py +++ b/libs/core/kiln_ai/adapters/test_prompt_builders.py @@ -95,6 +95,7 @@ def test_multi_shot_prompt_builder(tmp_path): e1 = Example( input='{"subject": "Cows"}', source=ExampleSource.human, + source_properties={"creator": "john_doe"}, parent=task, ) e1.save_to_file() @@ -104,6 +105,7 @@ def test_multi_shot_prompt_builder(tmp_path): eo1 = ExampleOutput( output='{"joke": "Moo I am a cow joke."}', source=ExampleSource.human, + source_properties={"creator": "john_doe"}, parent=e1, ) eo1.save_to_file() @@ -116,12 +118,14 @@ def test_multi_shot_prompt_builder(tmp_path): e2 = Example( input='{"subject": "Dogs"}', source=ExampleSource.human, + source_properties={"creator": "john_doe"}, parent=task, ) e2.save_to_file() eo2 = ExampleOutput( output='{"joke": "This is a ruff joke."}', source=ExampleSource.human, + source_properties={"creator": "john_doe"}, parent=e2, rating=ReasonRating(rating=4, reason="Bark"), ) diff --git a/libs/core/kiln_ai/datamodel/__init__.py b/libs/core/kiln_ai/datamodel/__init__.py index a2c34cc..6db1c87 100644 --- a/libs/core/kiln_ai/datamodel/__init__.py +++ b/libs/core/kiln_ai/datamodel/__init__.py @@ -153,6 +153,34 @@ def validate_requirement_rating_keys(self) -> Self: ) return self + @model_validator(mode="after") + def validate_source_properties(self) -> Self: + if self.source == ExampleOutputSource.synthetic: + required_keys = { + "adapter_name", + "model_name", + "model_provider", + "prompt_builder_name", + } + elif self.source == ExampleOutputSource.human: + required_keys = {"creator"} + else: + raise ValueError(f"Invalid source type: {self.source}") + + missing_keys = [] + for key in required_keys: + if key not in self.source_properties: + missing_keys.append(key) + elif self.source_properties[key] == "": + raise ValueError( + f"example output source_properties[{key}] must not be empty string for {self.source} outputs" + ) + if len(missing_keys) > 0: + raise ValueError( + f"example output source_properties must include {missing_keys} for {self.source} outputs" + ) + return self + class ExampleSource(str, Enum): """ diff --git a/libs/core/kiln_ai/datamodel/test_example_models.py b/libs/core/kiln_ai/datamodel/test_example_models.py index 74a70eb..d79a3bf 100644 --- a/libs/core/kiln_ai/datamodel/test_example_models.py +++ b/libs/core/kiln_ai/datamodel/test_example_models.py @@ -174,6 +174,7 @@ def test_structured_output_workflow(tmp_path): output = ExampleOutput( output='{"name": "John Doe", "age": 30}', source=ExampleOutputSource.human, + source_properties={"creator": "john_doe"}, parent=example, ) output.save_to_file() @@ -242,6 +243,7 @@ def test_example_output_requirement_rating_keys(tmp_path): valid_output = ExampleOutput( output="Test output", source="human", + source_properties={"creator": "john_doe"}, parent=example, requirement_ratings={ req1.id: {"rating": 5, "reason": "Excellent"}, @@ -259,6 +261,7 @@ def test_example_output_requirement_rating_keys(tmp_path): output = ExampleOutput( output="Test output", source="human", + source_properties={"creator": "john_doe"}, parent=example, requirement_ratings={ "unknown_id": {"rating": 4, "reason": "Good"}, @@ -283,13 +286,19 @@ def test_example_output_schema_validation(tmp_path): ), ) task.save_to_file() - example = Example(input="Test input", source="human", parent=task) + example = Example( + input="Test input", + source="human", + parent=task, + source_properties={"creator": "john_doe"}, + ) example.save_to_file() # Create an example output with a valid schema valid_output = ExampleOutput( output='{"name": "John Doe", "age": 30}', source="human", + source_properties={"creator": "john_doe"}, parent=example, ) valid_output.save_to_file() @@ -304,6 +313,7 @@ def test_example_output_schema_validation(tmp_path): output = ExampleOutput( output='{"name": "John Doe", "age": "thirty"}', source="human", + source_properties={"creator": "john_doe"}, parent=example, ) output.save_to_file() @@ -347,3 +357,89 @@ def test_example_input_schema_validation(tmp_path): parent=task, ) example.save_to_file() + + +def test_valid_human_example_output(): + output = ExampleOutput( + output="Test output", + source=ExampleOutputSource.human, + source_properties={"creator": "John Doe"}, + ) + assert output.source == ExampleOutputSource.human + assert output.source_properties["creator"] == "John Doe" + + +def test_invalid_human_example_output_missing_creator(): + with pytest.raises( + ValidationError, + match="must include \['creator'\]", + ): + ExampleOutput( + output="Test output", source=ExampleOutputSource.human, source_properties={} + ) + + +def test_invalid_human_example_output_empty_creator(): + with pytest.raises(ValidationError, match="must not be empty string"): + ExampleOutput( + output="Test output", + source=ExampleOutputSource.human, + source_properties={"creator": ""}, + ) + + +def test_valid_synthetic_example_output(): + output = ExampleOutput( + output="Test output", + source=ExampleOutputSource.synthetic, + source_properties={ + "adapter_name": "TestAdapter", + "model_name": "GPT-4", + "model_provider": "OpenAI", + "prompt_builder_name": "TestPromptBuilder", + }, + ) + assert output.source == ExampleOutputSource.synthetic + assert output.source_properties["adapter_name"] == "TestAdapter" + assert output.source_properties["model_name"] == "GPT-4" + assert output.source_properties["model_provider"] == "OpenAI" + assert output.source_properties["prompt_builder_name"] == "TestPromptBuilder" + + +def test_invalid_synthetic_example_output_missing_keys(): + with pytest.raises( + ValidationError, match="example output source_properties must include" + ): + ExampleOutput( + output="Test output", + source=ExampleOutputSource.synthetic, + source_properties={"adapter_name": "TestAdapter", "model_name": "GPT-4"}, + ) + + +def test_invalid_synthetic_example_output_empty_values(): + with pytest.raises(ValidationError, match="must not be empty string"): + ExampleOutput( + output="Test output", + source=ExampleOutputSource.synthetic, + source_properties={ + "adapter_name": "TestAdapter", + "model_name": "", + "model_provider": "OpenAI", + "prompt_builder_name": "TestPromptBuilder", + }, + ) + + +def test_invalid_synthetic_example_output_non_string_values(): + with pytest.raises(ValidationError, match="Input should be a valid string"): + ExampleOutput( + output="Test output", + source=ExampleOutputSource.synthetic, + source_properties={ + "adapter_name": "TestAdapter", + "model_name": "GPT-4", + "model_provider": "OpenAI", + "prompt_builder_name": 123, + }, + )