From 6344760f6e98c0972919e972a3ce474381137db2 Mon Sep 17 00:00:00 2001 From: scosman Date: Mon, 9 Sep 2024 23:42:50 -0400 Subject: [PATCH] Add structured input and tests. Using JSON format input for model for now, but doing full validation of schema before calling. --- libs/core/kiln_ai/adapters/base_adapter.py | 39 +++++++++---- .../kiln_ai/adapters/langchain_adapters.py | 8 +-- libs/core/kiln_ai/adapters/prompt_builders.py | 3 - .../adapters/test_structured_output.py | 58 ++++++++++++++++++- .../kiln_ai/datamodel/test_json_schema.py | 49 ++++++++++++++++ 5 files changed, 137 insertions(+), 20 deletions(-) diff --git a/libs/core/kiln_ai/adapters/base_adapter.py b/libs/core/kiln_ai/adapters/base_adapter.py index 58468d2..22a8d2a 100644 --- a/libs/core/kiln_ai/adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/base_adapter.py @@ -1,3 +1,4 @@ +import json from abc import ABCMeta, abstractmethod from typing import Dict @@ -8,18 +9,26 @@ class BaseAdapter(metaclass=ABCMeta): def __init__(self, kiln_task: Task): self.kiln_task = kiln_task - self._is_structured_output = self.kiln_task.output_json_schema is not None - async def invoke(self, input: str) -> Dict | str: + self.output_schema = self.kiln_task.output_json_schema + + self.input_schema = self.kiln_task.input_json_schema + + async def invoke(self, input: Dict | str) -> Dict | str: + # validate input + if self.input_schema is not None: + if not isinstance(input, dict): + raise ValueError(f"structured input is not a dict: {input}") + validate_schema(input, self.input_schema) + + # Run result = await self._run(input) - if self._is_structured_output: + + # validate output + if self.output_schema is not None: if not isinstance(result, dict): raise RuntimeError(f"structured response is not a dict: {result}") - if self.kiln_task.output_json_schema is None: - raise ValueError( - f"output_json_schema is not set for task {self.kiln_task.name}" - ) - validate_schema(result, self.kiln_task.output_json_schema) + validate_schema(result, self.output_schema) else: if not isinstance(result, str): raise RuntimeError( @@ -27,8 +36,11 @@ async def invoke(self, input: str) -> Dict | str: ) return result + def has_strctured_output(self) -> bool: + return self.output_schema is not None + @abstractmethod - async def _run(self, input: str) -> Dict | str: + async def _run(self, input: Dict | str) -> Dict | str: pass # override for adapter specific instructions (e.g. tool calling, json format, etc) @@ -45,6 +57,9 @@ def __init__(self, task: Task, adapter: BaseAdapter | None = None): def build_prompt(self) -> str: pass - @abstractmethod - def build_user_message(self, input: str) -> str: - pass + # Can be overridden to add more information to the user message + def build_user_message(self, input: Dict | str) -> str: + if isinstance(input, Dict): + return f"The input is:\n{json.dumps(input, indent=2)}" + + return f"The input is:\n{input}" diff --git a/libs/core/kiln_ai/adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/langchain_adapters.py index cb70988..34e636c 100644 --- a/libs/core/kiln_ai/adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/langchain_adapters.py @@ -28,7 +28,7 @@ def __init__( raise ValueError( "model_name and provider must be provided if custom_model is not provided" ) - if self._is_structured_output: + if self.has_strctured_output(): if not hasattr(self.model, "with_structured_output") or not callable( getattr(self.model, "with_structured_output") ): @@ -54,11 +54,11 @@ def __init__( def adapter_specific_instructions(self) -> str | None: # TODO: would be better to explicitly use bind_tools:tool_choice="task_response" here - if self._is_structured_output: + if self.has_strctured_output(): return "Always respond with a tool call. Never respond with a human readable message." return None - async def _run(self, input: str) -> Dict | str: + async def _run(self, input: Dict | str) -> Dict | str: prompt = self.prompt_builder.build_prompt() user_msg = self.prompt_builder.build_user_message(input) messages = [ @@ -66,7 +66,7 @@ async def _run(self, input: str) -> Dict | str: HumanMessage(content=user_msg), ] response = self.model.invoke(messages) - if self._is_structured_output: + if self.has_strctured_output(): if ( not isinstance(response, dict) or "parsed" not in response diff --git a/libs/core/kiln_ai/adapters/prompt_builders.py b/libs/core/kiln_ai/adapters/prompt_builders.py index c4ddbbc..175ee0e 100644 --- a/libs/core/kiln_ai/adapters/prompt_builders.py +++ b/libs/core/kiln_ai/adapters/prompt_builders.py @@ -20,6 +20,3 @@ def build_prompt(self) -> str: base_prompt += f"\n\n{adapter_instructions}" return base_prompt - - def build_user_message(self, input: str) -> str: - return f"The input is:\n{input}" diff --git a/libs/core/kiln_ai/adapters/test_structured_output.py b/libs/core/kiln_ai/adapters/test_structured_output.py index 87d012d..b66f83e 100644 --- a/libs/core/kiln_ai/adapters/test_structured_output.py +++ b/libs/core/kiln_ai/adapters/test_structured_output.py @@ -1,6 +1,8 @@ from pathlib import Path from typing import Dict +import jsonschema +import jsonschema.exceptions import kiln_ai.datamodel.models as models import pytest from kiln_ai.adapters.base_adapter import BaseAdapter @@ -9,7 +11,7 @@ built_in_models, ollama_online, ) -from kiln_ai.datamodel.test_models import json_joke_schema +from kiln_ai.datamodel.test_json_schema import json_joke_schema, json_triangle_schema @pytest.mark.paid @@ -143,3 +145,57 @@ async def run_structured_output_test(tmp_path: Path, model_name: str, provider: rating = int(rating) assert rating >= 0 assert rating <= 10 + + +def build_structured_input_test_task(tmp_path: Path): + project = models.Project(name="test", path=tmp_path / "test.kiln") + project.save_to_file() + task = models.Task( + parent=project, + name="test task", + instruction="You are an assistant which classifies a triangle given the lengths of its sides. If all sides are of equal length, the triangle is equilateral. If two sides are equal, the triangle is isosceles. Otherwise, it is scalene.\n\nAt the end of your response return the result in double square brackets. It should be plain text. It should be exactly one of the three following strings: '[[equilateral]]', or '[[isosceles]]', or '[[scalene]]'.", + ) + task.input_json_schema = json_triangle_schema + schema = task.input_schema() + assert schema is not None + assert schema["properties"]["a"]["type"] == "integer" + assert schema["properties"]["b"]["type"] == "integer" + assert schema["properties"]["c"]["type"] == "integer" + assert schema["required"] == ["a", "b", "c"] + task.save_to_file() + assert task.name == "test task" + assert len(task.requirements()) == 0 + return task + + +async def run_structured_input_test(tmp_path: Path, model_name: str, provider: str): + task = build_structured_input_test_task(tmp_path) + a = LangChainPromptAdapter(task, model_name=model_name, provider=provider) + with pytest.raises(ValueError): + # not structured input in dictionary + await a.invoke("a=1, b=2, c=3") + with pytest.raises(jsonschema.exceptions.ValidationError): + # invalid structured input + await a.invoke({"a": 1, "b": 2, "d": 3}) + + response = await a.invoke({"a": 2, "b": 2, "c": 2}) + assert response is not None + assert isinstance(response, str) + assert "[[equilateral]]" in response + + +@pytest.mark.paid +async def test_structured_input_gpt_4o_mini(tmp_path): + await run_structured_input_test(tmp_path, "llama_3_1_8b", "groq") + + +@pytest.mark.paid +@pytest.mark.ollama +async def test_all_built_in_models_structured_input(tmp_path): + for model in built_in_models: + for provider in model.providers: + try: + print(f"Running {model.name} {provider.name}") + await run_structured_input_test(tmp_path, model.name, provider.name) + except Exception as e: + raise RuntimeError(f"Error running {model.name} {provider}") from e diff --git a/libs/core/kiln_ai/datamodel/test_json_schema.py b/libs/core/kiln_ai/datamodel/test_json_schema.py index 685dbd2..05521a1 100644 --- a/libs/core/kiln_ai/datamodel/test_json_schema.py +++ b/libs/core/kiln_ai/datamodel/test_json_schema.py @@ -73,3 +73,52 @@ def test_validate_schema_content(): validate_schema(0, json_joke_schema) o = {"setup": "asdf", "punchline": "asdf"} validate_schema(o, json_joke_schema) + o = {"setup": "asdf", "punchline": "asdf", "rating": "1"} + with pytest.raises(Exception): + validate_schema(o, json_joke_schema) + + +json_triangle_schema = """{ + "type": "object", + "properties": { + "a": { + "description": "length of side a", + "title": "A", + "type": "integer" + }, + "b": { + "description": "length of side b", + "title": "B", + "type": "integer" + }, + "c": { + "description": "length of side c", + "title": "C", + "type": "integer" + } + }, + "required": [ + "a", + "b", + "c" + ] +} +""" + + +def test_triangle_schema(): + o = ExampleModel(x_schema=json_joke_schema) + parsed_schema = schema_from_json_str(o.x_schema) + assert parsed_schema is not None + + o = ExampleModel(x_schema=json_triangle_schema) + schema = schema_from_json_str(o.x_schema) + + assert schema is not None + assert schema["properties"]["a"]["type"] == "integer" + assert schema["properties"]["b"]["type"] == "integer" + assert schema["properties"]["c"]["type"] == "integer" + assert schema["required"] == ["a", "b", "c"] + validate_schema({"a": 1, "b": 2, "c": 3}, json_triangle_schema) + with pytest.raises(Exception): + validate_schema({"a": 1, "b": 2, "c": "3"}, json_triangle_schema)