diff --git a/app/desktop/studio_server/prompt_api.py b/app/desktop/studio_server/prompt_api.py index 91c35cc..b7d8ff8 100644 --- a/app/desktop/studio_server/prompt_api.py +++ b/app/desktop/studio_server/prompt_api.py @@ -20,7 +20,7 @@ async def generate_prompt( try: prompt_builder_class = prompt_builder_from_ui_name(prompt_generator) prompt_builder = prompt_builder_class(task) - prompt = prompt_builder.build_prompt() + prompt = prompt_builder.build_prompt_for_ui() except Exception as e: raise HTTPException(status_code=400, detail=str(e)) diff --git a/app/web_ui/src/lib/api_schema.d.ts b/app/web_ui/src/lib/api_schema.d.ts index fd1534b..487ac61 100644 --- a/app/web_ui/src/lib/api_schema.d.ts +++ b/app/web_ui/src/lib/api_schema.d.ts @@ -442,7 +442,7 @@ export interface components { * Where models have instruct and raw versions, instruct is default and raw is specified. * @enum {string} */ - ModelName: "llama_3_1_8b" | "llama_3_1_70b" | "llama_3_1_405b" | "llama_3_2_3b" | "llama_3_2_11b" | "llama_3_2_90b" | "gpt_4o_mini" | "gpt_4o" | "phi_3_5" | "mistral_large" | "mistral_nemo" | "gemma_2_2b" | "gemma_2_9b" | "gemma_2_27b" | "claude_3_5_sonnet" | "gemini_1_5_flash" | "gemini_1_5_flash_8b" | "gemini_1_5_pro" | "nemotron_70b"; + ModelName: "llama_3_1_8b" | "llama_3_1_70b" | "llama_3_1_405b" | "llama_3_2_3b" | "llama_3_2_11b" | "llama_3_2_90b" | "gpt_4o_mini" | "gpt_4o" | "phi_3_5" | "mistral_large" | "mistral_nemo" | "gemma_2_2b" | "gemma_2_9b" | "gemma_2_27b" | "claude_3_5_haiku" | "claude_3_5_sonnet" | "gemini_1_5_flash" | "gemini_1_5_flash_8b" | "gemini_1_5_pro" | "nemotron_70b"; /** OllamaConnection */ OllamaConnection: { /** Message */ @@ -480,7 +480,10 @@ export interface components { created_at?: string; /** Created By */ created_by?: string; - /** Name */ + /** + * Name + * @description A name for this entity. + */ name: string; /** * Description @@ -512,7 +515,10 @@ export interface components { created_at?: string; /** Created By */ created_by?: string; - /** Name */ + /** + * Name + * @description A name for this entity. + */ name: string; /** * Description @@ -606,7 +612,10 @@ export interface components { created_at?: string; /** Created By */ created_by?: string; - /** Name */ + /** + * Name + * @description A name for this entity. + */ name: string; /** * Description @@ -628,6 +637,11 @@ export interface components { output_json_schema?: string | null; /** Input Json Schema */ input_json_schema?: string | null; + /** + * Thinking Instruction + * @description Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting. + */ + thinking_instruction?: string | null; /** Model Type */ readonly model_type: string; }; @@ -807,7 +821,10 @@ export interface components { TaskRequirement: { /** Id */ id?: string | null; - /** Name */ + /** + * Name + * @description A name for this entity + */ name: string; /** Description */ description?: string | null; @@ -856,6 +873,13 @@ export interface components { repair_instructions?: string | null; /** @description An version of the output with issues fixed. This must be a 'fixed' version of the existing output, and not an entirely new output. If you wish to generate an ideal curatorial output for this task unrelated to this output, generate a new TaskOutput with type 'human' instead of using this field. */ repaired_output?: components["schemas"]["TaskOutput-Input"] | null; + /** + * Intermediate Outputs + * @description Intermediate outputs from the task run. Keys are the names of the intermediate output steps (cot=chain of thought, etc), values are the output data. + */ + intermediate_outputs?: { + [key: string]: string; + } | null; }; /** * TaskRun @@ -897,6 +921,13 @@ export interface components { repair_instructions?: string | null; /** @description An version of the output with issues fixed. This must be a 'fixed' version of the existing output, and not an entirely new output. If you wish to generate an ideal curatorial output for this task unrelated to this output, generate a new TaskOutput with type 'human' instead of using this field. */ repaired_output?: components["schemas"]["TaskOutput-Output"] | null; + /** + * Intermediate Outputs + * @description Intermediate outputs from the task run. Keys are the names of the intermediate output steps (cot=chain of thought, etc), values are the output data. + */ + intermediate_outputs?: { + [key: string]: string; + } | null; /** Model Type */ readonly model_type: string; }; diff --git a/app/web_ui/src/lib/utils/form_element.svelte b/app/web_ui/src/lib/utils/form_element.svelte index e22f36a..5010700 100644 --- a/app/web_ui/src/lib/utils/form_element.svelte +++ b/app/web_ui/src/lib/utils/form_element.svelte @@ -70,7 +70,7 @@ for={id} class="text-sm font-medium text-left flex flex-col gap-1 pb-[4px]" > -
+
{label} diff --git a/app/web_ui/src/routes/(app)/run/prompt_type_selector.svelte b/app/web_ui/src/routes/(app)/run/prompt_type_selector.svelte index dbf4afb..75e9cb8 100644 --- a/app/web_ui/src/routes/(app)/run/prompt_type_selector.svelte +++ b/app/web_ui/src/routes/(app)/run/prompt_type_selector.svelte @@ -14,5 +14,8 @@ ["few_shot", "Few Shot"], ["many_shot", "Many Shot"], ["repairs", "Repair Multi Shot"], + ["simple_chain_of_thought", "Basic Chain of Thought"], + ["few_shot_chain_of_thought", "Chain of Thought - Few Shot"], + ["multi_shot_chain_of_thought", "Chain of Thought - Many Shot"], ]} /> diff --git a/app/web_ui/src/routes/(fullscreen)/setup/(setup)/create_task/edit_task.svelte b/app/web_ui/src/routes/(fullscreen)/setup/(setup)/create_task/edit_task.svelte index 8605779..4bbf71f 100644 --- a/app/web_ui/src/routes/(fullscreen)/setup/(setup)/create_task/edit_task.svelte +++ b/app/web_ui/src/routes/(fullscreen)/setup/(setup)/create_task/edit_task.svelte @@ -68,6 +68,7 @@ description: task.description, instruction: task.instruction, requirements: task.requirements, + thinking_instruction: task.thinking_instruction, } // Can only set schemas when creating a new task if (creating) { @@ -145,6 +146,7 @@ !!task.name || !!task.description || !!task.instruction || + !!task.thinking_instruction || has_edited_requirements || !!inputSchemaSection.get_schema_string() || !!outputSchemaSection.get_schema_string() @@ -268,6 +270,16 @@ bind:value={task.instruction} /> + +
Part 2: Requirements diff --git a/libs/core/kiln_ai/adapters/base_adapter.py b/libs/core/kiln_ai/adapters/base_adapter.py index 05bd079..4038f56 100644 --- a/libs/core/kiln_ai/adapters/base_adapter.py +++ b/libs/core/kiln_ai/adapters/base_adapter.py @@ -24,6 +24,12 @@ class AdapterInfo: prompt_builder_name: str +@dataclass +class RunOutput: + output: Dict | str + intermediate_outputs: Dict[str, str] | None + + class BaseAdapter(metaclass=ABCMeta): """Base class for AI model adapters that handle task execution. @@ -36,22 +42,6 @@ class BaseAdapter(metaclass=ABCMeta): kiln_task (Task): The task configuration and metadata output_schema (dict | None): JSON schema for validating structured outputs input_schema (dict | None): JSON schema for validating structured inputs - - Example: - ```python - class CustomAdapter(BaseAdapter): - async def _run(self, input: Dict | str) -> Dict | str: - # Implementation for specific model - pass - - def adapter_info(self) -> AdapterInfo: - return AdapterInfo( - adapter_name="custom", - model_name="model-1", - model_provider="provider", - prompt_builder_name="simple" - ) - ``` """ def __init__( @@ -85,21 +75,23 @@ async def invoke( validate_schema(input, self.input_schema) # Run - result = await self._run(input) + run_output = await self._run(input) # validate output if self.output_schema is not None: - if not isinstance(result, dict): - raise RuntimeError(f"structured response is not a dict: {result}") - validate_schema(result, self.output_schema) + if not isinstance(run_output.output, dict): + raise RuntimeError( + f"structured response is not a dict: {run_output.output}" + ) + validate_schema(run_output.output, self.output_schema) else: - if not isinstance(result, str): + if not isinstance(run_output.output, str): raise RuntimeError( - f"response is not a string for non-structured task: {result}" + f"response is not a string for non-structured task: {run_output.output}" ) # Generate the run and output - run = self.generate_run(input, input_source, result) + run = self.generate_run(input, input_source, run_output) # Save the run if configured to do so, and we have a path to save to if Config.shared().autosave_runs and self.kiln_task.path is not None: @@ -118,27 +110,23 @@ def adapter_info(self) -> AdapterInfo: pass @abstractmethod - async def _run(self, input: Dict | str) -> Dict | str: + async def _run(self, input: Dict | str) -> RunOutput: pass def build_prompt(self) -> str: - prompt = self.prompt_builder.build_prompt() - adapter_instructions = self.adapter_specific_instructions() - if adapter_instructions is not None: - prompt += f"# Format Instructions\n\n{adapter_instructions}\n\n" - return prompt - - # override for adapter specific instructions (e.g. tool calling, json format, etc) - def adapter_specific_instructions(self) -> str | None: - return None + return self.prompt_builder.build_prompt() # create a run and task output def generate_run( - self, input: Dict | str, input_source: DataSource | None, output: Dict | str + self, input: Dict | str, input_source: DataSource | None, run_output: RunOutput ) -> TaskRun: # Convert input and output to JSON strings if they are dictionaries input_str = json.dumps(input) if isinstance(input, dict) else input - output_str = json.dumps(output) if isinstance(output, dict) else output + output_str = ( + json.dumps(run_output.output) + if isinstance(run_output.output, dict) + else run_output.output + ) # If no input source is provided, use the human data source if input_source is None: @@ -159,6 +147,7 @@ def generate_run( properties=self._properties_for_task_output(), ), ), + intermediate_outputs=run_output.intermediate_outputs, ) exclude_fields = { diff --git a/libs/core/kiln_ai/adapters/langchain_adapters.py b/libs/core/kiln_ai/adapters/langchain_adapters.py index eef41ae..455385e 100644 --- a/libs/core/kiln_ai/adapters/langchain_adapters.py +++ b/libs/core/kiln_ai/adapters/langchain_adapters.py @@ -2,14 +2,14 @@ from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.messages.base import BaseMessage from langchain_core.runnables import Runnable from pydantic import BaseModel import kiln_ai.datamodel as datamodel -from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder +from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput from .ml_model_list import langchain_model_from LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel] @@ -84,15 +84,37 @@ async def model(self) -> LangChainModelType: ) return self._model - async def _run(self, input: Dict | str) -> Dict | str: + async def _run(self, input: Dict | str) -> RunOutput: + model = await self.model() + chain = model + intermediate_outputs = {} + prompt = self.build_prompt() user_msg = self.prompt_builder.build_user_message(input) messages = [ SystemMessage(content=prompt), HumanMessage(content=user_msg), ] - model = await self.model() - response = model.invoke(messages) + + cot_prompt = self.prompt_builder.chain_of_thought_prompt() + if cot_prompt: + # Base model (without structured output) used for COT + base_model = await langchain_model_from( + self.model_name, self.model_provider + ) + messages.append( + SystemMessage(content=cot_prompt), + ) + + cot_messages = [*messages] + cot_response = base_model.invoke(cot_messages) + intermediate_outputs["chain_of_thought"] = cot_response.content + messages.append(AIMessage(content=cot_response.content)) + messages.append( + SystemMessage(content="Considering the above, return a final result.") + ) + + response = chain.invoke(messages) if self.has_structured_output(): if ( @@ -102,14 +124,20 @@ async def _run(self, input: Dict | str) -> Dict | str: ): raise RuntimeError(f"structured response not returned: {response}") structured_response = response["parsed"] - return self._munge_response(structured_response) + return RunOutput( + output=self._munge_response(structured_response), + intermediate_outputs=intermediate_outputs, + ) else: if not isinstance(response, BaseMessage): raise RuntimeError(f"response is not a BaseMessage: {response}") text_content = response.content if not isinstance(text_content, str): raise RuntimeError(f"response is not a string: {text_content}") - return text_content + return RunOutput( + output=text_content, + intermediate_outputs=intermediate_outputs, + ) def adapter_info(self) -> AdapterInfo: return AdapterInfo( diff --git a/libs/core/kiln_ai/adapters/prompt_builders.py b/libs/core/kiln_ai/adapters/prompt_builders.py index cd95267..5b527cf 100644 --- a/libs/core/kiln_ai/adapters/prompt_builders.py +++ b/libs/core/kiln_ai/adapters/prompt_builders.py @@ -54,6 +54,28 @@ def build_user_message(self, input: Dict | str) -> str: return f"The input is:\n{input}" + def chain_of_thought_prompt(self) -> str | None: + """Build and return the chain of thought prompt string. + + Returns: + str: The constructed chain of thought prompt. + """ + return None + + def build_prompt_for_ui(self) -> str: + """Build a prompt for the UI. It includes additional instructions (like chain of thought), even if they are passed to the model in stages. + + Designed for end-user consumption, not for model consumption. + + Returns: + str: The constructed prompt string. + """ + base_prompt = self.build_prompt() + cot_prompt = self.chain_of_thought_prompt() + if cot_prompt: + base_prompt += "\n# Thinking Instructions\n\n" + cot_prompt + return base_prompt + class SimplePromptBuilder(BasePromptBuilder): """A basic prompt builder that combines task instruction with requirements.""" @@ -187,11 +209,48 @@ def prompt_section_for_example(self, index: int, example: TaskRun) -> str: return prompt_section +def chain_of_thought_prompt(task: Task) -> str | None: + """Standard implementation to build and return the chain of thought prompt string. + + Returns: + str: The constructed chain of thought prompt. + """ + + if task.thinking_instruction: + return task.thinking_instruction + + return "Think step by step, explaining your reasoning, before responding with an answer." + + +class SimpleChainOfThoughtPromptBuilder(SimplePromptBuilder): + """A prompt builder that includes a chain of thought prompt on top of the simple prompt.""" + + def chain_of_thought_prompt(self) -> str | None: + return chain_of_thought_prompt(self.task) + + +class FewShotChainOfThoughtPromptBuilder(FewShotPromptBuilder): + """A prompt builder that includes a chain of thought prompt on top of the few shot prompt.""" + + def chain_of_thought_prompt(self) -> str | None: + return chain_of_thought_prompt(self.task) + + +class MultiShotChainOfThoughtPromptBuilder(MultiShotPromptBuilder): + """A prompt builder that includes a chain of thought prompt on top of the multi shot prompt.""" + + def chain_of_thought_prompt(self) -> str | None: + return chain_of_thought_prompt(self.task) + + prompt_builder_registry = { "simple_prompt_builder": SimplePromptBuilder, "multi_shot_prompt_builder": MultiShotPromptBuilder, "few_shot_prompt_builder": FewShotPromptBuilder, "repairs_prompt_builder": RepairsPromptBuilder, + "simple_chain_of_thought_prompt_builder": SimpleChainOfThoughtPromptBuilder, + "few_shot_chain_of_thought_prompt_builder": FewShotChainOfThoughtPromptBuilder, + "multi_shot_chain_of_thought_prompt_builder": MultiShotChainOfThoughtPromptBuilder, } @@ -217,5 +276,11 @@ def prompt_builder_from_ui_name(ui_name: str) -> type[BasePromptBuilder]: return MultiShotPromptBuilder case "repairs": return RepairsPromptBuilder + case "simple_chain_of_thought": + return SimpleChainOfThoughtPromptBuilder + case "few_shot_chain_of_thought": + return FewShotChainOfThoughtPromptBuilder + case "multi_shot_chain_of_thought": + return MultiShotChainOfThoughtPromptBuilder case _: raise ValueError(f"Unknown prompt builder: {ui_name}") diff --git a/libs/core/kiln_ai/adapters/repair/test_repair_task.py b/libs/core/kiln_ai/adapters/repair/test_repair_task.py index 837337d..33297a2 100644 --- a/libs/core/kiln_ai/adapters/repair/test_repair_task.py +++ b/libs/core/kiln_ai/adapters/repair/test_repair_task.py @@ -5,6 +5,7 @@ import pytest from pydantic import ValidationError +from kiln_ai.adapters.base_adapter import RunOutput from kiln_ai.adapters.langchain_adapters import ( LangChainPromptAdapter, ) @@ -222,7 +223,9 @@ async def test_mocked_repair_task_run(sample_task, sample_task_run, sample_repai with patch.object( LangChainPromptAdapter, "_run", new_callable=AsyncMock ) as mock_run: - mock_run.return_value = mocked_output + mock_run.return_value = RunOutput( + output=mocked_output, intermediate_outputs=None + ) adapter = LangChainPromptAdapter( repair_task, model_name="llama_3_1_8b", provider="groq" diff --git a/libs/core/kiln_ai/adapters/test_langchain_adapter.py b/libs/core/kiln_ai/adapters/test_langchain_adapter.py index da8ca78..2a2c57d 100644 --- a/libs/core/kiln_ai/adapters/test_langchain_adapter.py +++ b/libs/core/kiln_ai/adapters/test_langchain_adapter.py @@ -1,6 +1,10 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_groq import ChatGroq from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter +from kiln_ai.adapters.prompt_builders import SimpleChainOfThoughtPromptBuilder from kiln_ai.adapters.test_prompt_adaptors import build_test_task @@ -49,3 +53,72 @@ def test_langchain_adapter_info(tmp_path): assert model_info.adapter_name == "kiln_langchain_adapter" assert model_info.model_name == "llama_3_1_8b" assert model_info.model_provider == "ollama" + + +async def test_langchain_adapter_with_cot(tmp_path): + task = build_test_task(tmp_path) + lca = LangChainPromptAdapter( + kiln_task=task, + model_name="llama_3_1_8b", + provider="ollama", + prompt_builder=SimpleChainOfThoughtPromptBuilder(task), + ) + + # Mock the base model and its invoke method + mock_base_model = MagicMock() + mock_base_model.invoke.return_value = AIMessage( + content="Chain of thought reasoning..." + ) + + # Create a separate mock for self.model() + mock_model_instance = MagicMock() + mock_model_instance.invoke.return_value = AIMessage(content="Final response...") + + # Mock the langchain_model_from function to return the base model + mock_model_from = AsyncMock(return_value=mock_base_model) + + # Patch both the langchain_model_from function and self.model() + with ( + patch( + "kiln_ai.adapters.langchain_adapters.langchain_model_from", mock_model_from + ), + patch.object(LangChainPromptAdapter, "model", return_value=mock_model_instance), + ): + response = await lca._run("test input") + + # Verify the model was created with correct parameters + # mock_model_from.assert_awaited_once_with("llama_3_1_8b", "ollama") + + # First 3 messages are the same for both calls + for invoke_args in [ + mock_base_model.invoke.call_args[0][0], + mock_model_instance.invoke.call_args[0][0], + ]: + assert isinstance( + invoke_args[0], SystemMessage + ) # First message should be system prompt + assert ( + "You are an assistant which performs math tasks provided in plain text." + in invoke_args[0].content + ) + assert isinstance(invoke_args[1], HumanMessage) + assert "test input" in invoke_args[1].content + assert isinstance(invoke_args[2], SystemMessage) + assert "step by step" in invoke_args[2].content + + # the COT should only have 3 messages + assert len(mock_base_model.invoke.call_args[0][0]) == 3 + assert len(mock_model_instance.invoke.call_args[0][0]) == 5 + + # the final response should have the COT content and the final instructions + invoke_args = mock_model_instance.invoke.call_args[0][0] + assert isinstance(invoke_args[3], AIMessage) + assert "Chain of thought reasoning..." in invoke_args[3].content + assert isinstance(invoke_args[4], SystemMessage) + assert "Considering the above, return a final result." in invoke_args[4].content + + assert ( + response.intermediate_outputs["chain_of_thought"] + == "Chain of thought reasoning..." + ) + assert response.output == "Final response..." diff --git a/libs/core/kiln_ai/adapters/test_prompt_builders.py b/libs/core/kiln_ai/adapters/test_prompt_builders.py index 0e96273..92a52a0 100644 --- a/libs/core/kiln_ai/adapters/test_prompt_builders.py +++ b/libs/core/kiln_ai/adapters/test_prompt_builders.py @@ -4,10 +4,14 @@ from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter from kiln_ai.adapters.prompt_builders import ( + FewShotChainOfThoughtPromptBuilder, FewShotPromptBuilder, + MultiShotChainOfThoughtPromptBuilder, MultiShotPromptBuilder, RepairsPromptBuilder, + SimpleChainOfThoughtPromptBuilder, SimplePromptBuilder, + chain_of_thought_prompt, prompt_builder_from_ui_name, ) from kiln_ai.adapters.test_prompt_adaptors import build_test_task @@ -43,9 +47,6 @@ def test_simple_prompt_builder(tmp_path): class MockAdapter(BaseAdapter): - def adapter_specific_instructions(self) -> str | None: - return "You are a mock, send me the response!" - def _run(self, input: str) -> str: return "mock response" @@ -64,10 +65,6 @@ def test_simple_prompt_builder_structured_output(tmp_path): prompt = builder.build_prompt() assert "You are an assistant which tells a joke, given a subject." in prompt - # check adapter instructions are included - run_adapter = MockAdapter(task, prompt_builder=builder) - assert "You are a mock, send me the response!" in run_adapter.build_prompt() - user_msg = builder.build_user_message(input) assert input in user_msg assert input not in prompt @@ -313,6 +310,18 @@ def test_prompt_builder_from_ui_name(): assert prompt_builder_from_ui_name("few_shot") == FewShotPromptBuilder assert prompt_builder_from_ui_name("many_shot") == MultiShotPromptBuilder assert prompt_builder_from_ui_name("repairs") == RepairsPromptBuilder + assert ( + prompt_builder_from_ui_name("simple_chain_of_thought") + == SimpleChainOfThoughtPromptBuilder + ) + assert ( + prompt_builder_from_ui_name("few_shot_chain_of_thought") + == FewShotChainOfThoughtPromptBuilder + ) + assert ( + prompt_builder_from_ui_name("multi_shot_chain_of_thought") + == MultiShotChainOfThoughtPromptBuilder + ) with pytest.raises(ValueError, match="Unknown prompt builder: invalid_name"): prompt_builder_from_ui_name("invalid_name") @@ -336,3 +345,84 @@ def test_repair_multi_shot_prompt_builder(task_with_examples): 'Initial Output Which Was Insufficient: {"joke": "Moo I am a cow joke."}' in prompt ) + + +def test_chain_of_thought_prompt(tmp_path): + # Test with default thinking instruction + task = Task( + name="Test Task", + instruction="Test instruction", + parent=None, + thinking_instruction=None, + ) + assert ( + chain_of_thought_prompt(task) + == "Think step by step, explaining your reasoning, before responding with an answer." + ) + + # Test with custom thinking instruction + custom_instruction = "First analyze the problem, then break it down into steps." + task = Task( + name="Test Task", + instruction="Test instruction", + parent=None, + thinking_instruction=custom_instruction, + ) + assert chain_of_thought_prompt(task) == custom_instruction + + +@pytest.mark.parametrize( + "builder_class", + [ + SimpleChainOfThoughtPromptBuilder, + FewShotChainOfThoughtPromptBuilder, + MultiShotChainOfThoughtPromptBuilder, + ], +) +def test_chain_of_thought_prompt_builders(builder_class, task_with_examples): + # Test with default thinking instruction + builder = builder_class(task=task_with_examples) + assert ( + builder.chain_of_thought_prompt() + == "Think step by step, explaining your reasoning, before responding with an answer." + ) + + # Test with custom thinking instruction + custom_instruction = "First analyze the problem, then break it down into steps." + task_with_custom = task_with_examples.model_copy( + update={"thinking_instruction": custom_instruction} + ) + builder = builder_class(task=task_with_custom) + assert builder.chain_of_thought_prompt() == custom_instruction + + +def test_build_prompt_for_ui(tmp_path): + # Test regular prompt builder + task = build_test_task(tmp_path) + simple_builder = SimplePromptBuilder(task=task) + ui_prompt = simple_builder.build_prompt_for_ui() + + # Should match regular prompt since no chain of thought + assert ui_prompt == simple_builder.build_prompt() + assert "# Thinking Instructions" not in ui_prompt + + # Test chain of thought prompt builder + cot_builder = SimpleChainOfThoughtPromptBuilder(task=task) + ui_prompt_cot = cot_builder.build_prompt_for_ui() + + # Should include both base prompt and thinking instructions + assert cot_builder.build_prompt() in ui_prompt_cot + assert "# Thinking Instructions" in ui_prompt_cot + assert "Think step by step" in ui_prompt_cot + + # Test with custom thinking instruction + custom_instruction = "First analyze the problem, then solve it." + task_with_custom = task.model_copy( + update={"thinking_instruction": custom_instruction} + ) + custom_cot_builder = SimpleChainOfThoughtPromptBuilder(task=task_with_custom) + ui_prompt_custom = custom_cot_builder.build_prompt_for_ui() + + assert custom_cot_builder.build_prompt() in ui_prompt_custom + assert "# Thinking Instructions" in ui_prompt_custom + assert custom_instruction in ui_prompt_custom diff --git a/libs/core/kiln_ai/adapters/test_saving_adapter_results.py b/libs/core/kiln_ai/adapters/test_saving_adapter_results.py index 6aebec3..a83c1fe 100644 --- a/libs/core/kiln_ai/adapters/test_saving_adapter_results.py +++ b/libs/core/kiln_ai/adapters/test_saving_adapter_results.py @@ -2,7 +2,7 @@ import pytest -from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter +from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput from kiln_ai.datamodel import ( DataSource, DataSourceType, @@ -14,7 +14,7 @@ class MockAdapter(BaseAdapter): async def _run(self, input: dict | str) -> dict | str: - return "Test output" + return RunOutput(output="Test output", intermediate_outputs=None) def adapter_info(self) -> AdapterInfo: return AdapterInfo( @@ -42,9 +42,13 @@ def test_save_run_isolation(test_task): adapter = MockAdapter(test_task) input_data = "Test input" output_data = "Test output" + run_output = RunOutput( + output=output_data, + intermediate_outputs={"chain_of_thought": "Test chain of thought"}, + ) task_run = adapter.generate_run( - input=input_data, input_source=None, output=output_data + input=input_data, input_source=None, run_output=run_output ) task_run.save_to_file() @@ -52,6 +56,9 @@ def test_save_run_isolation(test_task): assert task_run.parent == test_task assert task_run.input == input_data assert task_run.input_source.type == DataSourceType.human + assert task_run.intermediate_outputs == { + "chain_of_thought": "Test chain of thought" + } created_by = Config.shared().user_id if created_by and created_by != "": assert task_run.input_source.properties["created_by"] == created_by @@ -86,13 +93,16 @@ def test_save_run_isolation(test_task): ) # Run again, with same input and different output. Should create a new TaskRun. - task_output = adapter.generate_run(input_data, None, "Different output") + different_run_output = RunOutput( + output="Different output", intermediate_outputs=None + ) + task_output = adapter.generate_run(input_data, None, different_run_output) task_output.save_to_file() assert len(test_task.runs()) == 2 assert "Different output" in set(run.output.output for run in test_task.runs()) # run again with same input and same output. Should not create a new TaskRun. - task_output = adapter.generate_run(input_data, None, output_data) + task_output = adapter.generate_run(input_data, None, run_output) task_output.save_to_file() assert len(test_task.runs()) == 2 assert "Different output" in set(run.output.output for run in test_task.runs()) @@ -110,7 +120,7 @@ def test_save_run_isolation(test_task): "adapter_name": "mock_adapter", }, ), - output_data, + run_output, ) task_output.save_to_file() assert len(test_task.runs()) == 3 diff --git a/libs/core/kiln_ai/adapters/test_structured_output.py b/libs/core/kiln_ai/adapters/test_structured_output.py index 7df1808..37c53fe 100644 --- a/libs/core/kiln_ai/adapters/test_structured_output.py +++ b/libs/core/kiln_ai/adapters/test_structured_output.py @@ -6,7 +6,7 @@ import pytest import kiln_ai.datamodel as datamodel -from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter +from kiln_ai.adapters.base_adapter import AdapterInfo, BaseAdapter, RunOutput from kiln_ai.adapters.langchain_adapters import LangChainPromptAdapter from kiln_ai.adapters.ml_model_list import ( built_in_models, @@ -59,8 +59,8 @@ def __init__(self, kiln_task: datamodel.Task, response: Dict | str | None): super().__init__(kiln_task) self.response = response - async def _run(self, input: str) -> Dict | str: - return self.response + async def _run(self, input: str) -> RunOutput: + return RunOutput(output=self.response, intermediate_outputs=None) def adapter_info(self) -> AdapterInfo: return AdapterInfo( diff --git a/libs/core/kiln_ai/datamodel/__init__.py b/libs/core/kiln_ai/datamodel/__init__.py index f7da859..ca2bd7a 100644 --- a/libs/core/kiln_ai/datamodel/__init__.py +++ b/libs/core/kiln_ai/datamodel/__init__.py @@ -48,8 +48,18 @@ # Filename compatible names NAME_REGEX = r"^[A-Za-z0-9 _-]+$" -NAME_FIELD = Field(min_length=1, max_length=120, pattern=NAME_REGEX) -SHORT_NAME_FIELD = Field(min_length=1, max_length=20, pattern=NAME_REGEX) +NAME_FIELD = Field( + min_length=1, + max_length=120, + pattern=NAME_REGEX, + description="A name for this entity.", +) +SHORT_NAME_FIELD = Field( + min_length=1, + max_length=20, + pattern=NAME_REGEX, + description="A name for this entity", +) class Priority(IntEnum): @@ -280,6 +290,10 @@ class TaskRun(KilnParentedModel): default=None, description="An version of the output with issues fixed. This must be a 'fixed' version of the existing output, and not an entirely new output. If you wish to generate an ideal curatorial output for this task unrelated to this output, generate a new TaskOutput with type 'human' instead of using this field.", ) + intermediate_outputs: Dict[str, str] | None = Field( + default=None, + description="Intermediate outputs from the task run. Keys are the names of the intermediate output steps (cot=chain of thought, etc), values are the output data.", + ) def parent_task(self) -> Task | None: if not isinstance(self.parent, Task): @@ -380,6 +394,10 @@ class Task( # TODO: make this required, or formalize the default message output schema output_json_schema: JsonObjectSchema | None = None input_json_schema: JsonObjectSchema | None = None + thinking_instruction: str | None = Field( + default=None, + description="Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting.", + ) def output_schema(self) -> Dict | None: if self.output_json_schema is None: diff --git a/libs/core/kiln_ai/datamodel/test_models.py b/libs/core/kiln_ai/datamodel/test_models.py index f706ee3..61eba1d 100644 --- a/libs/core/kiln_ai/datamodel/test_models.py +++ b/libs/core/kiln_ai/datamodel/test_models.py @@ -3,7 +3,16 @@ import pytest from pydantic import ValidationError -from kiln_ai.datamodel import Priority, Project, Task, TaskDeterminism +from kiln_ai.datamodel import ( + DataSource, + DataSourceType, + Priority, + Project, + Task, + TaskDeterminism, + TaskOutput, + TaskRun, +) from kiln_ai.datamodel.test_json_schema import json_joke_schema @@ -76,6 +85,7 @@ def test_task_serialization(test_project_file): determinism=TaskDeterminism.semantic_match, priority=Priority.p0, instruction="Test Base Task Instruction", + thinking_instruction="Test Thinking Instruction", ) task.save_to_file() @@ -84,6 +94,7 @@ def test_task_serialization(test_project_file): assert parsed_task.name == "Test Task" assert parsed_task.description == "Test Description" assert parsed_task.instruction == "Test Base Task Instruction" + assert parsed_task.thinking_instruction == "Test Thinking Instruction" assert parsed_task.determinism == TaskDeterminism.semantic_match assert parsed_task.priority == Priority.p0 @@ -189,3 +200,36 @@ def test_task_output_schema(tmp_path): task = Task(name="Test Task", output_json_schema="{'asdf':{}}", path=path) with pytest.raises(ValidationError): task = Task(name="Test Task", input_json_schema="{asdf", path=path) + + +def test_task_run_intermediate_outputs(): + # Create a basic task output + output = TaskOutput( + output="test output", + source=DataSource( + type=DataSourceType.synthetic, + properties={ + "model_name": "test-model", + "model_provider": "test-provider", + "adapter_name": "test-adapter", + }, + ), + ) + + # Test valid intermediate outputs + task_run = TaskRun( + input="test input", + input_source=DataSource( + type=DataSourceType.human, + properties={"created_by": "test-user"}, + ), + output=output, + intermediate_outputs={ + "cot": "chain of thought output", + "draft": "draft output", + }, + ) + assert task_run.intermediate_outputs == { + "cot": "chain of thought output", + "draft": "draft output", + }