From 76199701b0079fc75a8c92234494f52175bedd32 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 30 Dec 2024 18:56:56 +0000 Subject: [PATCH 1/4] Add more tests for async cancellation --- libs/langgraph/langgraph/pregel/loop.py | 4 +- libs/langgraph/tests/test_pregel_async.py | 256 ++++++++++++++++++++++ 2 files changed, 257 insertions(+), 3 deletions(-) diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index cf0716e7c..89ecc4420 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -1032,6 +1032,4 @@ async def __aexit__( traceback: Optional[TracebackType], ) -> Optional[bool]: # unwind stack - return await asyncio.shield( - self.stack.__aexit__(exc_type, exc_value, traceback) - ) + return await self.stack.__aexit__(exc_type, exc_value, traceback) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 02b985c62..5122d9cec 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -180,6 +180,262 @@ def logic(inp: str) -> str: pass +async def test_py_async_with_cancel_behavior() -> None: + """This test confirms that in all versions of Python we support, __aexit__ + is not cancelled when the coroutine containing the async with block is cancelled.""" + + logs: list[str] = [] + + class MyContextManager: + async def __aenter__(self): + logs.append("Entering") + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + logs.append("Starting exit") + try: + # Simulate some cleanup work + await asyncio.sleep(2) + logs.append("Cleanup completed") + except asyncio.CancelledError: + logs.append("Cleanup was cancelled!") + raise + logs.append("Exit finished") + + async def main(): + try: + async with MyContextManager(): + logs.append("In context") + await asyncio.sleep(1) + logs.append("This won't print if cancelled") + except asyncio.CancelledError: + logs.append("Context was cancelled") + raise + + # create task + t = asyncio.create_task(main()) + # cancel after 0.2 seconds + await asyncio.sleep(0.2) + t.cancel() + # check logs before cancellation is handled + assert logs == [ + "Entering", + "In context", + ], "Cancelled before cleanup started" + # wait for task to finish + try: + await t + except asyncio.CancelledError: + # check logs after cancellation is handled + assert logs == [ + "Entering", + "In context", + "Starting exit", + "Cleanup completed", + "Exit finished", + "Context was cancelled", + ], "Cleanup started and finished after cancellation" + else: + assert False, "Task should be cancelled" + + +async def test_checkpoint_put_after_cancellation() -> None: + logs: list[str] = [] + + class LongPutCheckpointer(MemorySaver): + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + logs.append("checkpoint.aput.start") + try: + await asyncio.sleep(1) + return await super().aput(config, checkpoint, metadata, new_versions) + finally: + logs.append("checkpoint.aput.end") + + inner_task_cancelled = False + + async def awhile(input: Any) -> None: + logs.append("awhile.start") + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + nonlocal inner_task_cancelled + inner_task_cancelled = True + raise + finally: + logs.append("awhile.end") + + builder = Graph() + builder.add_node("agent", awhile) + builder.set_entry_point("agent") + builder.set_finish_point("agent") + + graph = builder.compile(checkpointer=LongPutCheckpointer()) + thread1 = {"configurable": {"thread_id": "1"}} + + # start the task + t = asyncio.create_task(graph.ainvoke(1, thread1)) + # cancel after 0.2 seconds + await asyncio.sleep(0.2) + t.cancel() + # check logs before cancellation is handled + assert logs == [ + "checkpoint.aput.start", + "awhile.start", + ], "Cancelled before checkpoint put started" + # wait for task to finish + try: + await t + except asyncio.CancelledError: + # check logs after cancellation is handled + assert logs == [ + "checkpoint.aput.start", + "awhile.start", + "awhile.end", + "checkpoint.aput.end", + ], "Checkpoint put is not cancelled" + else: + assert False, "Task should be cancelled" + + +async def test_checkpoint_put_after_cancellation_stream_anext() -> None: + logs: list[str] = [] + + class LongPutCheckpointer(MemorySaver): + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + logs.append("checkpoint.aput.start") + try: + await asyncio.sleep(1) + return await super().aput(config, checkpoint, metadata, new_versions) + finally: + logs.append("checkpoint.aput.end") + + inner_task_cancelled = False + + async def awhile(input: Any) -> None: + logs.append("awhile.start") + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + nonlocal inner_task_cancelled + inner_task_cancelled = True + raise + finally: + logs.append("awhile.end") + + builder = Graph() + builder.add_node("agent", awhile) + builder.set_entry_point("agent") + builder.set_finish_point("agent") + + graph = builder.compile(checkpointer=LongPutCheckpointer()) + thread1 = {"configurable": {"thread_id": "1"}} + + # start the task + s = graph.astream(1, thread1) + t = asyncio.create_task(s.__anext__()) + # cancel after 0.2 seconds + await asyncio.sleep(0.2) + t.cancel() + # check logs before cancellation is handled + assert logs == [ + "checkpoint.aput.start", + "awhile.start", + ], "Cancelled before checkpoint put started" + # wait for task to finish + try: + await t + except asyncio.CancelledError: + # check logs after cancellation is handled + assert logs == [ + "checkpoint.aput.start", + "awhile.start", + "awhile.end", + "checkpoint.aput.end", + ], "Checkpoint put is not cancelled" + else: + assert False, "Task should be cancelled" + + +async def test_checkpoint_put_after_cancellation_stream_events_anext() -> None: + logs: list[str] = [] + + class LongPutCheckpointer(MemorySaver): + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + logs.append("checkpoint.aput.start") + try: + await asyncio.sleep(1) + return await super().aput(config, checkpoint, metadata, new_versions) + finally: + logs.append("checkpoint.aput.end") + + inner_task_cancelled = False + + async def awhile(input: Any) -> None: + logs.append("awhile.start") + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + nonlocal inner_task_cancelled + inner_task_cancelled = True + raise + finally: + logs.append("awhile.end") + + builder = Graph() + builder.add_node("agent", awhile) + builder.set_entry_point("agent") + builder.set_finish_point("agent") + + graph = builder.compile(checkpointer=LongPutCheckpointer()) + thread1 = {"configurable": {"thread_id": "1"}} + + # start the task + s = graph.astream_events(1, thread1, version="v2", include_names=["LangGraph"]) + # skip first event (happens right away) + await s.__anext__() + # start the task for 2nd event + t = asyncio.create_task(s.__anext__()) + # cancel after 0.2 seconds + await asyncio.sleep(0.2) + t.cancel() + # check logs before cancellation is handled + assert logs == [ + "checkpoint.aput.start", + "awhile.start", + ], "Cancelled before checkpoint put started" + # wait for task to finish + try: + await t + except asyncio.CancelledError: + # check logs after cancellation is handled + assert logs == [ + "checkpoint.aput.start", + "awhile.start", + "awhile.end", + "checkpoint.aput.end", + ], "Checkpoint put is not cancelled" + else: + assert False, "Task should be cancelled" + + async def test_node_cancellation_on_external_cancel() -> None: inner_task_cancelled = False From 01e5ecedfd6c06abed7f11ec506f8d775314d36f Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 30 Dec 2024 19:26:18 +0000 Subject: [PATCH 2/4] Remove assertion of order --- libs/langgraph/tests/test_pregel_async.py | 24 +++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 5122d9cec..1bdc051b6 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -284,21 +284,25 @@ async def awhile(input: Any) -> None: await asyncio.sleep(0.2) t.cancel() # check logs before cancellation is handled - assert logs == [ - "checkpoint.aput.start", - "awhile.start", - ], "Cancelled before checkpoint put started" + assert logs == sorted( + [ + "checkpoint.aput.start", + "awhile.start", + ] + ), "Cancelled before checkpoint put started" # wait for task to finish try: await t except asyncio.CancelledError: # check logs after cancellation is handled - assert logs == [ - "checkpoint.aput.start", - "awhile.start", - "awhile.end", - "checkpoint.aput.end", - ], "Checkpoint put is not cancelled" + assert logs == sorted( + [ + "checkpoint.aput.start", + "awhile.start", + "awhile.end", + "checkpoint.aput.end", + ] + ), "Checkpoint put is not cancelled" else: assert False, "Task should be cancelled" From 1d9c7ef46137128c05fdec95c70f15516fbaf7d4 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 30 Dec 2024 19:29:14 +0000 Subject: [PATCH 3/4] Fix --- libs/langgraph/tests/test_pregel_async.py | 24 ++++++++++------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index 1bdc051b6..a3e5026c2 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -284,25 +284,21 @@ async def awhile(input: Any) -> None: await asyncio.sleep(0.2) t.cancel() # check logs before cancellation is handled - assert logs == sorted( - [ - "checkpoint.aput.start", - "awhile.start", - ] - ), "Cancelled before checkpoint put started" + assert sorted(logs) == [ + "awhile.start", + "checkpoint.aput.start", + ], "Cancelled before checkpoint put started" # wait for task to finish try: await t except asyncio.CancelledError: # check logs after cancellation is handled - assert logs == sorted( - [ - "checkpoint.aput.start", - "awhile.start", - "awhile.end", - "checkpoint.aput.end", - ] - ), "Checkpoint put is not cancelled" + assert sorted(logs) == [ + "awhile.end", + "awhile.start", + "checkpoint.aput.end", + "checkpoint.aput.start", + ], "Checkpoint put is not cancelled" else: assert False, "Task should be cancelled" From 400d83708ad7fd22a235b327d005dd7b226c5394 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 30 Dec 2024 20:01:35 +0000 Subject: [PATCH 4/4] Fix --- libs/langgraph/tests/test_pregel_async.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index a3e5026c2..bb3e9ba08 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -349,20 +349,20 @@ async def awhile(input: Any) -> None: await asyncio.sleep(0.2) t.cancel() # check logs before cancellation is handled - assert logs == [ - "checkpoint.aput.start", + assert sorted(logs) == [ "awhile.start", + "checkpoint.aput.start", ], "Cancelled before checkpoint put started" # wait for task to finish try: await t except asyncio.CancelledError: # check logs after cancellation is handled - assert logs == [ - "checkpoint.aput.start", - "awhile.start", + assert sorted(logs) == [ "awhile.end", + "awhile.start", "checkpoint.aput.end", + "checkpoint.aput.start", ], "Checkpoint put is not cancelled" else: assert False, "Task should be cancelled"