diff --git a/libs/langgraph/langgraph/pregel/io.py b/libs/langgraph/langgraph/pregel/io.py index d54f8b8b2..70c5e60d9 100644 --- a/libs/langgraph/langgraph/pregel/io.py +++ b/libs/langgraph/langgraph/pregel/io.py @@ -1,3 +1,4 @@ +from collections import Counter from typing import Any, Iterator, Literal, Mapping, Optional, Sequence, TypeVar, Union from uuid import UUID @@ -181,12 +182,27 @@ def map_output_updates( (task.name, value) for chan, value in writes if chan == output_channels ) elif any(chan in output_channels for chan, _ in writes): - updated.append( - ( - task.name, - {chan: value for chan, value in writes if chan in output_channels}, + counts = Counter(chan for chan, _ in writes) + if any(counts[chan] > 1 for chan in output_channels): + updated.extend( + ( + task.name, + {chan: value}, + ) + for chan, value in writes + if chan in output_channels + ) + else: + updated.append( + ( + task.name, + { + chan: value + for chan, value in writes + if chan in output_channels + }, + ) ) - ) grouped: dict[str, list[Any]] = {t.name: [] for t, _ in output_tasks} for node, value in updated: grouped[node].append(value) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index b08cc87a7..c069293a8 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -5242,3 +5242,54 @@ def second_node(state: State): # Verify the error was recorded in checkpoint failed_checkpoint = next(c for c in history if c.tasks and c.tasks[0].error) assert "RuntimeError('Simulated failure')" in failed_checkpoint.tasks[0].error + + +def test_multiple_updates_root() -> None: + def node_a(state): + return [Command(update="a1"), Command(update="a2")] + + def node_b(state): + return "b" + + graph = ( + StateGraph(Annotated[str, operator.add]) + .add_sequence([node_a, node_b]) + .add_edge(START, "node_a") + .compile() + ) + + assert graph.invoke("") == "a1a2b" + + # only streams the last update from node_a + assert [c for c in graph.stream("", stream_mode="updates")] == [ + {"node_a": ["a1", "a2"]}, + {"node_b": "b"}, + ] + + +def test_multiple_updates() -> None: + class State(TypedDict): + foo: Annotated[str, operator.add] + + def node_a(state): + return [Command(update={"foo": "a1"}), Command(update={"foo": "a2"})] + + def node_b(state): + return {"foo": "b"} + + graph = ( + StateGraph(State) + .add_sequence([node_a, node_b]) + .add_edge(START, "node_a") + .compile() + ) + + assert graph.invoke({"foo": ""}) == { + "foo": "a1a2b", + } + + # only streams the last update from node_a + assert [c for c in graph.stream({"foo": ""}, stream_mode="updates")] == [ + {"node_a": [{"foo": "a1"}, {"foo": "a2"}]}, + {"node_b": {"foo": "b"}}, + ] diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index bb3e9ba08..38cbef2a8 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -6618,3 +6618,54 @@ async def second_node(state: State): # Verify the error was recorded in checkpoint failed_checkpoint = next(c for c in history if c.tasks and c.tasks[0].error) assert "RuntimeError('Simulated failure')" in failed_checkpoint.tasks[0].error + + +async def test_multiple_updates_root() -> None: + def node_a(state): + return [Command(update="a1"), Command(update="a2")] + + def node_b(state): + return "b" + + graph = ( + StateGraph(Annotated[str, operator.add]) + .add_sequence([node_a, node_b]) + .add_edge(START, "node_a") + .compile() + ) + + assert await graph.ainvoke("") == "a1a2b" + + # only streams the last update from node_a + assert [c async for c in graph.astream("", stream_mode="updates")] == [ + {"node_a": ["a1", "a2"]}, + {"node_b": "b"}, + ] + + +async def test_multiple_updates() -> None: + class State(TypedDict): + foo: Annotated[str, operator.add] + + def node_a(state): + return [Command(update={"foo": "a1"}), Command(update={"foo": "a2"})] + + def node_b(state): + return {"foo": "b"} + + graph = ( + StateGraph(State) + .add_sequence([node_a, node_b]) + .add_edge(START, "node_a") + .compile() + ) + + assert await graph.ainvoke({"foo": ""}) == { + "foo": "a1a2b", + } + + # only streams the last update from node_a + assert [c async for c in graph.astream({"foo": ""}, stream_mode="updates")] == [ + {"node_a": [{"foo": "a1"}, {"foo": "a2"}]}, + {"node_b": {"foo": "b"}}, + ]