From ee638fdbf8c6051575503c85990fd54e57889555 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 23 Sep 2024 14:29:38 -0700 Subject: [PATCH] langgraph[patch]: remove stringify of tool msg contnt (#1810) --- .../langgraph/langgraph/prebuilt/tool_node.py | 23 +++++++++++---- libs/langgraph/tests/test_prebuilt.py | 28 +++++++++++++++++++ 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/libs/langgraph/langgraph/prebuilt/tool_node.py b/libs/langgraph/langgraph/prebuilt/tool_node.py index 52b80f75a..be87c0f0f 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_node.py +++ b/libs/langgraph/langgraph/prebuilt/tool_node.py @@ -43,9 +43,20 @@ TOOL_CALL_ERROR_TEMPLATE = "Error: {error}\n Please fix your mistakes." -def str_output(output: Any) -> str: +def msg_content_output(output: Any) -> str | List[dict]: + recognized_content_block_types = ("image", "image_url", "text", "json") if isinstance(output, str): return output + elif all( + [ + isinstance(x, dict) and x.get("type") in recognized_content_block_types + for x in output + ] + ): + return output + # Technically a list of strings is also valid message content but it's not currently + # well tested that all chat models support this. And for backwards compatibility + # we want to make sure we don't break any existing ToolNode usage. else: try: return json.dumps(output, ensure_ascii=False) @@ -138,8 +149,9 @@ def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage: tool_message: ToolMessage = self.tools_by_name[call["name"]].invoke( input, config ) - # TODO: handle this properly in core - tool_message.content = str_output(tool_message.content) + tool_message.content = cast( + Union[str, list], msg_content_output(tool_message.content) + ) return tool_message except Exception as e: if not self.handle_tool_errors: @@ -155,8 +167,9 @@ async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage tool_message: ToolMessage = await self.tools_by_name[call["name"]].ainvoke( input, config ) - # TODO: handle this properly in core - tool_message.content = str_output(tool_message.content) + tool_message.content = cast( + Union[str, list], msg_content_output(tool_message.content) + ) return tool_message except Exception as e: if not self.handle_tool_errors: diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index 0fc1685e5..6e1cd081a 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -262,6 +262,12 @@ async def tool3(some_val: int, some_other_val: str) -> str: {"key_1": some_other_val, "key_2": "baz"}, ] + async def tool4(some_val: int, some_other_val: str) -> str: + """Tool 4 docstring.""" + return [ + {"type": "image_url", "image_url": {"url": "abdc"}}, + ] + result = ToolNode([tool1]).invoke( { "messages": [ @@ -397,6 +403,28 @@ async def tool3(some_val: int, some_other_val: str) -> str: ) assert tool_message.tool_call_id == "some 0" + # list of content blocks tool content + result4 = await ToolNode([tool4]).ainvoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool4", + "args": {"some_val": 2, "some_other_val": "bar"}, + "id": "some 0", + } + ], + ) + ] + } + ) + tool_message: ToolMessage = result4["messages"][-1] + assert tool_message.type == "tool" + assert tool_message.content == [{"type": "image_url", "image_url": {"url": "abdc"}}] + assert tool_message.tool_call_id == "some 0" + def my_function(some_val: int, some_other_val: str) -> str: return f"{some_val} - {some_other_val}"