diff --git a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py index da70f9e87..1b4797b80 100644 --- a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py +++ b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py @@ -554,6 +554,10 @@ class Agent,Tools otherClass ) model_runnable = preprocessor | model + # If any of the tools are configured to return_directly after running, + # our graph needs to check if these were called + should_return_direct = {t.name for t in tool_classes if t.return_direct} + # Define the function that calls the model def call_model(state: AgentState, config: RunnableConfig) -> AgentState: _validate_chat_history(state["messages"]) @@ -673,10 +677,6 @@ def should_continue(state: AgentState) -> Literal["tools", "__end__"]: should_continue, ) - # If any of the tools are configured to return_directly after running, - # our graph needs to check if these were called - should_return_direct = {t.name for t in tool_classes if t.return_direct} - def route_tool_responses(state: AgentState) -> Literal["agent", "__end__"]: for m in reversed(state["messages"]): if not isinstance(m, ToolMessage): diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index ea44c4f0e..3186cac9e 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -1,4 +1,5 @@ import dataclasses +import inspect import json from functools import partial from typing import ( @@ -2040,3 +2041,9 @@ def foo(a: str, b: int) -> float: return 0.0 assert _get_state_args(foo) == {"a": None, "b": "bar"} + + +def test_inspect_react() -> None: + model = FakeToolCallingModel(tool_calls=[]) + agent = create_react_agent(model, []) + inspect.getclosurevars(agent.nodes["agent"].bound.func)