Skip to content

Commit

Permalink
Merge pull request #25 from Kiln-AI/cot
Browse files Browse the repository at this point in the history
Chain of thought
  • Loading branch information
scosman authored Nov 7, 2024
2 parents 72a0f1b + 5066807 commit ff2e91c
Show file tree
Hide file tree
Showing 15 changed files with 435 additions and 69 deletions.
2 changes: 1 addition & 1 deletion app/desktop/studio_server/prompt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def generate_prompt(
try:
prompt_builder_class = prompt_builder_from_ui_name(prompt_generator)
prompt_builder = prompt_builder_class(task)
prompt = prompt_builder.build_prompt()
prompt = prompt_builder.build_prompt_for_ui()
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

Expand Down
41 changes: 36 additions & 5 deletions app/web_ui/src/lib/api_schema.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ export interface components {
* Where models have instruct and raw versions, instruct is default and raw is specified.
* @enum {string}
*/
ModelName: "llama_3_1_8b" | "llama_3_1_70b" | "llama_3_1_405b" | "llama_3_2_3b" | "llama_3_2_11b" | "llama_3_2_90b" | "gpt_4o_mini" | "gpt_4o" | "phi_3_5" | "mistral_large" | "mistral_nemo" | "gemma_2_2b" | "gemma_2_9b" | "gemma_2_27b" | "claude_3_5_sonnet" | "gemini_1_5_flash" | "gemini_1_5_flash_8b" | "gemini_1_5_pro" | "nemotron_70b";
ModelName: "llama_3_1_8b" | "llama_3_1_70b" | "llama_3_1_405b" | "llama_3_2_3b" | "llama_3_2_11b" | "llama_3_2_90b" | "gpt_4o_mini" | "gpt_4o" | "phi_3_5" | "mistral_large" | "mistral_nemo" | "gemma_2_2b" | "gemma_2_9b" | "gemma_2_27b" | "claude_3_5_haiku" | "claude_3_5_sonnet" | "gemini_1_5_flash" | "gemini_1_5_flash_8b" | "gemini_1_5_pro" | "nemotron_70b";
/** OllamaConnection */
OllamaConnection: {
/** Message */
Expand Down Expand Up @@ -480,7 +480,10 @@ export interface components {
created_at?: string;
/** Created By */
created_by?: string;
/** Name */
/**
* Name
* @description A name for this entity.
*/
name: string;
/**
* Description
Expand Down Expand Up @@ -512,7 +515,10 @@ export interface components {
created_at?: string;
/** Created By */
created_by?: string;
/** Name */
/**
* Name
* @description A name for this entity.
*/
name: string;
/**
* Description
Expand Down Expand Up @@ -606,7 +612,10 @@ export interface components {
created_at?: string;
/** Created By */
created_by?: string;
/** Name */
/**
* Name
* @description A name for this entity.
*/
name: string;
/**
* Description
Expand All @@ -628,6 +637,11 @@ export interface components {
output_json_schema?: string | null;
/** Input Json Schema */
input_json_schema?: string | null;
/**
* Thinking Instruction
* @description Instructions for the model 'thinking' about the requirement prior to answering. Used for chain of thought style prompting.
*/
thinking_instruction?: string | null;
/** Model Type */
readonly model_type: string;
};
Expand Down Expand Up @@ -807,7 +821,10 @@ export interface components {
TaskRequirement: {
/** Id */
id?: string | null;
/** Name */
/**
* Name
* @description A name for this entity
*/
name: string;
/** Description */
description?: string | null;
Expand Down Expand Up @@ -856,6 +873,13 @@ export interface components {
repair_instructions?: string | null;
/** @description An version of the output with issues fixed. This must be a 'fixed' version of the existing output, and not an entirely new output. If you wish to generate an ideal curatorial output for this task unrelated to this output, generate a new TaskOutput with type 'human' instead of using this field. */
repaired_output?: components["schemas"]["TaskOutput-Input"] | null;
/**
* Intermediate Outputs
* @description Intermediate outputs from the task run. Keys are the names of the intermediate output steps (cot=chain of thought, etc), values are the output data.
*/
intermediate_outputs?: {
[key: string]: string;
} | null;
};
/**
* TaskRun
Expand Down Expand Up @@ -897,6 +921,13 @@ export interface components {
repair_instructions?: string | null;
/** @description An version of the output with issues fixed. This must be a 'fixed' version of the existing output, and not an entirely new output. If you wish to generate an ideal curatorial output for this task unrelated to this output, generate a new TaskOutput with type 'human' instead of using this field. */
repaired_output?: components["schemas"]["TaskOutput-Output"] | null;
/**
* Intermediate Outputs
* @description Intermediate outputs from the task run. Keys are the names of the intermediate output steps (cot=chain of thought, etc), values are the output data.
*/
intermediate_outputs?: {
[key: string]: string;
} | null;
/** Model Type */
readonly model_type: string;
};
Expand Down
2 changes: 1 addition & 1 deletion app/web_ui/src/lib/utils/form_element.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
for={id}
class="text-sm font-medium text-left flex flex-col gap-1 pb-[4px]"
>
<div class="flex flex-row">
<div class="flex flex-row items-center">
<span class="grow {light_label ? 'text-xs text-gray-500' : ''}"
>{label}</span
>
Expand Down
3 changes: 3 additions & 0 deletions app/web_ui/src/routes/(app)/run/prompt_type_selector.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,8 @@
["few_shot", "Few Shot"],
["many_shot", "Many Shot"],
["repairs", "Repair Multi Shot"],
["simple_chain_of_thought", "Basic Chain of Thought"],
["few_shot_chain_of_thought", "Chain of Thought - Few Shot"],
["multi_shot_chain_of_thought", "Chain of Thought - Many Shot"],
]}
/>
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
description: task.description,
instruction: task.instruction,
requirements: task.requirements,
thinking_instruction: task.thinking_instruction,
}
// Can only set schemas when creating a new task
if (creating) {
Expand Down Expand Up @@ -145,6 +146,7 @@
!!task.name ||
!!task.description ||
!!task.instruction ||
!!task.thinking_instruction ||
has_edited_requirements ||
!!inputSchemaSection.get_schema_string() ||
!!outputSchemaSection.get_schema_string()
Expand Down Expand Up @@ -268,6 +270,16 @@
bind:value={task.instruction}
/>

<FormElement
label="'Thinking' Instructions"
inputType="textarea"
id="thinking_instructions"
optional={true}
description="Instructions for how the model should 'think' about the task prior to answering. Used for chain of thought style prompting."
info_description="Used when running a 'Chain of Thought' prompt. If left blank, a default 'think step by step' prompt will be used. Optionally customize this with your own instructions to better fit this task."
bind:value={task.thinking_instruction}
/>

<div class="text-sm font-medium text-left pt-6 flex flex-col gap-1">
<div class="text-xl font-bold" id="requirements_part">
Part 2: Requirements
Expand Down
59 changes: 24 additions & 35 deletions libs/core/kiln_ai/adapters/base_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ class AdapterInfo:
prompt_builder_name: str


@dataclass
class RunOutput:
output: Dict | str
intermediate_outputs: Dict[str, str] | None


class BaseAdapter(metaclass=ABCMeta):
"""Base class for AI model adapters that handle task execution.
Expand All @@ -36,22 +42,6 @@ class BaseAdapter(metaclass=ABCMeta):
kiln_task (Task): The task configuration and metadata
output_schema (dict | None): JSON schema for validating structured outputs
input_schema (dict | None): JSON schema for validating structured inputs
Example:
```python
class CustomAdapter(BaseAdapter):
async def _run(self, input: Dict | str) -> Dict | str:
# Implementation for specific model
pass
def adapter_info(self) -> AdapterInfo:
return AdapterInfo(
adapter_name="custom",
model_name="model-1",
model_provider="provider",
prompt_builder_name="simple"
)
```
"""

def __init__(
Expand Down Expand Up @@ -85,21 +75,23 @@ async def invoke(
validate_schema(input, self.input_schema)

# Run
result = await self._run(input)
run_output = await self._run(input)

# validate output
if self.output_schema is not None:
if not isinstance(result, dict):
raise RuntimeError(f"structured response is not a dict: {result}")
validate_schema(result, self.output_schema)
if not isinstance(run_output.output, dict):
raise RuntimeError(
f"structured response is not a dict: {run_output.output}"
)
validate_schema(run_output.output, self.output_schema)
else:
if not isinstance(result, str):
if not isinstance(run_output.output, str):
raise RuntimeError(
f"response is not a string for non-structured task: {result}"
f"response is not a string for non-structured task: {run_output.output}"
)

# Generate the run and output
run = self.generate_run(input, input_source, result)
run = self.generate_run(input, input_source, run_output)

# Save the run if configured to do so, and we have a path to save to
if Config.shared().autosave_runs and self.kiln_task.path is not None:
Expand All @@ -118,27 +110,23 @@ def adapter_info(self) -> AdapterInfo:
pass

@abstractmethod
async def _run(self, input: Dict | str) -> Dict | str:
async def _run(self, input: Dict | str) -> RunOutput:
pass

def build_prompt(self) -> str:
prompt = self.prompt_builder.build_prompt()
adapter_instructions = self.adapter_specific_instructions()
if adapter_instructions is not None:
prompt += f"# Format Instructions\n\n{adapter_instructions}\n\n"
return prompt

# override for adapter specific instructions (e.g. tool calling, json format, etc)
def adapter_specific_instructions(self) -> str | None:
return None
return self.prompt_builder.build_prompt()

# create a run and task output
def generate_run(
self, input: Dict | str, input_source: DataSource | None, output: Dict | str
self, input: Dict | str, input_source: DataSource | None, run_output: RunOutput
) -> TaskRun:
# Convert input and output to JSON strings if they are dictionaries
input_str = json.dumps(input) if isinstance(input, dict) else input
output_str = json.dumps(output) if isinstance(output, dict) else output
output_str = (
json.dumps(run_output.output)
if isinstance(run_output.output, dict)
else run_output.output
)

# If no input source is provided, use the human data source
if input_source is None:
Expand All @@ -159,6 +147,7 @@ def generate_run(
properties=self._properties_for_task_output(),
),
),
intermediate_outputs=run_output.intermediate_outputs,
)

exclude_fields = {
Expand Down
42 changes: 35 additions & 7 deletions libs/core/kiln_ai/adapters/langchain_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.runnables import Runnable
from pydantic import BaseModel

import kiln_ai.datamodel as datamodel

from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder
from .base_adapter import AdapterInfo, BaseAdapter, BasePromptBuilder, RunOutput
from .ml_model_list import langchain_model_from

LangChainModelType = BaseChatModel | Runnable[LanguageModelInput, Dict | BaseModel]
Expand Down Expand Up @@ -84,15 +84,37 @@ async def model(self) -> LangChainModelType:
)
return self._model

async def _run(self, input: Dict | str) -> Dict | str:
async def _run(self, input: Dict | str) -> RunOutput:
model = await self.model()
chain = model
intermediate_outputs = {}

prompt = self.build_prompt()
user_msg = self.prompt_builder.build_user_message(input)
messages = [
SystemMessage(content=prompt),
HumanMessage(content=user_msg),
]
model = await self.model()
response = model.invoke(messages)

cot_prompt = self.prompt_builder.chain_of_thought_prompt()
if cot_prompt:
# Base model (without structured output) used for COT
base_model = await langchain_model_from(
self.model_name, self.model_provider
)
messages.append(
SystemMessage(content=cot_prompt),
)

cot_messages = [*messages]
cot_response = base_model.invoke(cot_messages)
intermediate_outputs["chain_of_thought"] = cot_response.content
messages.append(AIMessage(content=cot_response.content))
messages.append(
SystemMessage(content="Considering the above, return a final result.")
)

response = chain.invoke(messages)

if self.has_structured_output():
if (
Expand All @@ -102,14 +124,20 @@ async def _run(self, input: Dict | str) -> Dict | str:
):
raise RuntimeError(f"structured response not returned: {response}")
structured_response = response["parsed"]
return self._munge_response(structured_response)
return RunOutput(
output=self._munge_response(structured_response),
intermediate_outputs=intermediate_outputs,
)
else:
if not isinstance(response, BaseMessage):
raise RuntimeError(f"response is not a BaseMessage: {response}")
text_content = response.content
if not isinstance(text_content, str):
raise RuntimeError(f"response is not a string: {text_content}")
return text_content
return RunOutput(
output=text_content,
intermediate_outputs=intermediate_outputs,
)

def adapter_info(self) -> AdapterInfo:
return AdapterInfo(
Expand Down
Loading

0 comments on commit ff2e91c

Please sign in to comment.