Skip to content

Commit

Permalink
Add structured input and tests.
Browse files Browse the repository at this point in the history
Using JSON format input for model for now, but doing full validation of schema before calling.
  • Loading branch information
scosman committed Sep 10, 2024
1 parent 65ca1ac commit 6344760
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 20 deletions.
39 changes: 27 additions & 12 deletions libs/core/kiln_ai/adapters/base_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from abc import ABCMeta, abstractmethod
from typing import Dict

Expand All @@ -8,27 +9,38 @@
class BaseAdapter(metaclass=ABCMeta):
def __init__(self, kiln_task: Task):
self.kiln_task = kiln_task
self._is_structured_output = self.kiln_task.output_json_schema is not None

async def invoke(self, input: str) -> Dict | str:
self.output_schema = self.kiln_task.output_json_schema

self.input_schema = self.kiln_task.input_json_schema

async def invoke(self, input: Dict | str) -> Dict | str:
# validate input
if self.input_schema is not None:
if not isinstance(input, dict):
raise ValueError(f"structured input is not a dict: {input}")
validate_schema(input, self.input_schema)

# Run
result = await self._run(input)
if self._is_structured_output:

# validate output
if self.output_schema is not None:
if not isinstance(result, dict):
raise RuntimeError(f"structured response is not a dict: {result}")
if self.kiln_task.output_json_schema is None:
raise ValueError(
f"output_json_schema is not set for task {self.kiln_task.name}"
)
validate_schema(result, self.kiln_task.output_json_schema)
validate_schema(result, self.output_schema)
else:
if not isinstance(result, str):
raise RuntimeError(
f"response is not a string for non-structured task: {result}"
)
return result

def has_strctured_output(self) -> bool:
return self.output_schema is not None

@abstractmethod
async def _run(self, input: str) -> Dict | str:
async def _run(self, input: Dict | str) -> Dict | str:
pass

# override for adapter specific instructions (e.g. tool calling, json format, etc)
Expand All @@ -45,6 +57,9 @@ def __init__(self, task: Task, adapter: BaseAdapter | None = None):
def build_prompt(self) -> str:
pass

@abstractmethod
def build_user_message(self, input: str) -> str:
pass
# Can be overridden to add more information to the user message
def build_user_message(self, input: Dict | str) -> str:
if isinstance(input, Dict):
return f"The input is:\n{json.dumps(input, indent=2)}"

