From d9b7aaa5cc6439d2745b97989a81acc577948db3 Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Mon, 16 Dec 2024 14:00:45 -0500 Subject: [PATCH] langgraph: relax constraints in ToolNode Command validation (#2778) --- .../langgraph/langgraph/prebuilt/tool_node.py | 24 ++-- libs/langgraph/tests/test_prebuilt.py | 110 ++++++++---------- 2 files changed, 54 insertions(+), 80 deletions(-) diff --git a/libs/langgraph/langgraph/prebuilt/tool_node.py b/libs/langgraph/langgraph/prebuilt/tool_node.py index e2ac50b8e..5692c750d 100644 --- a/libs/langgraph/langgraph/prebuilt/tool_node.py +++ b/libs/langgraph/langgraph/prebuilt/tool_node.py @@ -548,33 +548,25 @@ def _validate_tool_command( # convert to message objects if updates are in a dict format messages_update = convert_to_messages(messages_update) - have_seen_tool_messages = False + has_matching_tool_message = False for message in messages_update: if not isinstance(message, ToolMessage): continue - if have_seen_tool_messages: - raise ValueError( - f"Expected at most one ToolMessage in Command.update for tool '{call['name']}', got multiple: {messages_update}." - ) - - if message.tool_call_id != call["id"]: - raise ValueError( - f"ToolMessage.tool_call_id must match the tool call id. Expected: {call['id']}, got: {message.tool_call_id} for tool '{call['name']}'." - ) - - message.name = call["name"] - have_seen_tool_messages = True + if message.tool_call_id == call["id"]: + message.name = call["name"] + has_matching_tool_message = True - # validate that we always have exactly one ToolMessage in Command.update if command is sent to the CURRENT graph - if updated_command.graph is None and not have_seen_tool_messages: + # validate that we always have a ToolMessage matching the tool call in + # Command.update if command is sent to the CURRENT graph + if updated_command.graph is None and not has_matching_tool_message: example_update = ( '`Command(update={"messages": [ToolMessage("Success", tool_call_id=tool_call_id), ...]}, ...)`' if input_type == "dict" else '`Command(update=[ToolMessage("Success", tool_call_id=tool_call_id), ...], ...)`' ) raise ValueError( - f"Expected exactly one message (ToolMessage) in Command.update for tool '{call['name']}', got: {messages_update}. " + f"Expected to have a matching ToolMessage in Command.update for tool '{call['name']}', got: {messages_update}. " "Every tool call (LLM requesting to call a tool) in the message history MUST have a corresponding ToolMessage. " f"You can fix it by modifying the tool to return {example_update}." ) diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index 9c6541d9a..ea44c4f0e 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -1249,6 +1249,33 @@ def no_update_tool(): } ) + # test validation (tool message with a wrong tool call ID) + with pytest.raises(ValueError): + + @dec_tool + def mismatching_tool_call_id_tool(): + """My tool""" + return Command( + update={"messages": [ToolMessage(content="foo", tool_call_id="2")]} + ) + + ToolNode([mismatching_tool_call_id_tool]).invoke( + { + "messages": [ + AIMessage( + "", + tool_calls=[ + { + "args": {}, + "id": "1", + "name": "mismatching_tool_call_id_tool", + } + ], + ) + ] + } + ) + # test validation (missing tool message in the update for parent graph is OK) @dec_tool def node_update_parent_tool(): @@ -1268,40 +1295,6 @@ def node_update_parent_tool(): } ) == [Command(update={"messages": []}, graph=Command.PARENT)] - # test validation (multiple tool messages) - with pytest.raises(ValueError): - for graph in (None, Command.PARENT): - - @dec_tool - def multiple_tool_messages_tool(): - """My tool""" - return Command( - update={ - "messages": [ - ToolMessage(content="foo", tool_call_id=""), - ToolMessage(content="bar", tool_call_id=""), - ] - }, - graph=graph, - ) - - ToolNode([multiple_tool_messages_tool]).invoke( - { - "messages": [ - AIMessage( - "", - tool_calls=[ - { - "args": {}, - "id": "1", - "name": "multiple_tool_messages_tool", - } - ], - ) - ] - } - ) - @pytest.mark.skipif( not IS_LANGCHAIN_CORE_030_OR_GREATER, @@ -1524,6 +1517,25 @@ def no_update_tool(): ] ) + # test validation (tool message with a wrong tool call ID) + with pytest.raises(ValueError): + + @dec_tool + def mismatching_tool_call_id_tool(): + """My tool""" + return Command(update=[ToolMessage(content="foo", tool_call_id="2")]) + + ToolNode([mismatching_tool_call_id_tool]).invoke( + [ + AIMessage( + "", + tool_calls=[ + {"args": {}, "id": "1", "name": "mismatching_tool_call_id_tool"} + ], + ) + ] + ) + # test validation (missing tool message in the update for parent graph is OK) @dec_tool def node_update_parent_tool(): @@ -1539,36 +1551,6 @@ def node_update_parent_tool(): ] ) == [Command(update=[], graph=Command.PARENT)] - # test validation (multiple tool messages) - with pytest.raises(ValueError): - for graph in (None, Command.PARENT): - - @dec_tool - def multiple_tool_messages_tool(): - """My tool""" - return Command( - update=[ - ToolMessage(content="foo", tool_call_id=""), - ToolMessage(content="bar", tool_call_id=""), - ], - graph=graph, - ) - - ToolNode([multiple_tool_messages_tool]).invoke( - [ - AIMessage( - "", - tool_calls=[ - { - "args": {}, - "id": "1", - "name": "multiple_tool_messages_tool", - } - ], - ) - ] - ) - @pytest.mark.skipif( not IS_LANGCHAIN_CORE_030_OR_GREATER,