diff --git a/libs/langgraph/langgraph/constants.py b/libs/langgraph/langgraph/constants.py index bde74c438..13d136d48 100644 --- a/libs/langgraph/langgraph/constants.py +++ b/libs/langgraph/langgraph/constants.py @@ -68,8 +68,6 @@ # denotes push-style tasks, ie. those created by Send objects PULL = "__pregel_pull" # denotes pull-style tasks, ie. those triggered by edges -RUNTIME_PLACEHOLDER = "__pregel_runtime_placeholder__" -# placeholder for managed values replaced at runtime NS_SEP = "|" # for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph) NS_END = ":" @@ -103,7 +101,6 @@ # other constants PUSH, PULL, - RUNTIME_PLACEHOLDER, NS_SEP, NS_END, } diff --git a/libs/langgraph/langgraph/managed/base.py b/libs/langgraph/langgraph/managed/base.py index 3d4eb69f3..a8fc27c8e 100644 --- a/libs/langgraph/langgraph/managed/base.py +++ b/libs/langgraph/langgraph/managed/base.py @@ -16,16 +16,11 @@ from langchain_core.runnables import RunnableConfig from typing_extensions import Self, TypeGuard -from langgraph.constants import RUNTIME_PLACEHOLDER - V = TypeVar("V") U = TypeVar("U") class ManagedValue(ABC, Generic[V]): - runtime: bool = False - """Whether the managed value is always created at runtime, ie. never stored.""" - def __init__(self, config: RunnableConfig) -> None: self.config = config @@ -105,45 +100,4 @@ def is_writable_managed_value(value: Any) -> TypeGuard[Type[WritableManagedValue ChannelTypePlaceholder = object() -class ManagedValueMapping(dict[str, ManagedValue]): - def replace_runtime_values( - self, step: int, values: Union[dict[str, Any], Any] - ) -> None: - if not self or not values: - return - if all(not mv.runtime for mv in self.values()): - return - if isinstance(values, dict): - for key, value in values.items(): - for chan, mv in self.items(): - if mv.runtime and mv(step) is value: - values[key] = {RUNTIME_PLACEHOLDER: chan} - elif hasattr(values, "__dir__") and callable(values.__dir__): - for key in dir(values): - try: - value = getattr(values, key) - for chan, mv in self.items(): - if mv.runtime and mv(step) is value: - setattr(values, key, {RUNTIME_PLACEHOLDER: chan}) - except AttributeError: - pass - - def replace_runtime_placeholders( - self, step: int, values: Union[dict[str, Any], Any] - ) -> None: - if not self or not values: - return - if all(not mv.runtime for mv in self.values()): - return - if isinstance(values, dict): - for key, value in values.items(): - if isinstance(value, dict) and RUNTIME_PLACEHOLDER in value: - values[key] = self[value[RUNTIME_PLACEHOLDER]](step) - elif hasattr(values, "__dir__") and callable(values.__dir__): - for key in dir(values): - try: - value = getattr(values, key) - if isinstance(value, dict) and RUNTIME_PLACEHOLDER in value: - setattr(values, key, self[value[RUNTIME_PLACEHOLDER]](step)) - except AttributeError: - pass +ManagedValueMapping = dict[str, ManagedValue] diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index f78d072d6..687b3ebb3 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -852,11 +852,8 @@ def update_state( # deque.extend is thread-safe CONFIG_KEY_SEND: partial( local_write, - step + 1, writes.extend, - self.nodes, - channels, - managed, + self.nodes.keys(), ), CONFIG_KEY_READ: partial( local_read, @@ -1001,11 +998,8 @@ async def aupdate_state( # deque.extend is thread-safe CONFIG_KEY_SEND: partial( local_write, - step + 1, writes.extend, - self.nodes, - channels, - managed, + self.nodes.keys(), ), CONFIG_KEY_READ: partial( local_read, diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index f9b0096c8..9da421b95 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -155,11 +155,8 @@ def local_read( def local_write( - step: int, commit: Callable[[Sequence[tuple[str, Any]]], None], - processes: Mapping[str, PregelNode], - channels: Mapping[str, BaseChannel], - managed: ManagedValueMapping, + process_keys: Iterable[str], writes: Sequence[tuple[str, Any]], ) -> None: """Function injected under CONFIG_KEY_SEND in task config, to write to channels. @@ -167,15 +164,9 @@ def local_write( for chan, value in writes: if chan == TASKS: if not isinstance(value, Send): - raise InvalidUpdateError( - f"Invalid packet type, expected Packet, got {value}" - ) - if value.node not in processes: + raise InvalidUpdateError(f"Expected Send, got {value}") + if value.node not in process_keys: raise InvalidUpdateError(f"Invalid node name {value.node} in packet") - # replace any runtime values with placeholders - managed.replace_runtime_values(step, value.arg) - elif chan not in channels and chan not in managed: - logger.warning(f"Skipping write for channel '{chan}' which has no readers") commit(writes) @@ -411,7 +402,6 @@ def prepare_single_task( if for_execution: proc = processes[packet.node] if node := proc.node: - managed.replace_runtime_placeholders(step, packet.arg) if proc.metadata: metadata.update(proc.metadata) writes: deque[tuple[str, Any]] = deque() @@ -433,11 +423,8 @@ def prepare_single_task( # deque.extend is thread-safe CONFIG_KEY_SEND: partial( local_write, - step, writes.extend, - processes, - channels, - managed, + processes.keys(), ), CONFIG_KEY_READ: partial( local_read, @@ -543,11 +530,8 @@ def prepare_single_task( # deque.extend is thread-safe CONFIG_KEY_SEND: partial( local_write, - step, writes.extend, - processes, - channels, - managed, + processes.keys(), ), CONFIG_KEY_READ: partial( local_read, diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 605f00747..87a684ec7 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -4440,25 +4440,16 @@ def should_continue(data: AgentState) -> str: ), "nodes can pass extra data to their cond edges, which isn't saved in state" # Logic to decide whether to continue in the loop or exit if tool_calls := data["messages"][-1].tool_calls: - return [ - Send("tools", {"call": tool_call, "my_session": data["session"]}) - for tool_call in tool_calls - ] + return [Send("tools", tool_call) for tool_call in tool_calls] else: return END - class ToolInput(TypedDict): - call: ToolCall - my_session: httpx.Client - - def tools_node(input: ToolInput, config: RunnableConfig) -> AgentState: - assert isinstance(input["my_session"], httpx.Client) - tool_call = input["call"] - time.sleep(tool_call["args"].get("idx", 0) / 10) - output = tools_by_name[tool_call["name"]].invoke(tool_call["args"], config) + def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: + time.sleep(input["args"].get("idx", 0) / 10) + output = tools_by_name[input["name"]].invoke(input["args"], config) return { "messages": ToolMessage( - content=output, name=tool_call["name"], tool_call_id=tool_call["id"] + content=output, name=input["name"], tool_call_id=input["id"] ) } diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index d17925aca..0fe7210c8 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -4204,12 +4204,6 @@ def search_api(query: str) -> str: ] -# defined outside to allow deserializer to see it -class ToolInput(BaseModel, arbitrary_types_allowed=True): - call: ToolCall - my_session: httpx.AsyncClient - - @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC) async def test_state_graph_packets(checkpointer_name: str) -> None: from langchain_core.language_models.fake_chat_models import ( @@ -4273,23 +4267,16 @@ def should_continue(data: AgentState) -> str: assert isinstance(data["session"], httpx.AsyncClient) # Logic to decide whether to continue in the loop or exit if tool_calls := data["messages"][-1].tool_calls: - return [ - Send("tools", ToolInput(call=tool_call, my_session=data["session"])) - for tool_call in tool_calls - ] + return [Send("tools", tool_call) for tool_call in tool_calls] else: return END - async def tools_node(input: ToolInput, config: RunnableConfig) -> AgentState: - assert isinstance(input.my_session, httpx.AsyncClient) - tool_call = input.call - await asyncio.sleep(tool_call["args"].get("idx", 0) / 10) - output = await tools_by_name[tool_call["name"]].ainvoke( - tool_call["args"], config - ) + async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState: + await asyncio.sleep(input["args"].get("idx", 0) / 10) + output = await tools_by_name[input["name"]].ainvoke(input["args"], config) return { "messages": ToolMessage( - content=output, name=tool_call["name"], tool_call_id=tool_call["id"] + content=output, name=input["name"], tool_call_id=input["id"] ) }