diff --git a/libs/core/kiln_ai/adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/langchain_adapters.py index e336a1a..d07b851 100644 --- a/libs/core/kiln_ai/adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/langchain_adapters.py @@ -112,7 +112,8 @@ async def _run(self, input: Dict | str) -> RunOutput: SystemMessage(content=cot_prompt), ) - cot_response = base_model.invoke(messages) + cot_messages = [*messages] + cot_response = base_model.invoke(cot_messages) intermediate_outputs["chain_of_thought"] = cot_response.content messages.append(AIMessage(content=cot_response.content)) messages.append(tool_call_message) diff --git a/libs/core/kiln_ai/adapters/test_langchain_adapter.py b/libs/core/kiln_ai/adapters/test_langchain_adapter.py index da8ca78..f210a42 100644 --- a/libs/core/kiln_ai/adapters/test_langchain_adapter.py +++ b/libs/core/kiln_ai/adapters/test_langchain_adapter.py @@ -1,6 +1,10 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_groq import ChatGroq from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter +from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder from kiln_ai.adapters.test_prompt_adaptors import build_test_task @@ -49,3 +53,72 @@ def test_langchain_adapter_info(tmp_path): assert model_info.adapter_name == "kiln_langchain_adapter" assert model_info.model_name == "llama_3_1_8b" assert model_info.model_provider == "ollama" + + +async def test_langchain_adapter_with_cot(tmp_path): + task = build_test_task(tmp_path) + lca = LangChainPromptAdapter( + kiln_task=task, + model_name="llama_3_1_8b", + provider="ollama", + prompt_builder=SimpleChainOfThoughtPromptBuilder(task), + ) + + # Mock the base model and its invoke method + mock_base_model = MagicMock() + mock_base_model.invoke.return_value = AIMessage( + content="Chain of thought reasoning..." + ) + + # Create a separate mock for self.model() + mock_model_instance = MagicMock() + mock_model_instance.invoke.return_value = AIMessage(content="Final response...") + + # Mock the langchain_model_from function to return the base model + mock_model_from = AsyncMock(return_value=mock_base_model) + + # Patch both the langchain_model_from function and self.model() + with ( + patch( + "kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from + ), + patch.object(LangChainPromptAdapter, "model", return_value=mock_model_instance), + ): + response = await lca._run("test input") + + # Verify the model was created with correct parameters + # mock_model_from.assert_awaited_once_with("llama_3_1_8b", "ollama") + + # First 3 messages are the same for both calls + for invoke_args in [ + mock_base_model.invoke.call_args[0][0], + mock_model_instance.invoke.call_args[0][0], + ]: + assert isinstance( + invoke_args[0], SystemMessage + ) # First message should be system prompt + assert ( + "You are an assistant which performs math tasks provided in plain text." + in invoke_args[0].content + ) + assert isinstance(invoke_args[1], HumanMessage) + assert "test input" in invoke_args[1].content + assert isinstance(invoke_args[2], SystemMessage) + assert "step by step" in invoke_args[2].content + + # the COT should only have 3 messages + assert len(mock_base_model.invoke.call_args[0][0]) == 3 + assert len(mock_model_instance.invoke.call_args[0][0]) == 5 + + # the final response should have the COT content and the final instructions + invoke_args = mock_model_instance.invoke.call_args[0][0] + assert isinstance(invoke_args[3], AIMessage) + assert "Chain of thought reasoning..." in invoke_args[3].content + assert isinstance(invoke_args[4], SystemMessage) + assert "Always respond with a tool call" in invoke_args[4].content + + assert ( + response.intermediate_outputs["chain_of_thought"] + == "Chain of thought reasoning..." + ) + assert response.output == "Final response..." diff --git a/libs/core/kiln_ai/adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/test_saving_adapter_results.py index bf8601c..a83c1fe 100644 --- a/libs/core/kiln_ai/adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/test_saving_adapter_results.py @@ -42,7 +42,10 @@ def test_save_run_isolation(test_task): adapter = MockAdapter(test_task) input_data = "Test input" output_data = "Test output" - run_output = RunOutput(output=output_data, intermediate_outputs=None) + run_output = RunOutput( + output=output_data, + intermediate_outputs={"chain_of_thought": "Test chain of thought"}, + ) task_run = adapter.generate_run( input=input_data, input_source=None, run_output=run_output @@ -53,6 +56,9 @@ def test_save_run_isolation(test_task): assert task_run.parent == test_task assert task_run.input == input_data assert task_run.input_source.type == DataSourceType.human + assert task_run.intermediate_outputs == { + "chain_of_thought": "Test chain of thought" + } created_by = Config.shared().user_id if created_by and created_by != "": assert task_run.input_source.properties["created_by"] == created_by