From 11e80210a2ca8a0691132d2315cd881dc010cf77 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 5 Dec 2024 15:11:09 -0800 Subject: [PATCH 1/3] lib: Performance improvements - don't create contextvars.Context/asyncio.Task in RunnableSeq (not needed as each step creates it if necessary) - don't run in-memory-saver methods in background threads (no point as they hold the gil) - avoid calling should_interrupt when no interrupts set --- .../langgraph/checkpoint/memory/__init__.py | 34 +++---------------- libs/langgraph/langgraph/pregel/loop.py | 12 ++++--- libs/langgraph/langgraph/utils/runnable.py | 16 +++------ 3 files changed, 17 insertions(+), 45 deletions(-) diff --git a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py index e30c082c7..b11f6e21c 100644 --- a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py @@ -1,4 +1,3 @@ -import asyncio import logging import os import pickle @@ -6,7 +5,6 @@ import shutil from collections import defaultdict from contextlib import AbstractAsyncContextManager, AbstractContextManager, ExitStack -from functools import partial from types import TracebackType from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple, Type @@ -395,9 +393,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: Returns: Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. """ - return await asyncio.get_running_loop().run_in_executor( - None, self.get_tuple, config - ) + return self.get_tuple(config) async def alist( self, @@ -418,24 +414,8 @@ async def alist( Yields: AsyncIterator[CheckpointTuple]: An asynchronous iterator of checkpoint tuples. """ - loop = asyncio.get_running_loop() - iter = await loop.run_in_executor( - None, - partial( - self.list, - before=before, - limit=limit, - filter=filter, - ), - config, - ) - while True: - # handling StopIteration exception inside coroutine won't work - # as expected, so using next() with default value to break the loop - if item := await loop.run_in_executor(None, next, iter, None): - yield item - else: - break + for item in self.list(config, filter=filter, before=before, limit=limit): + yield item async def aput( self, @@ -455,9 +435,7 @@ async def aput( Returns: RunnableConfig: The updated config containing the saved checkpoint's timestamp. """ - return await asyncio.get_running_loop().run_in_executor( - None, self.put, config, checkpoint, metadata, new_versions - ) + return self.put(config, checkpoint, metadata, new_versions) async def aput_writes( self, @@ -474,10 +452,8 @@ async def aput_writes( config (RunnableConfig): The config to associate with the writes. writes (List[Tuple[str, Any]]): The writes to save, each as a (channel, value) pair. task_id (str): Identifier for the task creating the writes. + return self.put_writes(config, writes, task_id) """ - return await asyncio.get_running_loop().run_in_executor( - None, self.put_writes, config, writes, task_id - ) def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str: if current is None: diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index a8e945edd..cf0716e7c 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -311,7 +311,9 @@ def accept_push( ) -> Optional[PregelExecutableTask]: """Accept a PUSH from a task, potentially returning a new task to start.""" # don't start if we should interrupt *after* the original task - if should_interrupt(self.checkpoint, self.interrupt_after, [task]): + if self.interrupt_after and should_interrupt( + self.checkpoint, self.interrupt_after, [task] + ): self.to_interrupt.append(task) return if pushed := cast( @@ -333,7 +335,9 @@ def accept_push( ), ): # don't start if we should interrupt *before* the new task - if should_interrupt(self.checkpoint, self.interrupt_before, [pushed]): + if self.interrupt_before and should_interrupt( + self.checkpoint, self.interrupt_before, [pushed] + ): self.to_interrupt.append(pushed) return # produce debug output @@ -409,7 +413,7 @@ def tick( } ) # after execution, check if we should interrupt - if should_interrupt( + if self.interrupt_after and should_interrupt( self.checkpoint, self.interrupt_after, self.tasks.values() ): self.status = "interrupt_after" @@ -481,7 +485,7 @@ def tick( return self.tick(input_keys=input_keys) # before execution, check if we should interrupt - if should_interrupt( + if self.interrupt_before and should_interrupt( self.checkpoint, self.interrupt_before, self.tasks.values() ): self.status = "interrupt_before" diff --git a/libs/langgraph/langgraph/utils/runnable.py b/libs/langgraph/langgraph/utils/runnable.py index ccebba862..7cd6a85b9 100644 --- a/libs/langgraph/langgraph/utils/runnable.py +++ b/libs/langgraph/langgraph/utils/runnable.py @@ -404,12 +404,10 @@ def invoke( config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{i+1}") ) - context = copy_context() - context.run(_set_config_context, config) if i == 0: - input = context.run(step.invoke, input, config, **kwargs) + input = step.invoke(input, config, **kwargs) else: - input = context.run(step.invoke, input, config) + input = step.invoke(input, config) # finish the root run except BaseException as e: run_manager.on_chain_error(e) @@ -443,16 +441,10 @@ async def ainvoke( config = patch_config( config, callbacks=run_manager.get_child(f"seq:step:{i+1}") ) - context = copy_context() - context.run(_set_config_context, config) if i == 0: - coro = step.ainvoke(input, config, **kwargs) - else: - coro = step.ainvoke(input, config) - if ASYNCIO_ACCEPTS_CONTEXT: - input = await asyncio.create_task(coro, context=context) + input = await step.ainvoke(input, config, **kwargs) else: - input = await asyncio.create_task(coro) + input = await step.ainvoke(input, config) # finish the root run except BaseException as e: await run_manager.on_chain_error(e) From 7f8ec2c5905dece1d8f2b188c572d557f244ffe6 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 10 Dec 2024 13:46:45 -0800 Subject: [PATCH 2/3] Fix --- libs/checkpoint/langgraph/checkpoint/memory/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py index b11f6e21c..cb6b7b852 100644 --- a/libs/checkpoint/langgraph/checkpoint/memory/__init__.py +++ b/libs/checkpoint/langgraph/checkpoint/memory/__init__.py @@ -454,6 +454,7 @@ async def aput_writes( task_id (str): Identifier for the task creating the writes. return self.put_writes(config, writes, task_id) """ + return self.put_writes(config, writes, task_id) def get_next_version(self, current: Optional[str], channel: ChannelProtocol) -> str: if current is None: From 30f852e7b29452dc3e05864afc9c6d685162a519 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 10 Dec 2024 13:56:26 -0800 Subject: [PATCH 3/3] Fix --- libs/langgraph/tests/test_pregel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/langgraph/tests/test_pregel.py b/libs/langgraph/tests/test_pregel.py index 48c15e133..ad48108a9 100644 --- a/libs/langgraph/tests/test_pregel.py +++ b/libs/langgraph/tests/test_pregel.py @@ -12901,7 +12901,7 @@ def edit(state: JokeState): metadata={ "step": 1, "source": "loop", - "writes": {"edit": None}, + "writes": None, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"), @@ -12946,7 +12946,7 @@ def edit(state: JokeState): metadata={ "step": 1, "source": "loop", - "writes": {"edit": None}, + "writes": None, "parents": {"": AnyStr()}, "thread_id": "1", "checkpoint_ns": AnyStr("generate_joke:"),