Skip to content

Commit

Permalink
Add validaton for source_properties
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Sep 30, 2024
1 parent 50b7a1a commit ec627db
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 1 deletion.
4 changes: 4 additions & 0 deletions libs/core/kiln_ai/adapters/test_prompt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def test_multi_shot_prompt_builder(tmp_path):
e1 = Example(
input='{"subject": "Cows"}',
source=ExampleSource.human,
source_properties={"creator": "john_doe"},
parent=task,
)
e1.save_to_file()
Expand All @@ -104,6 +105,7 @@ def test_multi_shot_prompt_builder(tmp_path):
eo1 = ExampleOutput(
output='{"joke": "Moo I am a cow joke."}',
source=ExampleSource.human,
source_properties={"creator": "john_doe"},
parent=e1,
)
eo1.save_to_file()
Expand All @@ -116,12 +118,14 @@ def test_multi_shot_prompt_builder(tmp_path):
e2 = Example(
input='{"subject": "Dogs"}',
source=ExampleSource.human,
source_properties={"creator": "john_doe"},
parent=task,
)
e2.save_to_file()
eo2 = ExampleOutput(
output='{"joke": "This is a ruff joke."}',
source=ExampleSource.human,
source_properties={"creator": "john_doe"},
parent=e2,
rating=ReasonRating(rating=4, reason="Bark"),
)
Expand Down
28 changes: 28 additions & 0 deletions libs/core/kiln_ai/datamodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,34 @@ def validate_requirement_rating_keys(self) -> Self:
)
return self

@model_validator(mode="after")
def validate_source_properties(self) -> Self:
if self.source == ExampleOutputSource.synthetic:
required_keys = {
"adapter_name",
"model_name",
"model_provider",
"prompt_builder_name",
}
elif self.source == ExampleOutputSource.human:
required_keys = {"creator"}
else:
raise ValueError(f"Invalid source type: {self.source}")

missing_keys = []
for key in required_keys:
if key not in self.source_properties:
missing_keys.append(key)
elif self.source_properties[key] == "":
raise ValueError(
f"example output source_properties[{key}] must not be empty string for {self.source} outputs"
)
if len(missing_keys) > 0:
raise ValueError(
f"example output source_properties must include {missing_keys} for {self.source} outputs"
)
return self


class ExampleSource(str, Enum):
"""
Expand Down
98 changes: 97 additions & 1 deletion libs/core/kiln_ai/datamodel/test_example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_structured_output_workflow(tmp_path):
output = ExampleOutput(
output='{"name": "John Doe", "age": 30}',
source=ExampleOutputSource.human,
source_properties={"creator": "john_doe"},
parent=example,
)
output.save_to_file()
Expand Down Expand Up @@ -242,6 +243,7 @@ def test_example_output_requirement_rating_keys(tmp_path):
valid_output = ExampleOutput(
output="Test output",
source="human",
source_properties={"creator": "john_doe"},
parent=example,
requirement_ratings={
req1.id: {"rating": 5, "reason": "Excellent"},
Expand All @@ -259,6 +261,7 @@ def test_example_output_requirement_rating_keys(tmp_path):
output = ExampleOutput(
output="Test output",
source="human",
source_properties={"creator": "john_doe"},
parent=example,
requirement_ratings={
"unknown_id": {"rating": 4, "reason": "Good"},
Expand All @@ -283,13 +286,19 @@ def test_example_output_schema_validation(tmp_path):
),
)
task.save_to_file()
example = Example(input="Test input", source="human", parent=task)
example = Example(
input="Test input",
source="human",
parent=task,
source_properties={"creator": "john_doe"},
)
example.save_to_file()

# Create an example output with a valid schema
valid_output = ExampleOutput(
output='{"name": "John Doe", "age": 30}',
source="human",
source_properties={"creator": "john_doe"},
parent=example,
)
valid_output.save_to_file()
Expand All @@ -304,6 +313,7 @@ def test_example_output_schema_validation(tmp_path):
output = ExampleOutput(
output='{"name": "John Doe", "age": "thirty"}',
source="human",
source_properties={"creator": "john_doe"},
parent=example,
)
output.save_to_file()
Expand Down Expand Up @@ -347,3 +357,89 @@ def test_example_input_schema_validation(tmp_path):
parent=task,
)
example.save_to_file()


def test_valid_human_example_output():
output = ExampleOutput(
output="Test output",
source=ExampleOutputSource.human,
source_properties={"creator": "John Doe"},
)
assert output.source == ExampleOutputSource.human
assert output.source_properties["creator"] == "John Doe"


def test_invalid_human_example_output_missing_creator():
with pytest.raises(
ValidationError,
match="must include \['creator'\]",
):
ExampleOutput(
output="Test output", source=ExampleOutputSource.human, source_properties={}
)


def test_invalid_human_example_output_empty_creator():
with pytest.raises(ValidationError, match="must not be empty string"):
ExampleOutput(
output="Test output",
source=ExampleOutputSource.human,
source_properties={"creator": ""},
)


def test_valid_synthetic_example_output():
output = ExampleOutput(
output="Test output",
source=ExampleOutputSource.synthetic,
source_properties={
"adapter_name": "TestAdapter",
"model_name": "GPT-4",
"model_provider": "OpenAI",
"prompt_builder_name": "TestPromptBuilder",
},
)
assert output.source == ExampleOutputSource.synthetic
assert output.source_properties["adapter_name"] == "TestAdapter"
assert output.source_properties["model_name"] == "GPT-4"
assert output.source_properties["model_provider"] == "OpenAI"
assert output.source_properties["prompt_builder_name"] == "TestPromptBuilder"


def test_invalid_synthetic_example_output_missing_keys():
with pytest.raises(
ValidationError, match="example output source_properties must include"
):
ExampleOutput(
output="Test output",
source=ExampleOutputSource.synthetic,
source_properties={"adapter_name": "TestAdapter", "model_name": "GPT-4"},
)


def test_invalid_synthetic_example_output_empty_values():
with pytest.raises(ValidationError, match="must not be empty string"):
ExampleOutput(
output="Test output",
source=ExampleOutputSource.synthetic,
source_properties={
"adapter_name": "TestAdapter",
"model_name": "",
"model_provider": "OpenAI",
"prompt_builder_name": "TestPromptBuilder",
},
)


def test_invalid_synthetic_example_output_non_string_values():
with pytest.raises(ValidationError, match="Input should be a valid string"):
ExampleOutput(
output="Test output",
source=ExampleOutputSource.synthetic,
source_properties={
"adapter_name": "TestAdapter",
"model_name": "GPT-4",
"model_provider": "OpenAI",
"prompt_builder_name": 123,
},
)

0 comments on commit ec627db

Please sign in to comment.