Skip to content

Commit

Permalink
Check the returned data matches the required schema before returning it
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Sep 4, 2024
1 parent 4b6d6bb commit f4b5269
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 32 deletions.
25 changes: 24 additions & 1 deletion libs/core/kiln_ai/adapters/base_adapter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,34 @@
from abc import ABCMeta, abstractmethod
from typing import Dict

from kiln_ai.datamodel.json_schema import validate_schema
from kiln_ai.datamodel.models import Task


class BaseAdapter(metaclass=ABCMeta):
def __init__(self, kiln_task: Task):
self.kiln_task = kiln_task
self._is_structured = self.kiln_task.output_json_schema is not None

async def invoke(self, input: str) -> Dict | str:
result = await self._run(input)
if self._is_structured:
if not isinstance(result, dict):
raise RuntimeError(f"structured response is not a dict: {result}")
if self.kiln_task.output_json_schema is None:
raise ValueError(
f"output_json_schema is not set for task {self.kiln_task.name}"
)
validate_schema(result, self.kiln_task.output_json_schema)
else:
if not isinstance(result, str):
raise RuntimeError(
f"response is not a string for non-structured task: {result}"
)
return result

@abstractmethod
async def run(self, input: str) -> str:
async def _run(self, input: str) -> Dict | str:
pass


Expand Down
17 changes: 7 additions & 10 deletions libs/core/kiln_ai/adapters/langchain_adapters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import Dict