return f"The input is:\n{input}"
8 changes: 4 additions & 4 deletions libs/core/kiln_ai/adapters/langchain_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
raise ValueError(
"model_name and provider must be provided if custom_model is not provided"
)
if self._is_structured_output:
if self.has_strctured_output():
if not hasattr(self.model, "with_structured_output") or not callable(
getattr(self.model, "with_structured_output")
):
Expand All @@ -54,19 +54,19 @@ def __init__(

def adapter_specific_instructions(self) -> str | None:
# TODO: would be better to explicitly use bind_tools:tool_choice="task_response" here
if self._is_structured_output:
if self.has_strctured_output():
return "Always respond with a tool call. Never respond with a human readable message."
return None

async def _run(self, input: str) -> Dict | str:
async def _run(self, input: Dict | str) -> Dict | str:
prompt = self.prompt_builder.build_prompt()
user_msg = self.prompt_builder.build_user_message(input)
messages = [
SystemMessage(content=prompt),
HumanMessage(content=user_msg),
]
response = self.model.invoke(messages)
if self._is_structured_output:
if self.has_strctured_output():
if (
not isinstance(response, dict)
or "parsed" not in response
Expand Down
3 changes: 0 additions & 3 deletions libs/core/kiln_ai/adapters/prompt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,3 @@ def build_prompt(self) -> str:
base_prompt += f"\n\n{adapter_instructions}"

return base_prompt

def build_user_message(self, input: str) -> str:
return f"The input is:\n{input}"
58 changes: 57 additions & 1 deletion libs/core/kiln_ai/adapters/test_structured_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pathlib import Path
from typing import Dict

import jsonschema
import jsonschema.exceptions
import kiln_ai.datamodel.models as models
import pytest
from kiln_ai.adapters.base_adapter import BaseAdapter
Expand All @@ -9,7 +11,7 @@
built_in_models,
ollama_online,
)
from kiln_ai.datamodel.test_models import json_joke_schema
from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema


@pytest.mark.paid
Expand Down Expand Up @@ -143,3 +145,57 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider:
rating = int(rating)
assert rating >= 0
assert rating <= 10


def build_structured_input_test_task(tmp_path: Path):
project = models.Project(name="test", path=tmp_path / "test.kiln")
project.save_to_file()
task = models.Task(
parent=project,
name="test task",
instruction="You are an assistant which classifies a triangle given the lengths of its sides. If all sides are of equal length, the triangle is equilateral. If two sides are equal, the triangle is isosceles. Otherwise, it is scalene.\n\nAt the end of your response return the result in double square brackets. It should be plain text. It should be exactly one of the three following strings: '[[equilateral]]', or '[[isosceles]]', or '[[scalene]]'.",
)
task.input_json_schema = json_triangle_schema
schema = task.input_schema()
assert schema is not None
assert schema["properties"]["a"]["type"] == "integer"
assert schema["properties"]["b"]["type"] == "integer"
assert schema["properties"]["c"]["type"] == "integer"
assert schema["required"] == ["a", "b", "c"]
task.save_to_file()
assert task.name == "test task"
assert len(task.requirements()) == 0
return task


async def run_structured_input_test(tmp_path: Path, model_name: str, provider: str):
task = build_structured_input_test_task(tmp_path)
a = LangChainPromptAdapter(task, model_name=model_name, provider=provider)
with pytest.raises(ValueError):
# not structured input in dictionary
await a.invoke("a=1, b=2, c=3")
with pytest.raises(jsonschema.exceptions.ValidationError):
# invalid structured input
await a.invoke({"a": 1, "b": 2, "d": 3})

response = await a.invoke({"a": 2, "b": 2, "c": 2})
assert response is not None
assert isinstance(response, str)
assert "[[equilateral]]" in response


@pytest.mark.paid
async def test_structured_input_gpt_4o_mini(tmp_path):
await run_structured_input_test(tmp_path, "llama_3_1_8b", "groq")


@pytest.mark.paid
@pytest.mark.ollama
async def test_all_built_in_models_structured_input(tmp_path):
for model in built_in_models:
for provider in model.providers:
try:
print(f"Running {model.name} {provider.name}")
await run_structured_input_test(tmp_path, model.name, provider.name)
except Exception as e:
raise RuntimeError(f"Error running {model.name} {provider}") from e
49 changes: 49 additions & 0 deletions libs/core/kiln_ai/datamodel/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,52 @@ def test_validate_schema_content():
validate_schema(0, json_joke_schema)
o = {"setup": "asdf", "punchline": "asdf"}
validate_schema(o, json_joke_schema)
o = {"setup": "asdf", "punchline": "asdf", "rating": "1"}
with pytest.raises(Exception):
validate_schema(o, json_joke_schema)


json_triangle_schema = """{
"type": "object",
"properties": {
"a": {
"description": "length of side a",
"title": "A",
"type": "integer"
},
"b": {
"description": "length of side b",
"title": "B",
"type": "integer"
},
"c": {
"description": "length of side c",
"title": "C",
"type": "integer"
}
},
"required": [
"a",
"b",
"c"
]
}
"""


def test_triangle_schema():
o = ExampleModel(x_schema=json_joke_schema)
parsed_schema = schema_from_json_str(o.x_schema)
assert parsed_schema is not None

o = ExampleModel(x_schema=json_triangle_schema)
schema = schema_from_json_str(o.x_schema)

assert schema is not None
assert schema["properties"]["a"]["type"] == "integer"
assert schema["properties"]["b"]["type"] == "integer"
assert schema["properties"]["c"]["type"] == "integer"
assert schema["required"] == ["a", "b", "c"]
validate_schema({"a": 1, "b": 2, "c": 3}, json_triangle_schema)
with pytest.raises(Exception):
validate_schema({"a": 1, "b": 2, "c": "3"}, json_triangle_schema)

0 comments on commit 6344760

Please sign in to comment.