Skip to content

Commit

Permalink
Add sync support for the AsyncPostgresStore (#2673)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Dec 9, 2024
1 parent b37c9d8 commit 3f1bdb9
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 6 deletions.
3 changes: 0 additions & 3 deletions libs/checkpoint-postgres/langgraph/store/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,6 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]:

return results

def batch(self, ops: Iterable[Op]) -> list[Result]:
return asyncio.run_coroutine_threadsafe(self.abatch(ops), self.loop).result()

@classmethod
@asynccontextmanager
async def from_conn_string(
Expand Down
132 changes: 131 additions & 1 deletion libs/checkpoint-postgres/tests/test_async_store.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
# type: ignore
import asyncio
import itertools
import sys
import uuid
from collections.abc import AsyncIterator
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from typing import Any, Optional

import pytest
from langchain_core.embeddings import Embeddings
from psycopg import AsyncConnection

from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp
from langgraph.store.base import (
GetOp,
Item,
ListNamespacesOp,
PutOp,
SearchOp,
)
from langgraph.store.postgres import AsyncPostgresStore
from tests.conftest import (
DEFAULT_URI,
Expand Down Expand Up @@ -63,6 +71,128 @@ async def store(request) -> AsyncIterator[AsyncPostgresStore]:
await conn.execute(f"DROP DATABASE {database}")


async def test_no_running_loop(store: AsyncPostgresStore) -> None:
with pytest.raises(asyncio.InvalidStateError):
store.put(("foo", "bar"), "baz", {"val": "baz"})
with pytest.raises(asyncio.InvalidStateError):
store.get(("foo", "bar"), "baz")
with pytest.raises(asyncio.InvalidStateError):
store.delete(("foo", "bar"), "baz")
with pytest.raises(asyncio.InvalidStateError):
store.search(("foo", "bar"))
with pytest.raises(asyncio.InvalidStateError):
store.list_namespaces(prefix=("foo",))
with pytest.raises(asyncio.InvalidStateError):
store.batch([PutOp(namespace=("foo", "bar"), key="baz", value={"val": "baz"})])
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(store.put, ("foo", "bar"), "baz", {"val": "baz"})
result = await asyncio.wrap_future(future)
assert result is None
future = executor.submit(store.get, ("foo", "bar"), "baz")
result = await asyncio.wrap_future(future)
assert result.value == {"val": "baz"}
result = await asyncio.wrap_future(
executor.submit(store.list_namespaces, prefix=("foo",))
)


async def test_large_batches(request: Any, store: AsyncPostgresStore) -> None:
N = 100 # less important that we are performant here
M = 10

with ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for m in range(M):
for i in range(N):
futures += [
executor.submit(
store.put,
("test", "foo", "bar", "baz", str(m % 2)),
f"key{i}",
value={"foo": "bar" + str(i)},
),
executor.submit(
store.get,
("test", "foo", "bar", "baz", str(m % 2)),
f"key{i}",
),
executor.submit(
store.list_namespaces,
prefix=None,
max_depth=m + 1,
),
executor.submit(
store.search,
("test",),
),
executor.submit(
store.put,
("test", "foo", "bar", "baz", str(m % 2)),
f"key{i}",
value={"foo": "bar" + str(i)},
),
executor.submit(
store.put,
("test", "foo", "bar", "baz", str(m % 2)),
f"key{i}",
None,
),
]

results = await asyncio.gather(
*(asyncio.wrap_future(future) for future in futures)
)
assert len(results) == M * N * 6


async def test_large_batches_async(store: AsyncPostgresStore) -> None:
N = 1000
M = 10
coros = []
for m in range(M):
for i in range(N):
coros.append(
store.aput(
("test", "foo", "bar", "baz", str(m % 2)),
f"key{i}",
value={"foo": "bar" + str(i)},
)
)
coros.append(
store.aget(
("test", "foo", "bar", "baz", str(m % 2)),
f"key{i}",
)
)
coros.append(
store.alist_namespaces(
prefix=None,
max_depth=m + 1,
)
)
coros.append(
store.asearch(
("test",),
)
)
coros.append(
store.aput(
("test", "foo", "bar", "baz", str(m % 2)),
f"key{i}",
value={"foo": "bar" + str(i)},
)
)
coros.append(
store.adelete(
("test", "foo", "bar", "baz", str(m % 2)),
f"key{i}",
)
)

results = await asyncio.gather(*coros)
assert len(results) == M * N * 6


async def test_abatch_order(store: AsyncPostgresStore) -> None:
# Setup test data
await store.aput(("test", "foo"), "key1", {"data": "value1"})
Expand Down
111 changes: 109 additions & 2 deletions libs/checkpoint/langgraph/store/base/batch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import functools
import weakref
from typing import Any, Literal, Optional, Union
from typing import Any, Callable, Iterable, Literal, Optional, TypeVar, Union

from langgraph.store.base import (
BaseStore,
Expand All @@ -11,18 +12,47 @@
NamespacePath,
Op,
PutOp,
Result,
SearchItem,
SearchOp,
_validate_namespace,
)

F = TypeVar("F", bound=Callable)


def _check_loop(func: F) -> F:
@functools.wraps(func)
def wrapper(store: "AsyncBatchedBaseStore", *args: Any, **kwargs: Any) -> Any:
method_name: str = func.__name__
try:
current_loop = asyncio.get_running_loop()
if current_loop is store._loop:
replacement_str = (
f"Specifically, replace `store.{method_name}(...)` with `await store.a{method_name}(...)"
if method_name
else "For example, replace `store.get(...)` with `await store.aget(...)`"
)
raise asyncio.InvalidStateError(
f"Synchronous calls to {store.__class__.__name__} detected in the main event loop. "
"This can lead to deadlocks or performance issues. "
"Please use the asynchronous interface for main thread operations. "
f"{replacement_str} "
)
except RuntimeError:
pass
return func(store, *args, **kwargs)

return wrapper


class AsyncBatchedBaseStore(BaseStore):
"""Efficiently batch operations in a background task."""

__slots__ = ("_loop", "_aqueue", "_task")

def __init__(self) -> None:
super().__init__()
self._loop = asyncio.get_running_loop()
self._aqueue: dict[asyncio.Future, Op] = {}
self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self)))
Expand Down Expand Up @@ -99,6 +129,82 @@ async def alist_namespaces(
self._aqueue[fut] = op
return await fut

@_check_loop
def batch(self, ops: Iterable[Op]) -> list[Result]:
return asyncio.run_coroutine_threadsafe(self.abatch(ops), self._loop).result()

@_check_loop
def get(
self,
namespace: tuple[str, ...],
key: str,
) -> Optional[Item]:
return asyncio.run_coroutine_threadsafe(
self.aget(namespace, key=key), self._loop
).result()

@_check_loop
def search(
self,
namespace_prefix: tuple[str, ...],
/,
*,
query: Optional[str] = None,
filter: Optional[dict[str, Any]] = None,
limit: int = 10,
offset: int = 0,
) -> list[SearchItem]:
return asyncio.run_coroutine_threadsafe(
self.asearch(
namespace_prefix, query=query, filter=filter, limit=limit, offset=offset
),
self._loop,
).result()

@_check_loop
def put(
self,
namespace: tuple[str, ...],
key: str,
value: dict[str, Any],
index: Optional[Union[Literal[False], list[str]]] = None,
) -> None:
_validate_namespace(namespace)
asyncio.run_coroutine_threadsafe(
self.aput(namespace, key=key, value=value, index=index), self._loop
).result()

@_check_loop
def delete(
self,
namespace: tuple[str, ...],
key: str,
) -> None:
asyncio.run_coroutine_threadsafe(
self.adelete(namespace, key=key), self._loop
).result()

@_check_loop
def list_namespaces(
self,
*,
prefix: Optional[NamespacePath] = None,
suffix: Optional[NamespacePath] = None,
max_depth: Optional[int] = None,
limit: int = 100,
offset: int = 0,
) -> list[tuple[str, ...]]:
return asyncio.run_coroutine_threadsafe(
self.alist_namespaces(
prefix=prefix,
suffix=suffix,
max_depth=max_depth,
limit=limit,
offset=offset,
),
self._loop,
).result()


def _dedupe_ops(values: list[Op]) -> tuple[Optional[list[int]], list[Op]]:
"""Dedupe operations while preserving order for results.
Expand Down Expand Up @@ -144,7 +250,8 @@ def _dedupe_ops(values: list[Op]) -> tuple[Optional[list[int]], list[Op]]:


async def _run(
aqueue: dict[asyncio.Future, Op], store: weakref.ReferenceType[BaseStore]
aqueue: dict[asyncio.Future, Op],
store: weakref.ReferenceType[BaseStore],
) -> None:
while True:
await asyncio.sleep(0)
Expand Down
10 changes: 10 additions & 0 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12621,9 +12621,19 @@ async def __call__(
)
return {"count": 1}

def other_node(inputs: State, config: RunnableConfig, store: BaseStore):
assert isinstance(store, BaseStore)
store.put(("not", "interesting"), "key", {"val": "val"})
item = store.get(("not", "interesting"), "key")
assert item is not None
assert item.value == {"val": "val"}
return {"count": 0}

builder = StateGraph(State)
builder.add_node("node", Node())
builder.add_node("other_node", other_node)
builder.add_edge("__start__", "node")
builder.add_edge("node", "other_node")

N = 500
M = 1
Expand Down

0 comments on commit 3f1bdb9

Please sign in to comment.