import kiln_ai.datamodel.models as models
from kiln_ai.adapters.prompt_builders import SimplePromptBuilder
Expand All @@ -18,8 +19,7 @@ def __init__(
provider: str | None = None,
prompt_builder: BasePromptBuilder | None = None,
):
self.kiln_task = kiln_task
self.__is_structured = False
super().__init__(kiln_task)
if custom_model is not None:
self.model = custom_model
elif model_name is not None:
Expand All @@ -28,45 +28,42 @@ def __init__(
raise ValueError(
"model_name and provider must be provided if custom_model is not provided"
)
if self.kiln_task.output_json_schema is not None:
if self._is_structured:
if not hasattr(self.model, "with_structured_output") or not callable(
getattr(self.model, "with_structured_output")
):
raise ValueError(
f"model {self.model} does not support structured output, cannot use output_json_schema"
)
# Langchain expects title/description to be at top level, on top of json schema
output_schema = self.kiln_task.output_schema()
if output_schema is None:
raise ValueError(
f"output_json_schema is not valid json: {self.kiln_task.output_json_schema}"
)
# Langchain expects title/description to be at top level, on top of json schema
output_schema["title"] = "task_response"
output_schema["description"] = "A response from the task"
self.model = self.model.with_structured_output(
output_schema, include_raw=True
)
self.__is_structured = True
if prompt_builder is None:
self.prompt_builder = SimplePromptBuilder(kiln_task)
else:
self.prompt_builder = prompt_builder

# TODO: don't just append input to prompt
async def run(self, input: str) -> str:
async def _run(self, input: str) -> Dict | str:
# TODO cleanup
prompt = self.prompt_builder.build_prompt(input)
response = self.model.invoke(prompt)
if self.__is_structured:
if self._is_structured:
if (
not isinstance(response, dict)
or "parsed" not in response
or not isinstance(response["parsed"], dict)
):
raise RuntimeError(f"structured response not returned: {response}")
structured_response = response["parsed"]
# TODO: not JSON, use a dict here
return json.dumps(structured_response)
return structured_response
else:
if not isinstance(response, BaseMessage):
raise RuntimeError(f"response is not a BaseMessage: {response}")
Expand Down
2 changes: 1 addition & 1 deletion libs/core/kiln_ai/adapters/prompt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ def build_prompt(self, input: str) -> str:
for i, requirement in enumerate(self.task.requirements()):
base_prompt += f"{i+1}) {requirement.instruction}\n"

# TODO: should be another message
# TODO: should be another message, not just appended to prompt
base_prompt += f"\n\nThe input is:\n{input}"
return base_prompt
Empty file.
4 changes: 2 additions & 2 deletions libs/core/kiln_ai/adapters/test_prompt_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def test_mock(tmp_path):
task = build_test_task(tmp_path)
mockChatModel = FakeListChatModel(responses=["mock response"])
adapter = LangChainPromptAdapter(task, custom_model=mockChatModel)
answer = await adapter.run("You are a mock, send me the response!")
answer = await adapter.invoke("You are a mock, send me the response!")
assert "mock response" in answer


Expand Down Expand Up @@ -118,7 +118,7 @@ async def run_simple_test(tmp_path: Path, model_name: str, provider: str | None
async def run_simple_task(task: models.Task, model_name: str, provider: str):
adapter = LangChainPromptAdapter(task, model_name=model_name, provider=provider)

answer = await adapter.run(
answer = await adapter.invoke(
"You should answer the following question: four plus six times 10"
)
assert "64" in answer
63 changes: 50 additions & 13 deletions libs/core/kiln_ai/adapters/test_structured_output.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
from pathlib import Path
from typing import Dict

import kiln_ai.datamodel.models as models
import pytest
from kiln_ai.adapters.base_adapter import BaseAdapter
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
from kiln_ai.adapters.ml_model_list import (
built_in_models,
Expand Down Expand Up @@ -42,30 +43,65 @@ async def test_structured_output_ollama_llama(tmp_path):
await run_structured_output_test(tmp_path, "llama_3_1_8b", "ollama")


class MockAdapter(BaseAdapter):
def __init__(self, kiln_task: models.Task, response: Dict | str | None):
super().__init__(kiln_task)
self.response = response

async def _run(self, input: str) -> Dict | str:
return self.response


async def test_mock_unstructred_response(tmp_path):
task = build_structured_output_test_task(tmp_path)

# don't error on valid response
adapter = MockAdapter(task, response={"setup": "asdf", "punchline": "asdf"})
answer = await adapter.invoke("You are a mock, send me the response!")
assert answer["setup"] == "asdf"
assert answer["punchline"] == "asdf"

# error on response that doesn't match schema
adapter = MockAdapter(task, response={"setup": "asdf"})
with pytest.raises(Exception):
answer = await adapter.invoke("You are a mock, send me the response!")

adapter = MockAdapter(task, response="string instead of dict")
with pytest.raises(RuntimeError):
# Not a structed response so should error
answer = await adapter.invoke("You are a mock, send me the response!")

# Should error, expecting a string, not a dict
project = models.Project(name="test", path=tmp_path / "test.kiln")
task = models.Task(parent=project, name="test task")
task.instruction = (
"You are an assistant which performs math tasks provided in plain text."
)
adapter = MockAdapter(task, response={"dict": "value"})
with pytest.raises(RuntimeError):
answer = await adapter.invoke("You are a mock, send me the response!")


@pytest.mark.paid
@pytest.mark.ollama
async def test_all_built_in_models_structured_output(tmp_path):
for model in built_in_models:
if not model.supports_structured_output:
print(
f"Skipping {model.model_name} because it does not support structured output"
f"Skipping {model.name} because it does not support structured output"
)
continue
for provider in model.providers:
if not provider.supports_structured_output:
print(
f"Skipping {model.model_name} {provider.name} because it does not support structured output"
f"Skipping {model.name} {provider.name} because it does not support structured output"
)
continue
try:
print(f"Running {model.model_name} {provider}")
await run_structured_output_test(
tmp_path, model.model_name, provider.name
)
print(f"Running {model.name} {provider}")
await run_structured_output_test(tmp_path, model.name, provider.name)
except Exception as e:
raise RuntimeError(
f"Error running {model.model_name} {provider}"
) from e
raise RuntimeError(f"Error running {model.name} {provider}") from e


def build_structured_output_test_task(tmp_path: Path):
Expand All @@ -90,11 +126,12 @@ def build_structured_output_test_task(tmp_path: Path):
async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
task = build_structured_output_test_task(tmp_path)
a = LangChainPromptAdapter(task, model_name=model_name, provider=provider)
result = await a.run("Cows") # a joke about cows
parsed = json.loads(result)
parsed = await a.invoke("Cows") # a joke about cows
if parsed is None or not isinstance(parsed, Dict):
raise RuntimeError(f"structured response is not a dict: {parsed}")
assert parsed["setup"] is not None
assert parsed["punchline"] is not None
if "rating" in parsed:
if "rating" in parsed and parsed["rating"] is not None:
rating = parsed["rating"]
# Note: really should be an int according to json schema, but mistral returns a string
if isinstance(rating, str):
Expand Down
10 changes: 6 additions & 4 deletions libs/core/kiln_ai/datamodel/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ def _check_json_schema(v: str) -> str:
return v


def schema_from_json_str(v: str | None) -> Dict | None:
if v is None:
# Allowing None for now, may make this required later
return None
def validate_schema(instance: Dict, schema_str: str) -> None:
schema = schema_from_json_str(schema_str)
jsonschema.validate(instance=instance, schema=schema)


def schema_from_json_str(v: str) -> Dict:
try:
parsed = json.loads(v)
jsonschema.Draft202012Validator.check_schema(parsed)
Expand Down
4 changes: 4 additions & 0 deletions libs/core/kiln_ai/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,13 @@ class Task(KilnParentedModel):
input_json_schema: JsonObjectSchema | None = None

def output_schema(self) -> Dict | None:
if self.output_json_schema is None:
return None
return schema_from_json_str(self.output_json_schema)

def input_schema(self) -> Dict | None:
if self.input_json_schema is None:
return None
return schema_from_json_str(self.input_json_schema)

@classmethod
Expand Down
16 changes: 15 additions & 1 deletion libs/core/kiln_ai/datamodel/test_json_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import pytest
from kiln_ai.datamodel.json_schema import JsonObjectSchema, schema_from_json_str
from kiln_ai.datamodel.json_schema import (
JsonObjectSchema,
schema_from_json_str,
validate_schema,
)
from pydantic import BaseModel


Expand Down Expand Up @@ -59,3 +63,13 @@ def test_json_schema():
o = ExampleModel(x_schema="{'asdf':{}}")
with pytest.raises(ValueError):
o = ExampleModel(x_schema="{asdf")


def test_validate_schema_content():
o = {"setup": "asdf", "punchline": "asdf", "rating": 1}
validate_schema(o, json_joke_schema)
o = {"setup": "asdf"}
with pytest.raises(Exception):
validate_schema(0, json_joke_schema)
o = {"setup": "asdf", "punchline": "asdf"}
validate_schema(o, json_joke_schema)

0 comments on commit f4b5269

Please sign in to comment.