From 0b1d0eb40d7cb6cfaf85ebe536255a700d403fd6 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Mon, 16 Sep 2024 12:53:43 -0700 Subject: [PATCH] Retry subgraph starting at failing node (#1695) * Failing test * Ensure retried subgraphs resume from current point (if any) * Lint * Cleanup Test --------- Co-authored-by: Nuno Campos --- libs/langgraph/langgraph/pregel/retry.py | 14 +++++-- libs/langgraph/tests/test_pregel.py | 53 ++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/retry.py b/libs/langgraph/langgraph/pregel/retry.py index 486584809..cdf7b33d3 100644 --- a/libs/langgraph/langgraph/pregel/retry.py +++ b/libs/langgraph/langgraph/pregel/retry.py @@ -4,8 +4,10 @@ import time from typing import Optional +from langgraph.constants import CONFIG_KEY_RESUMING from langgraph.errors import GraphInterrupt from langgraph.pregel.types import PregelExecutableTask, RetryPolicy +from langgraph.utils.config import patch_configurable logger = logging.getLogger(__name__) @@ -18,12 +20,13 @@ def run_with_retry( retry_policy = task.retry_policy or retry_policy interval = retry_policy.initial_interval if retry_policy else 0 attempts = 0 + config = task.config while True: try: # clear any writes from previous attempts task.writes.clear() # run the task - task.proc.invoke(task.input, task.config) + task.proc.invoke(task.input, config) # if successful, end break except GraphInterrupt: @@ -56,6 +59,8 @@ def run_with_retry( f"Retrying task {task.name} after {interval:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}", exc_info=exc, ) + # signal subgraphs to resume (if available) + config = patch_configurable(config, {CONFIG_KEY_RESUMING: True}) async def arun_with_retry( @@ -67,16 +72,17 @@ async def arun_with_retry( retry_policy = task.retry_policy or retry_policy interval = retry_policy.initial_interval if retry_policy else 0 attempts = 0 + config = task.config while True: try: # clear any writes from previous attempts task.writes.clear() # run the task if stream: - async for _ in task.proc.astream(task.input, task.config): + async for _ in task.proc.astream(task.input, config): pass else: - await task.proc.ainvoke(task.input, task.config) + await task.proc.ainvoke(task.input, config) # if successful, end break except GraphInterrupt: @@ -109,3 +115,5 @@ async def arun_with_retry( f"Retrying task {task.name} after {interval:.2f} seconds (attempt {attempts}) after {exc.__class__.__name__} {exc}", exc_info=exc, ) + # signal subgraphs to resume (if available) + config = patch_configurable(config, {CONFIG_KEY_RESUMING: True}) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index b59309081..8c88f8ad0 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -11023,3 +11023,56 @@ def _node(state: State): app = parent.compile() assert app.get_graph(xray=True).draw_mermaid() == snapshot + + +def test_subgraph_retries(): + class State(TypedDict): + count: int + + class ChildState(State): + some_list: Annotated[list, operator.add] + + called_times = 0 + + class RandomError(ValueError): + """This will be retried on.""" + + def parent_node(state: State): + return {"count": state["count"] + 1} + + def child_node_a(state: ChildState): + nonlocal called_times + # We want it to retry only on node_b + # NOT re-compute the whole graph. + assert not called_times + called_times += 1 + return {"some_list": ["val"]} + + def child_node_b(state: ChildState): + raise RandomError("First attempt fails") + + child = StateGraph(ChildState) + child.add_node(child_node_a) + child.add_node(child_node_b) + child.add_edge("__start__", "child_node_a") + child.add_edge("child_node_a", "child_node_b") + + parent = StateGraph(State) + parent.add_node("parent_node", parent_node) + parent.add_node( + "child_graph", + child.compile(), + retry=RetryPolicy( + max_attempts=3, + retry_on=(RandomError,), + backoff_factor=0.0001, + initial_interval=0.0001, + ), + ) + parent.add_edge("parent_node", "child_graph") + parent.set_entry_point("parent_node") + + checkpointer = MemorySaver() + app = parent.compile(checkpointer=checkpointer) + with pytest.raises(RandomError): + app.invoke({"count": 0}, {"configurable": {"thread_id": "foo"}})