From 5a99127dc43fce3b63197d627678a4793624642c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Henrique=20Luckmann?= Date: Tue, 22 Oct 2024 10:46:52 -0300 Subject: [PATCH 1/2] Add callback suport to Function Tool --- .../tools/function_tool_callback.ipynb | 162 ++++++++++++++++++ .../llama_index/core/tools/function_tool.py | 32 +++- 2 files changed, 186 insertions(+), 8 deletions(-) create mode 100644 docs/docs/examples/tools/function_tool_callback.ipynb diff --git a/docs/docs/examples/tools/function_tool_callback.ipynb b/docs/docs/examples/tools/function_tool_callback.ipynb new file mode 100644 index 0000000000000..d7dbc3ffed588 --- /dev/null +++ b/docs/docs/examples/tools/function_tool_callback.ipynb @@ -0,0 +1,162 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Function call with callback\n", + "\n", + "This is a feature that allows applying some human-in-the-loop concepts in FunctionTool.\n", + "\n", + "Basically, a callback function is added that enables the developer to request user input in the middle of an agent interaction, as well as allowing any programmatic action." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index-llms-openai\n", + "%pip install llama-index-agents-openai" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.tools import FunctionTool\n", + "from llama_index.agent.openai import OpenAIAgent\n", + "from llama_index.llms.openai import OpenAI\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"OPENAI_API_KEY\"] = \"sk-\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Function to display to the user the data produced for function calling and request their input to return to the interaction." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def callback(message):\n", + "\n", + " confirmation = input(f\"{message[1]}\\nDo you approve of sending this greeting?\\nInput(Y/N):\")\n", + "\n", + " if confirmation.lower() == \"y\": \n", + " # Here you can trigger an action such as sending an email, message, api call, etc. \n", + " return \"Greeting sent successfully.\"\n", + " else:\n", + " return \"Greeting has not been approved, talk a bit about how to improve\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Simple function that only requires a recipient and a greeting message." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "def send_hello(destination:str, message:str)->str:\n", + " \"\"\"\n", + " Say hello with a rhyme \n", + " destination: str - Name of recipient\n", + " message: str - Greeting message with a rhyme to the recipient's name\n", + " \"\"\" \n", + "\n", + " return destination, message\n", + "\n", + "hello_tool = FunctionTool.from_defaults(fn=send_hello, callback=callback)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "llm = OpenAI()\n", + "agent = OpenAIAgent.from_tools([hello_tool])" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "I attempted to send a hello message to Karen, but it seems the greeting has not been approved. Let's try to come up with a different greeting that might be more suitable. How about \"Hello Karen, your smile shines like the sun\"? Let's send this message instead.\n" + ] + } + ], + "source": [ + "response = agent.chat(\"Send hello to Karen\")\n", + "print(str(response))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "I have successfully sent a hello message to Joe with the greeting \"Hello Joe, you're a pro!\"\n" + ] + } + ], + "source": [ + "response = agent.chat(\"Send hello to Joe\")\n", + "print(str(response))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/llama-index-core/llama_index/core/tools/function_tool.py b/llama-index-core/llama_index/core/tools/function_tool.py index 23cf7a5f8a48d..af8b01f81dfa4 100644 --- a/llama-index-core/llama_index/core/tools/function_tool.py +++ b/llama-index-core/llama_index/core/tools/function_tool.py @@ -12,7 +12,6 @@ AsyncCallable = Callable[..., Awaitable[Any]] - def sync_to_async(fn: Callable[..., Any]) -> AsyncCallable: """Sync to async.""" @@ -44,6 +43,7 @@ def __init__( fn: Optional[Callable[..., Any]] = None, metadata: Optional[ToolMetadata] = None, async_fn: Optional[AsyncCallable] = None, + callback: Optional[Callable[[Any], Any]] = None, ) -> None: if fn is None and async_fn is None: raise ValueError("fn or async_fn must be provided.") @@ -62,6 +62,13 @@ def __init__( raise ValueError("metadata must be provided.") self._metadata = metadata + self._callback = callback + + def _run_callback(self, result: Any) -> Any: + """Executes the callback if provided and returns its result.""" + if self._callback: + return self._callback(result) + return "" @classmethod def from_defaults( @@ -73,6 +80,7 @@ def from_defaults( fn_schema: Optional[Type[BaseModel]] = None, async_fn: Optional[AsyncCallable] = None, tool_metadata: Optional[ToolMetadata] = None, + callback: Optional[Callable[[Any], Any]] = None, ) -> "FunctionTool": if tool_metadata is None: fn_to_parse = fn or async_fn @@ -90,7 +98,7 @@ def from_defaults( fn_schema=fn_schema, return_direct=return_direct, ) - return cls(fn=fn, metadata=tool_metadata, async_fn=async_fn) + return cls(fn=fn, metadata=tool_metadata, async_fn=async_fn, callback=callback) @property def metadata(self) -> ToolMetadata: @@ -109,19 +117,27 @@ def async_fn(self) -> AsyncCallable: def call(self, *args: Any, **kwargs: Any) -> ToolOutput: """Call.""" - tool_output = self._fn(*args, **kwargs) + tool_output = self._fn(*args, **kwargs) + final_output_content = str(tool_output) + callback_output = self._run_callback(tool_output) + if callback_output: + final_output_content += f" Callback: {callback_output}" return ToolOutput( - content=str(tool_output), + content=final_output_content, tool_name=self.metadata.name, raw_input={"args": args, "kwargs": kwargs}, raw_output=tool_output, ) async def acall(self, *args: Any, **kwargs: Any) -> ToolOutput: - """Call.""" - tool_output = await self._async_fn(*args, **kwargs) + """Async Call.""" + tool_output = self._fn(*args, **kwargs) + final_output_content = str(tool_output) + callback_output = self._run_callback(tool_output) + if callback_output: + final_output_content += f" Callback: {callback_output}" return ToolOutput( - content=str(tool_output), + content=final_output_content, tool_name=self.metadata.name, raw_input={"args": args, "kwargs": kwargs}, raw_output=tool_output, @@ -157,4 +173,4 @@ def to_langchain_structured_tool( func=self.fn, coroutine=self.async_fn, **langchain_tool_kwargs, - ) + ) \ No newline at end of file From a420154b4b0d86536a94e6efe2f0ae6a67f7bdc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Henrique=20Luckmann?= Date: Tue, 22 Oct 2024 11:13:10 -0300 Subject: [PATCH 2/2] Adjustment to remove callback from langchain function tool --- llama-index-core/llama_index/core/tools/types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llama-index-core/llama_index/core/tools/types.py b/llama-index-core/llama_index/core/tools/types.py index 2355669858cf1..00dc75f6289e3 100644 --- a/llama-index-core/llama_index/core/tools/types.py +++ b/llama-index-core/llama_index/core/tools/types.py @@ -122,6 +122,9 @@ def _process_langchain_tool_kwargs( langchain_tool_kwargs["description"] = self.metadata.description if "fn_schema" not in langchain_tool_kwargs: langchain_tool_kwargs["args_schema"] = self.metadata.fn_schema + #Callback dont exist on langchain + if "callback" in langchain_tool_kwargs: + del langchain_tool_kwargs["callback"] return langchain_tool_kwargs def to_langchain_tool(