From 3f1bdb9ebf5effb39b7177cb536e5d0e4b9106ce Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Mon, 9 Dec 2024 07:12:52 -0800 Subject: [PATCH] Add sync support for the AsyncPostgresStore (#2673) --- .../langgraph/store/postgres/aio.py | 3 - .../tests/test_async_store.py | 132 +++++++++++++++++- libs/checkpoint/langgraph/store/base/batch.py | 111 ++++++++++++++- libs/langgraph/tests/test_pregel_async.py | 10 ++ 4 files changed, 250 insertions(+), 6 deletions(-) diff --git a/libs/checkpoint-postgres/langgraph/store/postgres/aio.py b/libs/checkpoint-postgres/langgraph/store/postgres/aio.py index e62d360cc..2354b3a8f 100644 --- a/libs/checkpoint-postgres/langgraph/store/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/store/postgres/aio.py @@ -155,9 +155,6 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]: return results - def batch(self, ops: Iterable[Op]) -> list[Result]: - return asyncio.run_coroutine_threadsafe(self.abatch(ops), self.loop).result() - @classmethod @asynccontextmanager async def from_conn_string( diff --git a/libs/checkpoint-postgres/tests/test_async_store.py b/libs/checkpoint-postgres/tests/test_async_store.py index eda0e2820..068ec1502 100644 --- a/libs/checkpoint-postgres/tests/test_async_store.py +++ b/libs/checkpoint-postgres/tests/test_async_store.py @@ -1,8 +1,10 @@ # type: ignore +import asyncio import itertools import sys import uuid from collections.abc import AsyncIterator +from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from typing import Any, Optional @@ -10,7 +12,13 @@ from langchain_core.embeddings import Embeddings from psycopg import AsyncConnection -from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp +from langgraph.store.base import ( + GetOp, + Item, + ListNamespacesOp, + PutOp, + SearchOp, +) from langgraph.store.postgres import AsyncPostgresStore from tests.conftest import ( DEFAULT_URI, @@ -63,6 +71,128 @@ async def store(request) -> AsyncIterator[AsyncPostgresStore]: await conn.execute(f"DROP DATABASE {database}") +async def test_no_running_loop(store: AsyncPostgresStore) -> None: + with pytest.raises(asyncio.InvalidStateError): + store.put(("foo", "bar"), "baz", {"val": "baz"}) + with pytest.raises(asyncio.InvalidStateError): + store.get(("foo", "bar"), "baz") + with pytest.raises(asyncio.InvalidStateError): + store.delete(("foo", "bar"), "baz") + with pytest.raises(asyncio.InvalidStateError): + store.search(("foo", "bar")) + with pytest.raises(asyncio.InvalidStateError): + store.list_namespaces(prefix=("foo",)) + with pytest.raises(asyncio.InvalidStateError): + store.batch([PutOp(namespace=("foo", "bar"), key="baz", value={"val": "baz"})]) + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(store.put, ("foo", "bar"), "baz", {"val": "baz"}) + result = await asyncio.wrap_future(future) + assert result is None + future = executor.submit(store.get, ("foo", "bar"), "baz") + result = await asyncio.wrap_future(future) + assert result.value == {"val": "baz"} + result = await asyncio.wrap_future( + executor.submit(store.list_namespaces, prefix=("foo",)) + ) + + +async def test_large_batches(request: Any, store: AsyncPostgresStore) -> None: + N = 100 # less important that we are performant here + M = 10 + + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [] + for m in range(M): + for i in range(N): + futures += [ + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ), + executor.submit( + store.get, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ), + executor.submit( + store.list_namespaces, + prefix=None, + max_depth=m + 1, + ), + executor.submit( + store.search, + ("test",), + ), + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ), + executor.submit( + store.put, + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + None, + ), + ] + + results = await asyncio.gather( + *(asyncio.wrap_future(future) for future in futures) + ) + assert len(results) == M * N * 6 + + +async def test_large_batches_async(store: AsyncPostgresStore) -> None: + N = 1000 + M = 10 + coros = [] + for m in range(M): + for i in range(N): + coros.append( + store.aput( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ) + ) + coros.append( + store.aget( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ) + ) + coros.append( + store.alist_namespaces( + prefix=None, + max_depth=m + 1, + ) + ) + coros.append( + store.asearch( + ("test",), + ) + ) + coros.append( + store.aput( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + value={"foo": "bar" + str(i)}, + ) + ) + coros.append( + store.adelete( + ("test", "foo", "bar", "baz", str(m % 2)), + f"key{i}", + ) + ) + + results = await asyncio.gather(*coros) + assert len(results) == M * N * 6 + + async def test_abatch_order(store: AsyncPostgresStore) -> None: # Setup test data await store.aput(("test", "foo"), "key1", {"data": "value1"}) diff --git a/libs/checkpoint/langgraph/store/base/batch.py b/libs/checkpoint/langgraph/store/base/batch.py index 33c502574..6cfc11419 100644 --- a/libs/checkpoint/langgraph/store/base/batch.py +++ b/libs/checkpoint/langgraph/store/base/batch.py @@ -1,6 +1,7 @@ import asyncio +import functools import weakref -from typing import Any, Literal, Optional, Union +from typing import Any, Callable, Iterable, Literal, Optional, TypeVar, Union from langgraph.store.base import ( BaseStore, @@ -11,11 +12,39 @@ NamespacePath, Op, PutOp, + Result, SearchItem, SearchOp, _validate_namespace, ) +F = TypeVar("F", bound=Callable) + + +def _check_loop(func: F) -> F: + @functools.wraps(func) + def wrapper(store: "AsyncBatchedBaseStore", *args: Any, **kwargs: Any) -> Any: + method_name: str = func.__name__ + try: + current_loop = asyncio.get_running_loop() + if current_loop is store._loop: + replacement_str = ( + f"Specifically, replace `store.{method_name}(...)` with `await store.a{method_name}(...)" + if method_name + else "For example, replace `store.get(...)` with `await store.aget(...)`" + ) + raise asyncio.InvalidStateError( + f"Synchronous calls to {store.__class__.__name__} detected in the main event loop. " + "This can lead to deadlocks or performance issues. " + "Please use the asynchronous interface for main thread operations. " + f"{replacement_str} " + ) + except RuntimeError: + pass + return func(store, *args, **kwargs) + + return wrapper + class AsyncBatchedBaseStore(BaseStore): """Efficiently batch operations in a background task.""" @@ -23,6 +52,7 @@ class AsyncBatchedBaseStore(BaseStore): __slots__ = ("_loop", "_aqueue", "_task") def __init__(self) -> None: + super().__init__() self._loop = asyncio.get_running_loop() self._aqueue: dict[asyncio.Future, Op] = {} self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self))) @@ -99,6 +129,82 @@ async def alist_namespaces( self._aqueue[fut] = op return await fut + @_check_loop + def batch(self, ops: Iterable[Op]) -> list[Result]: + return asyncio.run_coroutine_threadsafe(self.abatch(ops), self._loop).result() + + @_check_loop + def get( + self, + namespace: tuple[str, ...], + key: str, + ) -> Optional[Item]: + return asyncio.run_coroutine_threadsafe( + self.aget(namespace, key=key), self._loop + ).result() + + @_check_loop + def search( + self, + namespace_prefix: tuple[str, ...], + /, + *, + query: Optional[str] = None, + filter: Optional[dict[str, Any]] = None, + limit: int = 10, + offset: int = 0, + ) -> list[SearchItem]: + return asyncio.run_coroutine_threadsafe( + self.asearch( + namespace_prefix, query=query, filter=filter, limit=limit, offset=offset + ), + self._loop, + ).result() + + @_check_loop + def put( + self, + namespace: tuple[str, ...], + key: str, + value: dict[str, Any], + index: Optional[Union[Literal[False], list[str]]] = None, + ) -> None: + _validate_namespace(namespace) + asyncio.run_coroutine_threadsafe( + self.aput(namespace, key=key, value=value, index=index), self._loop + ).result() + + @_check_loop + def delete( + self, + namespace: tuple[str, ...], + key: str, + ) -> None: + asyncio.run_coroutine_threadsafe( + self.adelete(namespace, key=key), self._loop + ).result() + + @_check_loop + def list_namespaces( + self, + *, + prefix: Optional[NamespacePath] = None, + suffix: Optional[NamespacePath] = None, + max_depth: Optional[int] = None, + limit: int = 100, + offset: int = 0, + ) -> list[tuple[str, ...]]: + return asyncio.run_coroutine_threadsafe( + self.alist_namespaces( + prefix=prefix, + suffix=suffix, + max_depth=max_depth, + limit=limit, + offset=offset, + ), + self._loop, + ).result() + def _dedupe_ops(values: list[Op]) -> tuple[Optional[list[int]], list[Op]]: """Dedupe operations while preserving order for results. @@ -144,7 +250,8 @@ def _dedupe_ops(values: list[Op]) -> tuple[Optional[list[int]], list[Op]]: async def _run( - aqueue: dict[asyncio.Future, Op], store: weakref.ReferenceType[BaseStore] + aqueue: dict[asyncio.Future, Op], + store: weakref.ReferenceType[BaseStore], ) -> None: while True: await asyncio.sleep(0) diff --git a/libs/langgraph/tests/test_pregel_async.py b/libs/langgraph/tests/test_pregel_async.py index f43e76d31..b3317e5ae 100644 --- a/libs/langgraph/tests/test_pregel_async.py +++ b/libs/langgraph/tests/test_pregel_async.py @@ -12621,9 +12621,19 @@ async def __call__( ) return {"count": 1} + def other_node(inputs: State, config: RunnableConfig, store: BaseStore): + assert isinstance(store, BaseStore) + store.put(("not", "interesting"), "key", {"val": "val"}) + item = store.get(("not", "interesting"), "key") + assert item is not None + assert item.value == {"val": "val"} + return {"count": 0} + builder = StateGraph(State) builder.add_node("node", Node()) + builder.add_node("other_node", other_node) builder.add_edge("__start__", "node") + builder.add_edge("node", "other_node") N = 500 M = 1