From f4b526995331d0fdfe5a43f1a5f0aa3c2c1fc6fc Mon Sep 17 00:00:00 2001 From: scosman Date: Tue, 3 Sep 2024 22:19:41 -0400 Subject: [PATCH] Check the returned data matches the required schema before returning it --- libs/core/kiln_ai/adapters/base_adapter.py | 25 +++++++- .../kiln_ai/adapters/langchain_adapters.py | 17 +++-- libs/core/kiln_ai/adapters/prompt_builders.py | 2 +- .../kiln_ai/adapters/test_base_adapter.py | 0 .../kiln_ai/adapters/test_prompt_adaptors.py | 4 +- .../adapters/test_structured_output.py | 63 +++++++++++++++---- libs/core/kiln_ai/datamodel/json_schema.py | 10 +-- libs/core/kiln_ai/datamodel/models.py | 4 ++ .../kiln_ai/datamodel/test_json_schema.py | 16 ++++- 9 files changed, 109 insertions(+), 32 deletions(-) create mode 100644 libs/core/kiln_ai/adapters/test_base_adapter.py diff --git a/libs/core/kiln_ai/adapters/base_adapter.py b/libs/core/kiln_ai/adapters/base_adapter.py index aaf23e5..02bc13c 100644 --- a/libs/core/kiln_ai/adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/base_adapter.py @@ -1,11 +1,34 @@ from abc import ABCMeta, abstractmethod +from typing import Dict +from kiln_ai.datamodel.json_schema import validate_schema from kiln_ai.datamodel.models import Task class BaseAdapter(metaclass=ABCMeta): + def __init__(self, kiln_task: Task): + self.kiln_task = kiln_task + self._is_structured = self.kiln_task.output_json_schema is not None + + async def invoke(self, input: str) -> Dict | str: + result = await self._run(input) + if self._is_structured: + if not isinstance(result, dict): + raise RuntimeError(f"structured response is not a dict: {result}") + if self.kiln_task.output_json_schema is None: + raise ValueError( + f"output_json_schema is not set for task {self.kiln_task.name}" + ) + validate_schema(result, self.kiln_task.output_json_schema) + else: + if not isinstance(result, str): + raise RuntimeError( + f"response is not a string for non-structured task: {result}" + ) + return result + @abstractmethod - async def run(self, input: str) -> str: + async def _run(self, input: str) -> Dict | str: pass diff --git a/libs/core/kiln_ai/adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/langchain_adapters.py index 76e0403..58570a0 100644 --- a/libs/core/kiln_ai/adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/langchain_adapters.py @@ -1,4 +1,5 @@ import json +from typing import Dict import kiln_ai.datamodel.models as models from kiln_ai.adapters.prompt_builders import SimplePromptBuilder @@ -18,8 +19,7 @@ def __init__( provider: str | None = None, prompt_builder: BasePromptBuilder | None = None, ): - self.kiln_task = kiln_task - self.__is_structured = False + super().__init__(kiln_task) if custom_model is not None: self.model = custom_model elif model_name is not None: @@ -28,36 +28,34 @@ def __init__( raise ValueError( "model_name and provider must be provided if custom_model is not provided" ) - if self.kiln_task.output_json_schema is not None: + if self._is_structured: if not hasattr(self.model, "with_structured_output") or not callable( getattr(self.model, "with_structured_output") ): raise ValueError( f"model {self.model} does not support structured output, cannot use output_json_schema" ) + # Langchain expects title/description to be at top level, on top of json schema output_schema = self.kiln_task.output_schema() if output_schema is None: raise ValueError( f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}" ) - # Langchain expects title/description to be at top level, on top of json schema output_schema["title"] = "task_response" output_schema["description"] = "A response from the task" self.model = self.model.with_structured_output( output_schema, include_raw=True ) - self.__is_structured = True if prompt_builder is None: self.prompt_builder = SimplePromptBuilder(kiln_task) else: self.prompt_builder = prompt_builder - # TODO: don't just append input to prompt - async def run(self, input: str) -> str: + async def _run(self, input: str) -> Dict | str: # TODO cleanup prompt = self.prompt_builder.build_prompt(input) response = self.model.invoke(prompt) - if self.__is_structured: + if self._is_structured: if ( not isinstance(response, dict) or "parsed" not in response @@ -65,8 +63,7 @@ async def run(self, input: str) -> str: ): raise RuntimeError(f"structured response not returned: {response}") structured_response = response["parsed"] - # TODO: not JSON, use a dict here - return json.dumps(structured_response) + return structured_response else: if not isinstance(response, BaseMessage): raise RuntimeError(f"response is not a BaseMessage: {response}") diff --git a/libs/core/kiln_ai/adapters/prompt_builders.py b/libs/core/kiln_ai/adapters/prompt_builders.py index 27205e3..8cce52c 100644 --- a/libs/core/kiln_ai/adapters/prompt_builders.py +++ b/libs/core/kiln_ai/adapters/prompt_builders.py @@ -14,6 +14,6 @@ def build_prompt(self, input: str) -> str: for i, requirement in enumerate(self.task.requirements()): base_prompt += f"{i+1}) {requirement.instruction}\n" - # TODO: should be another message + # TODO: should be another message, not just appended to prompt base_prompt += f"\n\nThe input is:\n{input}" return base_prompt diff --git a/libs/core/kiln_ai/adapters/test_base_adapter.py b/libs/core/kiln_ai/adapters/test_base_adapter.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py index 3b7e903..50cefcd 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_adaptors.py +++ b/libs/core/kiln_ai/adapters/test_prompt_adaptors.py @@ -63,7 +63,7 @@ async def test_mock(tmp_path): task = build_test_task(tmp_path) mockChatModel = FakeListChatModel(responses=["mock response"]) adapter = LangChainPromptAdapter(task, custom_model=mockChatModel) - answer = await adapter.run("You are a mock, send me the response!") + answer = await adapter.invoke("You are a mock, send me the response!") assert "mock response" in answer @@ -118,7 +118,7 @@ async def run_simple_test(tmp_path: Path, model_name: str, provider: str | None async def run_simple_task(task: models.Task, model_name: str, provider: str): adapter = LangChainPromptAdapter(task, model_name=model_name, provider=provider) - answer = await adapter.run( + answer = await adapter.invoke( "You should answer the following question: four plus six times 10" ) assert "64" in answer diff --git a/libs/core/kiln_ai/adapters/test_structured_output.py b/libs/core/kiln_ai/adapters/test_structured_output.py index 1d51dec..2f148cc 100644 --- a/libs/core/kiln_ai/adapters/test_structured_output.py +++ b/libs/core/kiln_ai/adapters/test_structured_output.py @@ -1,8 +1,9 @@ -import json from pathlib import Path +from typing import Dict import kiln_ai.datamodel.models as models import pytest +from kiln_ai.adapters.base_adapter import BaseAdapter from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter from kiln_ai.adapters.ml_model_list import ( built_in_models, @@ -42,30 +43,65 @@ async def test_structured_output_ollama_llama(tmp_path): await run_structured_output_test(tmp_path, "llama_3_1_8b", "ollama") +class MockAdapter(BaseAdapter): + def __init__(self, kiln_task: models.Task, response: Dict | str | None): + super().__init__(kiln_task) + self.response = response + + async def _run(self, input: str) -> Dict | str: + return self.response + + +async def test_mock_unstructred_response(tmp_path): + task = build_structured_output_test_task(tmp_path) + + # don't error on valid response + adapter = MockAdapter(task, response={"setup": "asdf", "punchline": "asdf"}) + answer = await adapter.invoke("You are a mock, send me the response!") + assert answer["setup"] == "asdf" + assert answer["punchline"] == "asdf" + + # error on response that doesn't match schema + adapter = MockAdapter(task, response={"setup": "asdf"}) + with pytest.raises(Exception): + answer = await adapter.invoke("You are a mock, send me the response!") + + adapter = MockAdapter(task, response="string instead of dict") + with pytest.raises(RuntimeError): + # Not a structed response so should error + answer = await adapter.invoke("You are a mock, send me the response!") + + # Should error, expecting a string, not a dict + project = models.Project(name="test", path=tmp_path / "test.kiln") + task = models.Task(parent=project, name="test task") + task.instruction = ( + "You are an assistant which performs math tasks provided in plain text." + ) + adapter = MockAdapter(task, response={"dict": "value"}) + with pytest.raises(RuntimeError): + answer = await adapter.invoke("You are a mock, send me the response!") + + @pytest.mark.paid @pytest.mark.ollama async def test_all_built_in_models_structured_output(tmp_path): for model in built_in_models: if not model.supports_structured_output: print( - f"Skipping {model.model_name} because it does not support structured output" + f"Skipping {model.name} because it does not support structured output" ) continue for provider in model.providers: if not provider.supports_structured_output: print( - f"Skipping {model.model_name} {provider.name} because it does not support structured output" + f"Skipping {model.name} {provider.name} because it does not support structured output" ) continue try: - print(f"Running {model.model_name} {provider}") - await run_structured_output_test( - tmp_path, model.model_name, provider.name - ) + print(f"Running {model.name} {provider}") + await run_structured_output_test(tmp_path, model.name, provider.name) except Exception as e: - raise RuntimeError( - f"Error running {model.model_name} {provider}" - ) from e + raise RuntimeError(f"Error running {model.name} {provider}") from e def build_structured_output_test_task(tmp_path: Path): @@ -90,11 +126,12 @@ def build_structured_output_test_task(tmp_path: Path): async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str): task = build_structured_output_test_task(tmp_path) a = LangChainPromptAdapter(task, model_name=model_name, provider=provider) - result = await a.run("Cows") # a joke about cows - parsed = json.loads(result) + parsed = await a.invoke("Cows") # a joke about cows + if parsed is None or not isinstance(parsed, Dict): + raise RuntimeError(f"structured response is not a dict: {parsed}") assert parsed["setup"] is not None assert parsed["punchline"] is not None - if "rating" in parsed: + if "rating" in parsed and parsed["rating"] is not None: rating = parsed["rating"] # Note: really should be an int according to json schema, but mistral returns a string if isinstance(rating, str): diff --git a/libs/core/kiln_ai/datamodel/json_schema.py b/libs/core/kiln_ai/datamodel/json_schema.py index 9fc1cfe..50dbfa5 100644 --- a/libs/core/kiln_ai/datamodel/json_schema.py +++ b/libs/core/kiln_ai/datamodel/json_schema.py @@ -17,10 +17,12 @@ def _check_json_schema(v: str) -> str: return v -def schema_from_json_str(v: str | None) -> Dict | None: - if v is None: - # Allowing None for now, may make this required later - return None +def validate_schema(instance: Dict, schema_str: str) -> None: + schema = schema_from_json_str(schema_str) + jsonschema.validate(instance=instance, schema=schema) + + +def schema_from_json_str(v: str) -> Dict: try: parsed = json.loads(v) jsonschema.Draft202012Validator.check_schema(parsed) diff --git a/libs/core/kiln_ai/datamodel/models.py b/libs/core/kiln_ai/datamodel/models.py index 3bdf524..aca7bb7 100644 --- a/libs/core/kiln_ai/datamodel/models.py +++ b/libs/core/kiln_ai/datamodel/models.py @@ -55,9 +55,13 @@ class Task(KilnParentedModel): input_json_schema: JsonObjectSchema | None = None def output_schema(self) -> Dict | None: + if self.output_json_schema is None: + return None return schema_from_json_str(self.output_json_schema) def input_schema(self) -> Dict | None: + if self.input_json_schema is None: + return None return schema_from_json_str(self.input_json_schema) @classmethod diff --git a/libs/core/kiln_ai/datamodel/test_json_schema.py b/libs/core/kiln_ai/datamodel/test_json_schema.py index ed3185e..685dbd2 100644 --- a/libs/core/kiln_ai/datamodel/test_json_schema.py +++ b/libs/core/kiln_ai/datamodel/test_json_schema.py @@ -1,5 +1,9 @@ import pytest -from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str +from kiln_ai.datamodel.json_schema import ( + JsonObjectSchema, + schema_from_json_str, + validate_schema, +) from pydantic import BaseModel @@ -59,3 +63,13 @@ def test_json_schema(): o = ExampleModel(x_schema="{'asdf':{}}") with pytest.raises(ValueError): o = ExampleModel(x_schema="{asdf") + + +def test_validate_schema_content(): + o = {"setup": "asdf", "punchline": "asdf", "rating": 1} + validate_schema(o, json_joke_schema) + o = {"setup": "asdf"} + with pytest.raises(Exception): + validate_schema(0, json_joke_schema) + o = {"setup": "asdf", "punchline": "asdf"} + validate_schema(o, json_joke_schema)