Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chain of thought #25

Merged
merged 10 commits into from
Nov 7, 2024
2 changes: 1 addition & 1 deletion app/desktop/studio_server/prompt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def generate_prompt(
try:
prompt_builder_class = prompt_builder_from_ui_name(prompt_generator)
prompt_builder = prompt_builder_class(task)
prompt = prompt_builder.build_prompt()
prompt = prompt_builder.build_prompt_for_ui()
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

Expand Down
41 changes: 36 additions & 5 deletions app/web_ui/src/lib/api_schema.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ export interface components {
* Where models have instruct and raw versions, instruct is default and raw is specified.
* @enum {string}
*/
ModelName: "llama_3_1_8b" | "llama_3_1_70b" | "llama_3_1_405b" | "llama_3_2_3b" | "llama_3_2_11b" | "llama_3_2_90b" | "gpt_4o_mini" | "gpt_4o" | "phi_3_5" | "mistral_large" | "mistral_nemo" | "gemma_2_2b" | "gemma_2_9b" | "gemma_2_27b" | "claude_3_5_sonnet" | "gemini_1_5_flash" | "gemini_1_5_flash_8b" | "gemini_1_5_pro" | "nemotron_70b";
ModelName: "llama_3_1_8b" | "llama_3_1_70b" | "llama_3_1_405b" | "llama_3_2_3b" | "llama_3_2_11b" | "llama_3_2_90b" | "gpt_4o_mini" | "gpt_4o" | "phi_3_5" | "mistral_large" | "mistral_nemo" | "gemma_2_2b" | "gemma_2_9b" | "gemma_2_27b" | "claude_3_5_haiku" | "claude_3_5_sonnet" | "gemini_1_5_flash" | "gemini_1_5_flash_8b" | "gemini_1_5_pro" | "nemotron_70b";
/** OllamaConnection */
OllamaConnection: {
/** Message */
Expand Down Expand Up @@ -480,7 +480,10 @@ export interface components {
created_at?: string;
/** Created By */
created_by?: string;
/** Name */
/**
* Name
* @description A name for this entity.
*/
name: string;
/**
* Description
Expand Down Expand Up @@ -512,7 +515,10 @@ export interface components {
created_at?: string;
/** Created By */
created_by?: string;
/** Name */
/**
* Name
* @description A name for this entity.
*/
name: string;
/**
* Description
Expand Down Expand Up @@ -606,7 +612,10 @@ export interface components {
created_at?: string;
/** Created By */
created_by?: string;
/** Name */
/**
* Name
* @description A name for this entity.
*/
name: string;
/**
* Description
Expand All @@ -628,6 +637,11 @@ export interface components {
output_json_schema?: string | null;
/** Input Json Schema */
input_json_schema?: string | null;
/**
* Thinking Instruction
* @description Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting.
*/
thinking_instruction?: string | null;
/** Model Type */
readonly model_type: string;
};
Expand Down Expand Up @@ -807,7 +821,10 @@ export interface components {
TaskRequirement: {
/** Id */
id?: string | null;
/** Name */
/**
* Name
* @description A name for this entity
*/
name: string;
/** Description */
description?: string | null;
Expand Down Expand Up @@ -856,6 +873,13 @@ export interface components {
repair_instructions?: string | null;
/** @description An version of the output with issues fixed. This must be a 'fixed' version of the existing output, and not an entirely new output. If you wish to generate an ideal curatorial output for this task unrelated to this output, generate a new TaskOutput with type 'human' instead of using this field. */
repaired_output?: components["schemas"]["TaskOutput-Input"] | null;
/**
* Intermediate Outputs
* @description Intermediate outputs from the task run. Keys are the names of the intermediate output steps (cot=chain of thought, etc), values are the output data.
*/
intermediate_outputs?: {
[key: string]: string;
} | null;
};
/**
* TaskRun
Expand Down Expand Up @@ -897,6 +921,13 @@ export interface components {
repair_instructions?: string | null;
/** @description An version of the output with issues fixed. This must be a 'fixed' version of the existing output, and not an entirely new output. If you wish to generate an ideal curatorial output for this task unrelated to this output, generate a new TaskOutput with type 'human' instead of using this field. */
repaired_output?: components["schemas"]["TaskOutput-Output"] | null;
/**
* Intermediate Outputs
* @description Intermediate outputs from the task run. Keys are the names of the intermediate output steps (cot=chain of thought, etc), values are the output data.
*/
intermediate_outputs?: {
[key: string]: string;
} | null;
/** Model Type */
readonly model_type: string;
};
Expand Down
2 changes: 1 addition & 1 deletion app/web_ui/src/lib/utils/form_element.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
for={id}
class="text-sm font-medium text-left flex flex-col gap-1 pb-[4px]"
>
<div class="flex flex-row">
<div class="flex flex-row items-center">
<span class="grow {light_label ? 'text-xs text-gray-500' : ''}"
>{label}</span
>
Expand Down
3 changes: 3 additions & 0 deletions app/web_ui/src/routes/(app)/run/prompt_type_selector.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,8 @@
["few_shot", "Few Shot"],
["many_shot", "Many Shot"],
["repairs", "Repair Multi Shot"],
["simple_chain_of_thought", "Basic Chain of Thought"],
["few_shot_chain_of_thought", "Chain of Thought - Few Shot"],
["multi_shot_chain_of_thought", "Chain of Thought - Many Shot"],
]}
/>
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
description: task.description,
instruction: task.instruction,
requirements: task.requirements,
thinking_instruction: task.thinking_instruction,
}
// Can only set schemas when creating a new task
if (creating) {
Expand Down Expand Up @@ -145,6 +146,7 @@
!!task.name ||
!!task.description ||
!!task.instruction ||
!!task.thinking_instruction ||
has_edited_requirements ||
!!inputSchemaSection.get_schema_string() ||
!!outputSchemaSection.get_schema_string()
Expand Down Expand Up @@ -268,6 +270,16 @@
bind:value={task.instruction}
/>

<FormElement
label="'Thinking' Instructions"
inputType="textarea"
id="thinking_instructions"
optional={true}
description="Instructions for how the model should 'think' about the task prior to answering. Used for chain of thought style prompting."
info_description="Used when running a 'Chain of Thought' prompt. If left blank, a default 'think step by step' prompt will be used. Optionally customize this with your own instructions to better fit this task."
bind:value={task.thinking_instruction}
/>

<div class="text-sm font-medium text-left pt-6 flex flex-col gap-1">
<div class="text-xl font-bold" id="requirements_part">
Part 2: Requirements
Expand Down
59 changes: 24 additions & 35 deletions libs/core/kiln_ai/adapters/base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ class AdapterInfo:
prompt_builder_name: str


@dataclass
class RunOutput:
output: Dict | str
intermediate_outputs: Dict[str, str] | None


class BaseAdapter(metaclass=ABCMeta):
"""Base class for AI model adapters that handle task execution.

Expand All @@ -36,22 +42,6 @@ class BaseAdapter(metaclass=ABCMeta):
kiln_task (Task): The task configuration and metadata
output_schema (dict | None): JSON schema for validating structured outputs
input_schema (dict | None): JSON schema for validating structured inputs

Example:
```python
class CustomAdapter(BaseAdapter):
async def _run(self, input: Dict | str) -> Dict | str:
# Implementation for specific model
pass

def adapter_info(self) -> AdapterInfo:
return AdapterInfo(
adapter_name="custom",
model_name="model-1",
model_provider="provider",
prompt_builder_name="simple"
)
```
"""

def __init__(
Expand Down Expand Up @@ -85,21 +75,23 @@ async def invoke(
validate_schema(input, self.input_schema)

# Run
result = await self._run(input)
run_output = await self._run(input)

# validate output
if self.output_schema is not None:
if not isinstance(result, dict):
raise RuntimeError(f"structured response is not a dict: {result}")
validate_schema(result, self.output_schema)
if not isinstance(run_output.output, dict):
raise RuntimeError(
f"structured response is not a dict: {run_output.output}"
)
validate_schema(run_output.output, self.output_schema)
else:
if not isinstance(result, str):
if not isinstance(run_output.output, str):
raise RuntimeError(
f"response is not a string for non-structured task: {result}"
f"response is not a string for non-structured task: {run_output.output}"
)

# Generate the run and output
run = self.generate_run(input, input_source, result)
run = self.generate_run(input, input_source, run_output)

# Save the run if configured to do so, and we have a path to save to
if Config.shared().autosave_runs and self.kiln_task.path is not None:
Expand All @@ -118,27 +110,23 @@ def adapter_info(self) -> AdapterInfo:
pass

@abstractmethod
async def _run(self, input: Dict | str) -> Dict | str:
async def _run(self, input: Dict | str) -> RunOutput:
pass

def build_prompt(self) -> str:
prompt = self.prompt_builder.build_prompt()
adapter_instructions = self.adapter_specific_instructions()
if adapter_instructions is not None:
prompt += f"# Format Instructions\n\n{adapter_instructions}\n\n"
return prompt

# override for adapter specific instructions (e.g. tool calling, json format, etc)
def adapter_specific_instructions(self) -> str | None:
return None
return self.prompt_builder.build_prompt()

# create a run and task output
def generate_run(
self, input: Dict | str, input_source: DataSource | None, output: Dict | str
self, input: Dict | str, input_source: DataSource | None, run_output: RunOutput
) -> TaskRun:
# Convert input and output to JSON strings if they are dictionaries
input_str = json.dumps(input) if isinstance(input, dict) else input
output_str = json.dumps(output) if isinstance(output, dict) else output
output_str = (
json.dumps(run_output.output)
if isinstance(run_output.output, dict)
else run_output.output
)

# If no input source is provided, use the human data source
if input_source is None:
Expand All @@ -159,6 +147,7 @@ def generate_run(
properties=self._properties_for_task_output(),
),
),
intermediate_outputs=run_output.intermediate_outputs,
)

exclude_fields = {
Expand Down
42 changes: 35 additions & 7 deletions libs/core/kiln_ai/adapters/langchain_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.runnables import Runnable
from pydantic import BaseModel

import kiln_ai.datamodel as datamodel

from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder
from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput
from .ml_model_list import langchain_model_from

LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
Expand Down Expand Up @@ -84,15 +84,37 @@ async def model(self) -> LangChainModelType:
)
return self._model

async def _run(self, input: Dict | str) -> Dict | str:
async def _run(self, input: Dict | str) -> RunOutput:
model = await self.model()
chain = model
intermediate_outputs = {}

prompt = self.build_prompt()
user_msg = self.prompt_builder.build_user_message(input)
messages = [
SystemMessage(content=prompt),
HumanMessage(content=user_msg),
]
model = await self.model()
response = model.invoke(messages)

cot_prompt = self.prompt_builder.chain_of_thought_prompt()
if cot_prompt:
# Base model (without structured output) used for COT
base_model = await langchain_model_from(
self.model_name, self.model_provider
)
messages.append(
SystemMessage(content=cot_prompt),
)

cot_messages = [*messages]
cot_response = base_model.invoke(cot_messages)
intermediate_outputs["chain_of_thought"] = cot_response.content
messages.append(AIMessage(content=cot_response.content))
messages.append(
SystemMessage(content="Considering the above, return a final result.")
)

response = chain.invoke(messages)

if self.has_structured_output():
if (
Expand All @@ -102,14 +124,20 @@ async def _run(self, input: Dict | str) -> Dict | str:
):
raise RuntimeError(f"structured response not returned: {response}")
structured_response = response["parsed"]
return self._munge_response(structured_response)
return RunOutput(
output=self._munge_response(structured_response),
intermediate_outputs=intermediate_outputs,
)
else:
if not isinstance(response, BaseMessage):
raise RuntimeError(f"response is not a BaseMessage: {response}")
text_content = response.content
if not isinstance(text_content, str):
raise RuntimeError(f"response is not a string: {text_content}")
return text_content
return RunOutput(
output=text_content,
intermediate_outputs=intermediate_outputs,
)

def adapter_info(self) -> AdapterInfo:
return AdapterInfo(
Expand Down
Loading
Loading