Skip to content

Commit

Permalink
Add more tests for async cancellation (#2902)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos authored Dec 30, 2024
2 parents effddca + 400d837 commit 3aaa3e3
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 3 deletions.
4 changes: 1 addition & 3 deletions libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,4 @@ async def __aexit__(
traceback: Optional[TracebackType],
) -> Optional[bool]:
# unwind stack
return await asyncio.shield(
self.stack.__aexit__(exc_type, exc_value, traceback)
)
return await self.stack.__aexit__(exc_type, exc_value, traceback)
256 changes: 256 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,262 @@ def logic(inp: str) -> str:
pass


async def test_py_async_with_cancel_behavior() -> None:
"""This test confirms that in all versions of Python we support, __aexit__
is not cancelled when the coroutine containing the async with block is cancelled."""

logs: list[str] = []

class MyContextManager:
async def __aenter__(self):
logs.append("Entering")
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
logs.append("Starting exit")
try:
# Simulate some cleanup work
await asyncio.sleep(2)
logs.append("Cleanup completed")
except asyncio.CancelledError:
logs.append("Cleanup was cancelled!")
raise
logs.append("Exit finished")

async def main():
try:
async with MyContextManager():
logs.append("In context")
await asyncio.sleep(1)
logs.append("This won't print if cancelled")
except asyncio.CancelledError:
logs.append("Context was cancelled")
raise

# create task
t = asyncio.create_task(main())
# cancel after 0.2 seconds
await asyncio.sleep(0.2)
t.cancel()
# check logs before cancellation is handled
assert logs == [
"Entering",
"In context",
], "Cancelled before cleanup started"
# wait for task to finish
try:
await t
except asyncio.CancelledError:
# check logs after cancellation is handled
assert logs == [
"Entering",
"In context",
"Starting exit",
"Cleanup completed",
"Exit finished",
"Context was cancelled",
], "Cleanup started and finished after cancellation"
else:
assert False, "Task should be cancelled"


async def test_checkpoint_put_after_cancellation() -> None:
logs: list[str] = []

class LongPutCheckpointer(MemorySaver):
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
logs.append("checkpoint.aput.start")
try:
await asyncio.sleep(1)
return await super().aput(config, checkpoint, metadata, new_versions)
finally:
logs.append("checkpoint.aput.end")

inner_task_cancelled = False

async def awhile(input: Any) -> None:
logs.append("awhile.start")
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
nonlocal inner_task_cancelled
inner_task_cancelled = True
raise
finally:
logs.append("awhile.end")

builder = Graph()
builder.add_node("agent", awhile)
builder.set_entry_point("agent")
builder.set_finish_point("agent")

graph = builder.compile(checkpointer=LongPutCheckpointer())
thread1 = {"configurable": {"thread_id": "1"}}

# start the task
t = asyncio.create_task(graph.ainvoke(1, thread1))
# cancel after 0.2 seconds
await asyncio.sleep(0.2)
t.cancel()
# check logs before cancellation is handled
assert sorted(logs) == [
"awhile.start",
"checkpoint.aput.start",
], "Cancelled before checkpoint put started"
# wait for task to finish
try:
await t
except asyncio.CancelledError:
# check logs after cancellation is handled
assert sorted(logs) == [
"awhile.end",
"awhile.start",
"checkpoint.aput.end",
"checkpoint.aput.start",
], "Checkpoint put is not cancelled"
else:
assert False, "Task should be cancelled"


async def test_checkpoint_put_after_cancellation_stream_anext() -> None:
logs: list[str] = []

class LongPutCheckpointer(MemorySaver):
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
logs.append("checkpoint.aput.start")
try:
await asyncio.sleep(1)
return await super().aput(config, checkpoint, metadata, new_versions)
finally:
logs.append("checkpoint.aput.end")

inner_task_cancelled = False

async def awhile(input: Any) -> None:
logs.append("awhile.start")
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
nonlocal inner_task_cancelled
inner_task_cancelled = True
raise
finally:
logs.append("awhile.end")

builder = Graph()
builder.add_node("agent", awhile)
builder.set_entry_point("agent")
builder.set_finish_point("agent")

graph = builder.compile(checkpointer=LongPutCheckpointer())
thread1 = {"configurable": {"thread_id": "1"}}

# start the task
s = graph.astream(1, thread1)
t = asyncio.create_task(s.__anext__())
# cancel after 0.2 seconds
await asyncio.sleep(0.2)
t.cancel()
# check logs before cancellation is handled
assert sorted(logs) == [
"awhile.start",
"checkpoint.aput.start",
], "Cancelled before checkpoint put started"
# wait for task to finish
try:
await t
except asyncio.CancelledError:
# check logs after cancellation is handled
assert sorted(logs) == [
"awhile.end",
"awhile.start",
"checkpoint.aput.end",
"checkpoint.aput.start",
], "Checkpoint put is not cancelled"
else:
assert False, "Task should be cancelled"


async def test_checkpoint_put_after_cancellation_stream_events_anext() -> None:
logs: list[str] = []

class LongPutCheckpointer(MemorySaver):
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
logs.append("checkpoint.aput.start")
try:
await asyncio.sleep(1)
return await super().aput(config, checkpoint, metadata, new_versions)
finally:
logs.append("checkpoint.aput.end")

inner_task_cancelled = False

async def awhile(input: Any) -> None:
logs.append("awhile.start")
try:
await asyncio.sleep(1)
except asyncio.CancelledError:
nonlocal inner_task_cancelled
inner_task_cancelled = True
raise
finally:
logs.append("awhile.end")

builder = Graph()
builder.add_node("agent", awhile)
builder.set_entry_point("agent")
builder.set_finish_point("agent")

graph = builder.compile(checkpointer=LongPutCheckpointer())
thread1 = {"configurable": {"thread_id": "1"}}

# start the task
s = graph.astream_events(1, thread1, version="v2", include_names=["LangGraph"])
# skip first event (happens right away)
await s.__anext__()
# start the task for 2nd event
t = asyncio.create_task(s.__anext__())
# cancel after 0.2 seconds
await asyncio.sleep(0.2)
t.cancel()
# check logs before cancellation is handled
assert logs == [
"checkpoint.aput.start",
"awhile.start",
], "Cancelled before checkpoint put started"
# wait for task to finish
try:
await t
except asyncio.CancelledError:
# check logs after cancellation is handled
assert logs == [
"checkpoint.aput.start",
"awhile.start",
"awhile.end",
"checkpoint.aput.end",
], "Checkpoint put is not cancelled"
else:
assert False, "Task should be cancelled"


async def test_node_cancellation_on_external_cancel() -> None:
inner_task_cancelled = False

Expand Down

0 comments on commit 3aaa3e3

Please sign in to comment.