Skip to content

Commit

Permalink
Use AsyncBatch for postgres store (#2020)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Oct 8, 2024
1 parent 7c2a89d commit 254b12a
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 19 deletions.
8 changes: 6 additions & 2 deletions libs/checkpoint-postgres/langgraph/store/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from psycopg.rows import dict_row

from langgraph.store.base import GetOp, ListNamespacesOp, Op, PutOp, Result, SearchOp
from langgraph.store.base.batch import AsyncBatchedBaseStore
from langgraph.store.postgres.base import (
BasePostgresStore,
Row,
Expand All @@ -29,7 +30,9 @@
logger = logging.getLogger(__name__)


class AsyncPostgresStore(BasePostgresStore[AsyncConnection]):
class AsyncPostgresStore(AsyncBatchedBaseStore, BasePostgresStore[AsyncConnection]):
__slots__ = ("_deserializer",)

def __init__(
self,
conn: AsyncConnection[Any],
Expand All @@ -38,7 +41,8 @@ def __init__(
Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]
] = None,
) -> None:
super().__init__(deserializer=deserializer)
super().__init__()
self._deserializer = deserializer
self.conn = conn
self.conn = conn
self.loop = asyncio.get_running_loop()
Expand Down
25 changes: 10 additions & 15 deletions libs/checkpoint-postgres/langgraph/store/postgres/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,10 @@
C = TypeVar("C", bound=BaseConnection)


class BasePostgresStore(BaseStore, Generic[C]):
class BasePostgresStore(Generic[C]):
MIGRATIONS = MIGRATIONS
conn: C
__slots__ = ("_deserializer",)

def __init__(
self,
*,
deserializer: Optional[
Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]
] = None,
) -> None:
super().__init__()
self._deserializer = deserializer
_deserializer: Optional[Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]]

def _get_batch_GET_ops_queries(
self,
Expand Down Expand Up @@ -166,7 +156,9 @@ def _get_batch_search_queries(
params.extend([key, json.dumps(value)])
query += " AND " + " AND ".join(filter_conditions)

query += " LIMIT %s OFFSET %s"
# Note: we will need to not do this if sim/keyword search
# is used
query += " ORDER BY updated_at DESC LIMIT %s OFFSET %s"
params.extend([op.limit, op.offset])

queries.append((query, params))
Expand Down Expand Up @@ -227,7 +219,9 @@ def _get_batch_list_namespaces_queries(
return queries


class PostgresStore(BasePostgresStore[Connection]):
class PostgresStore(BaseStore, BasePostgresStore[Connection]):
__slots__ = ("_deserializer",)

def __init__(
self,
conn: Connection[Any],
Expand All @@ -236,7 +230,8 @@ def __init__(
Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]]
] = None,
) -> None:
super().__init__(deserializer=deserializer)
super().__init__()
self._deserializer = deserializer
self.conn = conn

def batch(self, ops: Iterable[Op]) -> list[Result]:
Expand Down
2 changes: 1 addition & 1 deletion libs/checkpoint-postgres/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langgraph-checkpoint-postgres"
version = "2.0.0"
version = "2.0.1"
description = "Library with a Postgres implementation of LangGraph checkpoint saver."
authors = []
license = "MIT"
Expand Down
6 changes: 5 additions & 1 deletion libs/checkpoint/langgraph/store/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ def _validate_namespace(namespace: tuple[str, ...]) -> None:


class BaseStore(ABC):
"""Abstract base class for key-value stores."""
"""Abstract base class for persistent key-value stores.
Stores enable persistence and memory that can be shared across threads,
scoped to user IDs, assistant IDs, or other arbitrary namespaces.
"""

__slots__ = ("__weakref__",)

Expand Down

0 comments on commit 254b12a

Please sign in to comment.