Skip to content

Commit

Permalink
Fix stream_mode=updates for cases where one node returns multiple upd…
Browse files Browse the repository at this point in the history
…ates for same key (#2903)
  • Loading branch information
nfcampos authored Jan 2, 2025
2 parents c86f0af + dac8495 commit f70bfc6
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 5 deletions.
26 changes: 21 additions & 5 deletions libs/langgraph/langgraph/pregel/io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import Counter
from typing import Any, Iterator, Literal, Mapping, Optional, Sequence, TypeVar, Union
from uuid import UUID

Expand Down Expand Up @@ -181,12 +182,27 @@ def map_output_updates(
(task.name, value) for chan, value in writes if chan == output_channels
)
elif any(chan in output_channels for chan, _ in writes):
updated.append(
(
task.name,
{chan: value for chan, value in writes if chan in output_channels},
counts = Counter(chan for chan, _ in writes)
if any(counts[chan] > 1 for chan in output_channels):
updated.extend(
(
task.name,
{chan: value},
)
for chan, value in writes
if chan in output_channels
)
else:
updated.append(
(
task.name,
{
chan: value
for chan, value in writes
if chan in output_channels
},
)
)
)
grouped: dict[str, list[Any]] = {t.name: [] for t, _ in output_tasks}
for node, value in updated:
grouped[node].append(value)
Expand Down
51 changes: 51 additions & 0 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5242,3 +5242,54 @@ def second_node(state: State):
# Verify the error was recorded in checkpoint
failed_checkpoint = next(c for c in history if c.tasks and c.tasks[0].error)
assert "RuntimeError('Simulated failure')" in failed_checkpoint.tasks[0].error


def test_multiple_updates_root() -> None:
def node_a(state):
return [Command(update="a1"), Command(update="a2")]

def node_b(state):
return "b"

graph = (
StateGraph(Annotated[str, operator.add])
.add_sequence([node_a, node_b])
.add_edge(START, "node_a")
.compile()
)

assert graph.invoke("") == "a1a2b"

# only streams the last update from node_a
assert [c for c in graph.stream("", stream_mode="updates")] == [
{"node_a": ["a1", "a2"]},
{"node_b": "b"},
]


def test_multiple_updates() -> None:
class State(TypedDict):
foo: Annotated[str, operator.add]

def node_a(state):
return [Command(update={"foo": "a1"}), Command(update={"foo": "a2"})]

def node_b(state):
return {"foo": "b"}

graph = (
StateGraph(State)
.add_sequence([node_a, node_b])
.add_edge(START, "node_a")
.compile()
)

assert graph.invoke({"foo": ""}) == {
"foo": "a1a2b",
}

# only streams the last update from node_a
assert [c for c in graph.stream({"foo": ""}, stream_mode="updates")] == [
{"node_a": [{"foo": "a1"}, {"foo": "a2"}]},
{"node_b": {"foo": "b"}},
]
51 changes: 51 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -6618,3 +6618,54 @@ async def second_node(state: State):
# Verify the error was recorded in checkpoint
failed_checkpoint = next(c for c in history if c.tasks and c.tasks[0].error)
assert "RuntimeError('Simulated failure')" in failed_checkpoint.tasks[0].error


async def test_multiple_updates_root() -> None:
def node_a(state):
return [Command(update="a1"), Command(update="a2")]

def node_b(state):
return "b"

graph = (
StateGraph(Annotated[str, operator.add])
.add_sequence([node_a, node_b])
.add_edge(START, "node_a")
.compile()
)

assert await graph.ainvoke("") == "a1a2b"

# only streams the last update from node_a
assert [c async for c in graph.astream("", stream_mode="updates")] == [
{"node_a": ["a1", "a2"]},
{"node_b": "b"},
]


async def test_multiple_updates() -> None:
class State(TypedDict):
foo: Annotated[str, operator.add]

def node_a(state):
return [Command(update={"foo": "a1"}), Command(update={"foo": "a2"})]

def node_b(state):
return {"foo": "b"}

graph = (
StateGraph(State)
.add_sequence([node_a, node_b])
.add_edge(START, "node_a")
.compile()
)

assert await graph.ainvoke({"foo": ""}) == {
"foo": "a1a2b",
}

# only streams the last update from node_a
assert [c async for c in graph.astream({"foo": ""}, stream_mode="updates")] == [
{"node_a": [{"foo": "a1"}, {"foo": "a2"}]},
{"node_b": {"foo": "b"}},
]

0 comments on commit f70bfc6

Please sign in to comment.