Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
scosman committed Nov 7, 2024
1 parent 4cf094a commit 83fe045
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 2 deletions.
3 changes: 2 additions & 1 deletion libs/core/kiln_ai/adapters/langchain_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ async def _run(self, input: Dict | str) -> RunOutput:
SystemMessage(content=cot_prompt),
)

cot_response = base_model.invoke(messages)
cot_messages = [*messages]
cot_response = base_model.invoke(cot_messages)
intermediate_outputs["chain_of_thought"] = cot_response.content
messages.append(AIMessage(content=cot_response.content))
messages.append(tool_call_message)
Expand Down
73 changes: 73 additions & 0 deletions libs/core/kiln_ai/adapters/test_langchain_adapter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from unittest.mock import AsyncMock, MagicMock, patch

from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_groq import ChatGroq

from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter
from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder
from kiln_ai.adapters.test_prompt_adaptors import build_test_task


Expand Down Expand Up @@ -49,3 +53,72 @@ def test_langchain_adapter_info(tmp_path):
assert model_info.adapter_name == "kiln_langchain_adapter"
assert model_info.model_name == "llama_3_1_8b"
assert model_info.model_provider == "ollama"


async def test_langchain_adapter_with_cot(tmp_path):
task = build_test_task(tmp_path)
lca = LangChainPromptAdapter(
kiln_task=task,
model_name="llama_3_1_8b",
provider="ollama",
prompt_builder=SimpleChainOfThoughtPromptBuilder(task),
)

# Mock the base model and its invoke method
mock_base_model = MagicMock()
mock_base_model.invoke.return_value = AIMessage(
content="Chain of thought reasoning..."
)

# Create a separate mock for self.model()
mock_model_instance = MagicMock()
mock_model_instance.invoke.return_value = AIMessage(content="Final response...")

# Mock the langchain_model_from function to return the base model
mock_model_from = AsyncMock(return_value=mock_base_model)

# Patch both the langchain_model_from function and self.model()
with (
patch(
"kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from
),
patch.object(LangChainPromptAdapter, "model", return_value=mock_model_instance),
):
response = await lca._run("test input")

# Verify the model was created with correct parameters
# mock_model_from.assert_awaited_once_with("llama_3_1_8b", "ollama")

# First 3 messages are the same for both calls
for invoke_args in [
mock_base_model.invoke.call_args[0][0],
mock_model_instance.invoke.call_args[0][0],
]:
assert isinstance(
invoke_args[0], SystemMessage
) # First message should be system prompt
assert (
"You are an assistant which performs math tasks provided in plain text."
in invoke_args[0].content
)
assert isinstance(invoke_args[1], HumanMessage)
assert "test input" in invoke_args[1].content
assert isinstance(invoke_args[2], SystemMessage)
assert "step by step" in invoke_args[2].content

# the COT should only have 3 messages
assert len(mock_base_model.invoke.call_args[0][0]) == 3
assert len(mock_model_instance.invoke.call_args[0][0]) == 5

# the final response should have the COT content and the final instructions
invoke_args = mock_model_instance.invoke.call_args[0][0]
assert isinstance(invoke_args[3], AIMessage)
assert "Chain of thought reasoning..." in invoke_args[3].content
assert isinstance(invoke_args[4], SystemMessage)
assert "Always respond with a tool call" in invoke_args[4].content

assert (
response.intermediate_outputs["chain_of_thought"]
== "Chain of thought reasoning..."
)
assert response.output == "Final response..."
8 changes: 7 additions & 1 deletion libs/core/kiln_ai/adapters/test_saving_adapter_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def test_save_run_isolation(test_task):
adapter = MockAdapter(test_task)
input_data = "Test input"
output_data = "Test output"
run_output = RunOutput(output=output_data, intermediate_outputs=None)
run_output = RunOutput(
output=output_data,
intermediate_outputs={"chain_of_thought": "Test chain of thought"},
)

task_run = adapter.generate_run(
input=input_data, input_source=None, run_output=run_output
Expand All @@ -53,6 +56,9 @@ def test_save_run_isolation(test_task):
assert task_run.parent == test_task
assert task_run.input == input_data
assert task_run.input_source.type == DataSourceType.human
assert task_run.intermediate_outputs == {
"chain_of_thought": "Test chain of thought"
}
created_by = Config.shared().user_id
if created_by and created_by != "":
assert task_run.input_source.properties["created_by"] == created_by
Expand Down

0 comments on commit 83fe045

Please sign in to comment.