Skip to content

Commit

Permalink
Separate prompt builder concept from adapter.
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Sep 3, 2024
1 parent 0a5453d commit 4e56611
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 30 deletions.
5 changes: 3 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
}
},
"editor.tabSize": 4,
},
"eslint.validate": ["javascript", "javascriptreact", "svelte"],
// Svelte, JS, TS files
Expand All @@ -24,5 +25,5 @@
"**/.ruff_cache": true,
"**/dist": true,
"**/node_modules": true
}
},
}
11 changes: 11 additions & 0 deletions libs/core/kiln_ai/adapters/base_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
from abc import ABCMeta, abstractmethod

from kiln_ai.datamodel.models import Task


class BaseAdapter(metaclass=ABCMeta):
@abstractmethod
async def run(self, input: str) -> str:
pass


class BasePromptBuilder(metaclass=ABCMeta):
def __init__(self, task: Task):
self.task = task

@abstractmethod
def build_prompt(self, input: str) -> str:
pass
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import json
from abc import ABCMeta, abstractmethod

import kiln_ai.datamodel.models as models
from kiln_ai.adapters.prompt_builders import SimplePromptBuilder
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages.base import BaseMessage

from .base_adapter import BaseAdapter
from .base_adapter import BaseAdapter, BasePromptBuilder
from .ml_model_list import langchain_model_from


class BaseLangChainPromptAdapter(BaseAdapter, metaclass=ABCMeta):
class LangChainPromptAdapter(BaseAdapter):
def __init__(
self,
kiln_task: models.Task,
custom_model: BaseChatModel | None = None,
model_name: str | None = None,
provider: str | None = None,
prompt_builder: BasePromptBuilder | None = None,
):
self.kiln_task = kiln_task
self.__is_structured = False
Expand Down Expand Up @@ -46,16 +47,15 @@ def __init__(
output_schema, include_raw=True
)
self.__is_structured = True

@abstractmethod
def build_prompt(self) -> str:
pass
if prompt_builder is None:
self.prompt_builder = SimplePromptBuilder(kiln_task)
else:
self.prompt_builder = prompt_builder

# TODO: don't just append input to prompt
async def run(self, input: str) -> str:
# TODO cleanup
prompt = self.build_prompt()
prompt += f"\n\n{input}"
prompt = self.prompt_builder.build_prompt(input)
response = self.model.invoke(prompt)
if self.__is_structured:
if (
Expand All @@ -74,17 +74,3 @@ async def run(self, input: str) -> str:
if not isinstance(text_content, str):
raise RuntimeError(f"response is not a string: {text_content}")
return text_content


class SimplePromptAdapter(BaseLangChainPromptAdapter):
def build_prompt(self) -> str:
base_prompt = self.kiln_task.instruction

# TODO: this is just a quick version. Formatting and best practices TBD
if len(self.kiln_task.requirements()) > 0:
base_prompt += "\n\nYou should requect the following requirements:\n"
# iterate requirements, formatting them in numbereed list like 1) task.instruction\n2)...
for i, requirement in enumerate(self.kiln_task.requirements()):
base_prompt += f"{i+1}) {requirement.instruction}\n"

return base_prompt
19 changes: 19 additions & 0 deletions libs/core/kiln_ai/adapters/prompt_builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from kiln_ai.adapters.base_adapter import BasePromptBuilder


class SimplePromptBuilder(BasePromptBuilder):
def build_prompt(self, input: str) -> str:
base_prompt = self.task.instruction

# TODO: this is just a quick version. Formatting and best practices TBD
if len(self.task.requirements()) > 0:
base_prompt += (
"\n\nYour response should respectthe following requirements:\n"
)
# iterate requirements, formatting them in numbereed list like 1) task.instruction\n2)...
for i, requirement in enumerate(self.task.requirements()):
base_prompt += f"{i+1}) {requirement.instruction}\n"

# TODO: should be another message
base_prompt += f"\n\nThe input is:\n{input}"
return base_prompt
7 changes: 4 additions & 3 deletions libs/core/kiln_ai/adapters/test_prompt_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import kiln_ai.datamodel.models as models
import pytest
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
from kiln_ai.adapters.ml_model_list import built_in_models, ollama_online
from kiln_ai.adapters.prompt_adapters import SimplePromptAdapter
from langchain_core.language_models.fake_chat_models import FakeListChatModel


Expand Down Expand Up @@ -62,7 +62,7 @@ async def test_amazon_bedrock(tmp_path):
async def test_mock(tmp_path):
task = build_test_task(tmp_path)
mockChatModel = FakeListChatModel(responses=["mock response"])
adapter = SimplePromptAdapter(task, custom_model=mockChatModel)
adapter = LangChainPromptAdapter(task, custom_model=mockChatModel)
answer = await adapter.run("You are a mock, send me the response!")
assert "mock response" in answer

Expand Down Expand Up @@ -116,7 +116,8 @@ async def run_simple_test(tmp_path: Path, model_name: str, provider: str | None


async def run_simple_task(task: models.Task, model_name: str, provider: str):
adapter = SimplePromptAdapter(task, model_name=model_name, provider=provider)
adapter = LangChainPromptAdapter(task, model_name=model_name, provider=provider)

answer = await adapter.run(
"You should answer the following question: four plus six times 10"
)
Expand Down
4 changes: 2 additions & 2 deletions libs/core/kiln_ai/adapters/test_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import kiln_ai.datamodel.models as models
import pytest
from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
from kiln_ai.adapters.ml_model_list import (
built_in_models,
ollama_online,
)
from kiln_ai.adapters.prompt_adapters import SimplePromptAdapter
from kiln_ai.datamodel.test_models import json_joke_schema


Expand Down Expand Up @@ -89,7 +89,7 @@ def build_structured_output_test_task(tmp_path: Path):

async def run_structured_output_test(tmp_path: Path, model_name: str, provider: str):
task = build_structured_output_test_task(tmp_path)
a = SimplePromptAdapter(task, model_name=model_name, provider=provider)
a = LangChainPromptAdapter(task, model_name=model_name, provider=provider)
result = await a.run("Cows") # a joke about cows
parsed = json.loads(result)
assert parsed["setup"] is not None
Expand Down

0 comments on commit 4e56611

Please sign in to comment.