Skip to content

Commit

Permalink
Move to user message vs system message
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Sep 10, 2024
1 parent 174076e commit 6416c78
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 16 deletions.
6 changes: 5 additions & 1 deletion libs/core/kiln_ai/adapters/base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,9 @@ def __init__(self, task: Task, adapter: BaseAdapter | None = None):
self.adapter = adapter

@abstractmethod
def build_prompt(self, input: str) -> str:
def build_prompt(self) -> str:
pass

@abstractmethod
def build_user_message(self, input: str) -> str:
pass
11 changes: 8 additions & 3 deletions libs/core/kiln_ai/adapters/langchain_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import kiln_ai.datamodel.models as models
from kiln_ai.adapters.prompt_builders import SimplePromptBuilder
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages.base import BaseMessage

from .base_adapter import BaseAdapter, BasePromptBuilder
Expand Down Expand Up @@ -58,9 +59,13 @@ def adapter_specific_instructions(self) -> str | None:
return None

async def _run(self, input: str) -> Dict | str:
# TODO cleanup
prompt = self.prompt_builder.build_prompt(input)
response = self.model.invoke(prompt)
prompt = self.prompt_builder.build_prompt()
user_msg = self.prompt_builder.build_user_message(input)
messages = [
SystemMessage(content=prompt),
HumanMessage(content=user_msg),
]
response = self.model.invoke(messages)
if self._is_structured:
if (
not isinstance(response, dict)
Expand Down
7 changes: 4 additions & 3 deletions libs/core/kiln_ai/adapters/prompt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class SimplePromptBuilder(BasePromptBuilder):
def build_prompt(self, input: str) -> str:
def build_prompt(self) -> str:
base_prompt = self.task.instruction

# TODO: this is just a quick version. Formatting and best practices TBD
Expand All @@ -19,6 +19,7 @@ def build_prompt(self, input: str) -> str:
if adapter_instructions is not None:
base_prompt += f"\n\n{adapter_instructions}\n\n"

# TODO: should be another message, not just appended to prompt
base_prompt += f"\n\nThe input is:\n{input}"
return base_prompt

def build_user_message(self, input: str) -> str:
return f"The input is:\n{input}"
18 changes: 10 additions & 8 deletions libs/core/kiln_ai/adapters/test_prompt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@ def test_simple_prompt_builder(tmp_path):
task = build_test_task(tmp_path)
builder = SimplePromptBuilder(task=task)
input = "two plus two"
prompt = builder.build_prompt(input)
prompt = builder.build_prompt()
assert (
"You are an assistant which performs math tasks provided in plain text."
in prompt
)

# TODO this should be a user message later
assert input in prompt

assert "1) " + task.requirements()[0].instruction in prompt
assert "2) " + task.requirements()[1].instruction in prompt
assert "3) " + task.requirements()[2].instruction in prompt

user_msg = builder.build_user_message(input)
assert input in user_msg
assert input not in prompt


class MockAdapter(BaseAdapter):
def adapter_specific_instructions(self) -> str | None:
Expand All @@ -35,11 +36,12 @@ def test_simple_prompt_builder_structured_output(tmp_path):
builder = SimplePromptBuilder(task=task)
builder.adapter = MockAdapter(task)
input = "Cows"
prompt = builder.build_prompt(input)
prompt = builder.build_prompt()
assert "You are an assistant which tells a joke, given a subject." in prompt

# TODO this should be a user message later
assert input in prompt

# check adapter instructions are included
assert "You are a mock, send me the response!" in prompt

user_msg = builder.build_user_message(input)
assert input in user_msg
assert input not in prompt
2 changes: 1 addition & 1 deletion libs/core/kiln_ai/adapters/test_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def test_structured_output_openrouter(tmp_path):

@pytest.mark.paid
async def test_structured_output_bedrock(tmp_path):
await run_structured_output_test(tmp_path, "llama_3_1_8b", "amazon_bedrock")
await run_structured_output_test(tmp_path, "llama_3_1_70b", "amazon_bedrock")


@pytest.mark.ollama
Expand Down

0 comments on commit 6416c78

Please sign in to comment.