Skip to content

Commit

Permalink
langgraph: always raise NodeInterrupt in ToolNode if raised from a to…
Browse files Browse the repository at this point in the history
…ol (#2175)
  • Loading branch information
vbarda authored Oct 24, 2024
1 parent 91ad8b8 commit 6dacd1a
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 0 deletions.
17 changes: 17 additions & 0 deletions libs/langgraph/langgraph/prebuilt/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from langchain_core.tools import tool as create_tool
from typing_extensions import Annotated, get_args, get_origin

from langgraph.errors import GraphInterrupt
from langgraph.store.base import BaseStore
from langgraph.utils.runnable import RunnableCallable

Expand Down Expand Up @@ -267,6 +268,14 @@ def _run_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage:
Union[str, list], msg_content_output(tool_message.content)
)
return tool_message
# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios:
# (1) a NodeInterrupt is raised inside a tool
# (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
except GraphInterrupt as e:
raise e
except Exception as e:
if isinstance(self.handle_tool_errors, tuple):
handled_types: tuple = self.handle_tool_errors
Expand Down Expand Up @@ -300,6 +309,14 @@ async def _arun_one(self, call: ToolCall, config: RunnableConfig) -> ToolMessage
Union[str, list], msg_content_output(tool_message.content)
)
return tool_message
# GraphInterrupt is a special exception that will always be raised.
# It can be triggered in the following scenarios:
# (1) a NodeInterrupt is raised inside a tool
# (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool
# (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
# (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
except GraphInterrupt as e:
raise e
except Exception as e:
if isinstance(self.handle_tool_errors, tuple):
handled_types: tuple = self.handle_tool_errors
Expand Down
82 changes: 82 additions & 0 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from typing_extensions import TypedDict

from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.memory import MemorySaver
from langgraph.errors import NodeInterrupt
from langgraph.graph import START, MessagesState, StateGraph, add_messages
from langgraph.prebuilt import (
ToolNode,
Expand All @@ -52,6 +54,7 @@
)
from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore
from langgraph.types import Interrupt
from tests.conftest import (
ALL_CHECKPOINTERS_ASYNC,
ALL_CHECKPOINTERS_SYNC,
Expand Down Expand Up @@ -834,6 +837,85 @@ def test_tool_node_incorrect_tool_name():
assert tool_message.tool_call_id == "some 0"


def test_tool_node_node_interrupt():
def tool_normal(some_val: int) -> str:
"""Tool docstring."""
return "normal"

def tool_interrupt(some_val: int) -> str:
"""Tool docstring."""
raise NodeInterrupt("foo")

def handle(e: NodeInterrupt):
return "handled"

for handle_tool_errors in (True, (NodeInterrupt,), "handled", handle, False):
node = ToolNode([tool_interrupt], handle_tool_errors=handle_tool_errors)
with pytest.raises(NodeInterrupt) as exc_info:
node.invoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool_interrupt",
"args": {"some_val": 0},
"id": "some 0",
}
],
)
]
}
)
assert exc_info.value == "foo"

# test inside react agent
model = FakeToolCallingModel(
tool_calls=[
[
ToolCall(name="tool_interrupt", args={"some_val": 0}, id="1"),
ToolCall(name="tool_normal", args={"some_val": 1}, id="2"),
],
[],
]
)
checkpointer = MemorySaver()
config = {"configurable": {"thread_id": "1"}}
agent = create_react_agent(
model, [tool_interrupt, tool_normal], checkpointer=checkpointer
)
result = agent.invoke({"messages": [HumanMessage("hi?")]}, config)
assert result["messages"] == [
_AnyIdHumanMessage(
content="hi?",
),
AIMessage(
content="hi?",
id="0",
tool_calls=[
{
"name": "tool_interrupt",
"args": {"some_val": 0},
"id": "1",
"type": "tool_call",
},
{
"name": "tool_normal",
"args": {"some_val": 1},
"id": "2",
"type": "tool_call",
},
],
),
]
state = agent.get_state(config)
assert state.next == ("tools",)
task = state.tasks[0]
assert task.name == "tools"
assert task.interrupts == (Interrupt(value="foo", when="during"),)


def my_function(some_val: int, some_other_val: str) -> str:
return f"{some_val} - {some_other_val}"

Expand Down

0 comments on commit 6dacd1a

Please sign in to comment.