From b331b4cf6f52a081f1e70773e0e97cfc462ce869 Mon Sep 17 00:00:00 2001 From: Andy Brenneke Date: Tue, 3 Sep 2024 16:35:02 -0700 Subject: [PATCH] Anthropic prompt caching --- .../components/editors/DefaultNodeEditor.tsx | 4 + packages/core/src/model/DataValue.ts | 13 ++ packages/core/src/model/nodes/PromptNode.ts | 16 ++ .../core/src/plugins/anthropic/anthropic.ts | 149 +++++++++++------- .../anthropic/nodes/ChatAnthropicNode.ts | 71 +++++++-- 5 files changed, 182 insertions(+), 71 deletions(-) diff --git a/packages/app/src/components/editors/DefaultNodeEditor.tsx b/packages/app/src/components/editors/DefaultNodeEditor.tsx index fa44b1d5b..6392165b9 100644 --- a/packages/app/src/components/editors/DefaultNodeEditor.tsx +++ b/packages/app/src/components/editors/DefaultNodeEditor.tsx @@ -72,6 +72,10 @@ export const defaultEditorContainerStyles = css` min-width: 75px; } + label:nth-child(2) { + min-width: 32px; + } + &.use-input-toggle label:first-child { min-width: unset; } diff --git a/packages/core/src/model/DataValue.ts b/packages/core/src/model/DataValue.ts index fc5a86ad4..904aeb96a 100644 --- a/packages/core/src/model/DataValue.ts +++ b/packages/core/src/model/DataValue.ts @@ -14,11 +14,17 @@ export type BoolDataValue = DataValueDef<'boolean', boolean>; export type SystemChatMessage = { type: 'system'; message: ChatMessageMessagePart | ChatMessageMessagePart[]; + + /** If true, this message marks a breakpoint when used with prompt caching (as of right now, Anthropic-only). */ + isCacheBreakpoint?: boolean; }; export type UserChatMessage = { type: 'user'; message: ChatMessageMessagePart | ChatMessageMessagePart[]; + + /** If true, this message marks a breakpoint when used with prompt caching (as of right now, Anthropic-only). */ + isCacheBreakpoint?: boolean; }; export type AssistantChatMessageFunctionCall = { @@ -41,12 +47,18 @@ export type AssistantChatMessage = { function_call: AssistantChatMessageFunctionCall | undefined; function_calls: AssistantChatMessageFunctionCall[] | undefined; + + /** If true, this message marks a breakpoint when used with prompt caching (as of right now, Anthropic-only). */ + isCacheBreakpoint?: boolean; }; export type FunctionResponseChatMessage = { type: 'function'; message: ChatMessageMessagePart | ChatMessageMessagePart[]; name: string; + + /** If true, this message marks a breakpoint when used with prompt caching (as of right now, Anthropic-only). */ + isCacheBreakpoint?: boolean; }; export type ChatMessage = SystemChatMessage | UserChatMessage | AssistantChatMessage | FunctionResponseChatMessage; @@ -389,6 +401,7 @@ export const scalarDefaults: { [P in ScalarDataType]: Extract { @@ -160,6 +163,14 @@ export class PromptNodeImpl extends NodeImpl { label: 'Compute Token Count', dataKey: 'computeTokenCount', }, + { + type: 'toggle', + label: 'Is Cache Breakpoint', + dataKey: 'isCacheBreakpoint', + helperMessage: + 'For Anthropic, marks this message as a cache breakpoint - this message and every message before it will be cached using Prompt Caching.', + useInputToggleDataKey: 'useIsCacheBreakpointInput', + }, { type: 'code', label: 'Prompt Text', @@ -208,6 +219,7 @@ export class PromptNodeImpl extends NodeImpl { const outputValue = interpolate(this.chartNode.data.promptText, inputMap); const type = getInputOrData(this.data, inputs, 'type', 'string'); + const isCacheBreakpoint = getInputOrData(this.data, inputs, 'isCacheBreakpoint', 'boolean'); if (['assistant', 'system', 'user', 'function'].includes(type) === false) { throw new Error(`Invalid type: ${type}`); @@ -219,6 +231,7 @@ export class PromptNodeImpl extends NodeImpl { (type): ChatMessage => ({ type, message: outputValue, + isCacheBreakpoint, }), ) .with( @@ -226,6 +239,7 @@ export class PromptNodeImpl extends NodeImpl { (type): ChatMessage => ({ type, message: outputValue, + isCacheBreakpoint, }), ) .with('assistant', (type): ChatMessage => { @@ -248,6 +262,7 @@ export class PromptNodeImpl extends NodeImpl { message: outputValue, function_call: functionCall as AssistantChatMessageFunctionCall, function_calls: functionCall ? [functionCall as AssistantChatMessageFunctionCall] : undefined, + isCacheBreakpoint, }; }) .with( @@ -256,6 +271,7 @@ export class PromptNodeImpl extends NodeImpl { type, message: outputValue, name: getInputOrData(this.data, inputs, 'name', 'string'), + isCacheBreakpoint, }), ) .otherwise(() => { diff --git a/packages/core/src/plugins/anthropic/anthropic.ts b/packages/core/src/plugins/anthropic/anthropic.ts index e78d9b79e..e05d21ac3 100644 --- a/packages/core/src/plugins/anthropic/anthropic.ts +++ b/packages/core/src/plugins/anthropic/anthropic.ts @@ -86,11 +86,12 @@ export const anthropicModelOptions = Object.entries(anthropicModels).map(([id, { export type Claude3ChatMessage = { role: 'user' | 'assistant'; content: string | Claude3ChatMessageContentPart[]; -} +}; export type Claude3ChatMessageTextContentPart = { type: 'text'; text: string; + cache_control: CacheControl; }; export type Claude3ChatMessageImageContentPart = { @@ -100,12 +101,14 @@ export type Claude3ChatMessageImageContentPart = { media_type: string; data: string; }; + cache_control: CacheControl; }; export type Claude3ChatMessageToolResultContentPart = { type: 'tool_result'; tool_use_id: string; - content: string | { type: 'text'; text: string; }[]; + content: string | { type: 'text'; text: string }[]; + cache_control: CacheControl; }; export type Claude3ChatMessageToolUseContentPart = { @@ -113,9 +116,10 @@ export type Claude3ChatMessageToolUseContentPart = { id: string; name: string; input: object; -} + cache_control: CacheControl; +}; -export type Claude3ChatMessageContentPart = +export type Claude3ChatMessageContentPart = | Claude3ChatMessageTextContentPart | Claude3ChatMessageImageContentPart | Claude3ChatMessageToolResultContentPart @@ -125,7 +129,7 @@ export type ChatMessageOptions = { apiKey: string; model: AnthropicModels; messages: Claude3ChatMessage[]; - system?: string; + system?: SystemPrompt; max_tokens: number; stop_sequences?: string[]; temperature?: number; @@ -138,6 +142,7 @@ export type ChatMessageOptions = { description: string; input_schema: object; }[]; + beta?: string; }; export type ChatCompletionOptions = { @@ -159,62 +164,83 @@ export type ChatCompletionChunk = { model: string; }; -export type ChatMessageChunk = { - type: 'message_start'; - message: { - id: string; - type: string; - role: string; - content: { - type: 'text'; - text: string; - }[]; - model: AnthropicModels; - stop_reason: string | null; - stop_sequence: string | null; - usage: { - input_tokens: number; - output_tokens: number; - }; - }; -} | { - type: 'content_block_start'; - index: number; - content_block: { - type: 'text'; - text: string; - }; -} | { - type: 'ping'; -} | { - type: 'content_block_delta'; - index: number; - delta: { - type: 'text_delta'; - text: string; - } -} | { - type: 'message_delta'; - delta: { - stop_reason: string | null; - stop_sequence: string | null; - usage: { - output_tokens: number; - } - } -} | { - type: 'message_stop'; +export type CacheControl = null | { + type: 'ephemeral'; }; +export type SystemPrompt = string | SystemPromptMessage[]; + +export type SystemPromptMessage = { + cache_control: CacheControl; + type: 'text'; + text: string; +}; + +export type ChatMessageChunk = + | { + type: 'message_start'; + message: { + id: string; + type: string; + role: string; + content: { + type: 'text'; + text: string; + }[]; + model: AnthropicModels; + stop_reason: string | null; + stop_sequence: string | null; + usage: { + input_tokens: number; + output_tokens: number; + }; + }; + } + | { + type: 'content_block_start'; + index: number; + content_block: { + type: 'text'; + text: string; + }; + } + | { + type: 'ping'; + } + | { + type: 'content_block_delta'; + index: number; + delta: { + type: 'text_delta'; + text: string; + }; + } + | { + type: 'message_delta'; + delta: { + stop_reason: string | null; + stop_sequence: string | null; + usage: { + output_tokens: number; + }; + }; + } + | { + type: 'message_stop'; + }; + export type ChatMessageResponse = { id: string; - content: ({ - text: string; - } | { - id: string; - name: string; - input: object; - })[]; + content: ( + | { + text: string; + } + | { + id: string; + name: string; + input: object; + } + )[]; model: string; stop_reason: 'end_turn'; stop_sequence: string; @@ -230,7 +256,7 @@ export async function* streamChatCompletions({ ...rest }: ChatCompletionOptions): AsyncGenerator { const defaultSignal = new AbortController().signal; - const response = await fetchEventSource('https://api.anthropic.com/v1/complete', { + const response = await fetchEventSource('https://api.anthropic.com/v1/completions', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -278,6 +304,7 @@ export async function callMessageApi({ apiKey, signal, tools, + beta, ...rest }: ChatMessageOptions): Promise { const defaultSignal = new AbortController().signal; @@ -287,7 +314,7 @@ export async function callMessageApi({ 'Content-Type': 'application/json', 'x-api-key': apiKey, 'anthropic-version': '2023-06-01', - 'anthropic-beta': tools ? 'tools-2024-04-04' : 'messages-2023-12-15', + ...(beta ? { 'anthropic-beta': beta } : {}), }, body: JSON.stringify({ ...rest, @@ -306,17 +333,19 @@ export async function callMessageApi({ export async function* streamMessageApi({ apiKey, signal, + beta, ...rest }: ChatMessageOptions): AsyncGenerator { // Use the Messages API for Claude 3 models const defaultSignal = new AbortController().signal; + console.dir({ rest }, { depth: null }); const response = await fetchEventSource('https://api.anthropic.com/v1/messages', { method: 'POST', headers: { 'Content-Type': 'application/json', 'x-api-key': apiKey, 'anthropic-version': '2023-06-01', - 'anthropic-beta': 'messages-2023-12-15', + ...(beta ? { 'anthropic-beta': beta } : {}), }, body: JSON.stringify({ ...rest, diff --git a/packages/core/src/plugins/anthropic/nodes/ChatAnthropicNode.ts b/packages/core/src/plugins/anthropic/nodes/ChatAnthropicNode.ts index 1cb39a292..c5c391dd4 100644 --- a/packages/core/src/plugins/anthropic/nodes/ChatAnthropicNode.ts +++ b/packages/core/src/plugins/anthropic/nodes/ChatAnthropicNode.ts @@ -28,6 +28,7 @@ import { type ChatMessageOptions, callMessageApi, type Claude3ChatMessageTextContentPart, + type SystemPrompt, } from '../anthropic.js'; import { nanoid } from 'nanoid/non-secure'; import { dedent } from 'ts-dedent'; @@ -341,8 +342,10 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { const tools = data.enableToolUse ? coerceTypeOptional(inputs['tools' as PortId], 'gpt-function[]') ?? [] : undefined; + const rivetChatMessages = getChatMessages(inputs); const messages = await chatMessagesToClaude3ChatMessages(rivetChatMessages); + let prompt = messages.reduce((acc, message) => { const content = typeof message.content === 'string' @@ -363,12 +366,18 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { // Get the "System" prompt input for Claude 3 models const system = data.model.startsWith('claude-3') ? getSystemPrompt(inputs) : undefined; + const systemInput = inputs['system' as PortId]; + const includesCacheBreakpoint = + rivetChatMessages.some((m) => m.isCacheBreakpoint) || + (systemInput?.type === 'chat-message' && systemInput.value.isCacheBreakpoint); + let { maxTokens } = data; const tokenizerInfo: TokenizerCallInfo = { node: context.node, model, endpoint: undefined, }; + const tokenCountEstimate = await context.tokenizer.getTokenCountForString(prompt, tokenizerInfo); const modelInfo = anthropicModels[model] ?? { maxTokens: Number.MAX_SAFE_INTEGER, @@ -377,11 +386,13 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { completion: 0, }, }; + if (tokenCountEstimate >= modelInfo.maxTokens) { throw new Error( `The model ${model} can only handle ${modelInfo.maxTokens} tokens, but ${tokenCountEstimate} were provided in the prompts alone.`, ); } + if (tokenCountEstimate + maxTokens > modelInfo.maxTokens) { const message = `The model can only handle a maximum of ${ modelInfo.maxTokens @@ -391,6 +402,7 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { addWarning(output, message); maxTokens = Math.floor((modelInfo.maxTokens - tokenCountEstimate) * 0.95); // reduce max tokens by 5% to be safe, calculation is a little wrong. } + try { return await retry( async () => { @@ -430,6 +442,7 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { // Streaming is not supported with tool usage. const response = await callMessageApi({ apiKey: apiKey ?? '', + beta: includesCacheBreakpoint ? 'prompt-caching-2024-07-31' : undefined, ...messageOptions, }); const { input_tokens: requestTokens, output_tokens: responseTokens } = response.usage; @@ -488,6 +501,7 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { const chunks = streamMessageApi({ apiKey: apiKey ?? '', signal: context.signal, + beta: includesCacheBreakpoint ? 'prompt-caching-2024-07-31' : undefined, ...messageOptions, }); @@ -639,22 +653,43 @@ export const ChatAnthropicNodeImpl: PluginNodeImpl = { export const chatAnthropicNode = pluginNodeDefinition(ChatAnthropicNodeImpl, 'Chat'); -export function getSystemPrompt(inputs: Inputs) { - const system = coerceTypeOptional(inputs['system' as PortId], 'string'); +export function getSystemPrompt(inputs: Inputs): SystemPrompt | undefined { + const systemInput = inputs['system' as PortId]; + + const system = coerceTypeOptional(systemInput, 'string'); + if (system) { - return system; + return [ + { + type: 'text', + text: system, + cache_control: + systemInput?.type === 'chat-message' + ? systemInput.value.isCacheBreakpoint + ? { type: 'ephemeral' } + : null + : null, + }, + ]; } + const prompt = inputs['prompt' as PortId]; if (prompt && prompt.type === 'chat-message[]') { - const systemMessage = prompt.value.find((message) => message.type === 'system'); - if (systemMessage) { - if (typeof systemMessage.message === 'string') { - return systemMessage.message; - } else if (Array.isArray(systemMessage.message)) { - return systemMessage.message.filter((p) => typeof p === 'string').join(''); - } + const systemMessages = prompt.value.filter((message) => message.type === 'system'); + if (systemMessages.length) { + const converted = systemMessages.map((message) => { + return { + type: 'text' as const, + text: coerceType({ type: 'chat-message', value: message }, 'string'), + cache_control: message.isCacheBreakpoint ? { type: 'ephemeral' as const } : null, + }; + }); + + return converted; } } + + return undefined; } function getChatMessages(inputs: Inputs) { @@ -725,6 +760,7 @@ async function chatMessageToClaude3ChatMessage(message: ChatMessage): Promise chatMessageContentToClaude3ChatMessage(part))) : [await chatMessageContentToClaude3ChatMessage(message.message)]; if (message.type === 'assistant' && message.function_calls) { @@ -753,6 +790,7 @@ async function chatMessageToClaude3ChatMessage(message: ChatMessage): Promise