From 3798416dc53986583657fe43e4345f23fb3e57d7 Mon Sep 17 00:00:00 2001 From: adlymousa Date: Thu, 14 Nov 2024 12:36:47 +0000 Subject: [PATCH] Support funciton pre-execute system prompt --- .../aggregators/openai_llm_context.py | 8 +++++ src/pipecat/services/ai_services.py | 15 ++++++++-- src/pipecat/services/openai.py | 29 +++++++++++++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/src/pipecat/processors/aggregators/openai_llm_context.py b/src/pipecat/processors/aggregators/openai_llm_context.py index d70f0e25b..16f1a1981 100644 --- a/src/pipecat/processors/aggregators/openai_llm_context.py +++ b/src/pipecat/processors/aggregators/openai_llm_context.py @@ -100,10 +100,18 @@ def messages(self) -> List[ChatCompletionMessageParam]: def tools(self) -> List[ChatCompletionToolParam] | NotGiven: return self._tools + @tools.setter + def tools(self, value): + self._tools = value + @property def tool_choice(self) -> ChatCompletionToolChoiceOptionParam | NotGiven: return self._tool_choice + @tool_choice.setter + def tool_choice(self, value): + self._tool_choice = value + def add_message(self, message: ChatCompletionMessageParam): self._messages.append(message) diff --git a/src/pipecat/services/ai_services.py b/src/pipecat/services/ai_services.py index f7a08802c..5ec9e70b0 100644 --- a/src/pipecat/services/ai_services.py +++ b/src/pipecat/services/ai_services.py @@ -134,20 +134,31 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self._callbacks = {} self._start_callbacks = {} + self._pre_execute_prompts = {} # TODO-CB: callback function type - def register_function(self, function_name: str | None, callback, start_callback=None): + def register_function( + self, + function_name: str | None, + callback, + start_callback=None, + pre_execute_prompt: str | None = None + ): # Registering a function with the function_name set to None will run that callback # for all functions self._callbacks[function_name] = callback # QUESTION FOR CB: maybe this isn't needed anymore? if start_callback: self._start_callbacks[function_name] = start_callback + if pre_execute_prompt: + self._pre_execute_prompts[function_name] = pre_execute_prompt def unregister_function(self, function_name: str | None): del self._callbacks[function_name] - if self._start_callbacks[function_name]: + if function_name in self._start_callbacks: del self._start_callbacks[function_name] + if function_name in self._pre_execute_prompts: + del self._pre_execute_prompts[function_name] def has_function(self, function_name: str): if None in self._callbacks.keys(): diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index b6927e8dc..881bdd461 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -193,6 +193,34 @@ async def _stream_chat_completions( chunks = await self.get_chat_completions(context, messages) return chunks + + async def _handle_pre_execute_prompt(self, context: OpenAILLMContext, function_name: str): + """Handle pre-execute prompt for a function if one exists.""" + pre_execute_prompt = self._pre_execute_prompts.get(function_name) + if not pre_execute_prompt: + return + + logger.debug(f"Handling pre_execute_prompt for function: {function_name}") + + # Add the pre-execute prompt as a system message to the context + context.add_message({"role": "system", "content": pre_execute_prompt}) + + # Temporarily disable function calling to prevent recursion + original_tools = context.tools + original_tool_choice = context.tool_choice + context.tools = NOT_GIVEN + context.tool_choice = NOT_GIVEN + + # Process the context normally + await self.push_frame(LLMFullResponseStartFrame()) + await self.start_processing_metrics() + await self._process_context(context) + await self.stop_processing_metrics() + await self.push_frame(LLMFullResponseEndFrame()) + + # Restore function calling capability + context.tools = original_tools + context.tool_choice = original_tool_choice async def _process_context(self, context: OpenAILLMContext): functions_list = [] @@ -250,6 +278,7 @@ async def _process_context(self, context: OpenAILLMContext): if tool_call.function and tool_call.function.name: function_name += tool_call.function.name tool_call_id = tool_call.id + await self._handle_pre_execute_prompt(context, function_name) await self.call_start_function(context, function_name) if tool_call.function and tool_call.function.arguments: # Keep iterating through the response to collect all the argument fragments