Skip to content

Commit

Permalink
Update PG Implementation (#1948)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Oct 1, 2024
1 parent db6f8a5 commit 5c3ac5d
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 34 deletions.
4 changes: 2 additions & 2 deletions libs/checkpoint-postgres/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ test:
EXIT_CODE=$$?; \
make stop-postgres; \
exit $$EXIT_CODE

TEST ?= .
test_watch:
make start-postgres; \
poetry run ptw .; \
poetry run ptw $(TEST); \
EXIT_CODE=$$?; \
make stop-postgres; \
exit $$EXIT_CODE
Expand Down
75 changes: 48 additions & 27 deletions libs/checkpoint-postgres/langgraph/store/postgres/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,9 @@

MIGRATIONS = [
"""
CREATE EXTENSION IF NOT EXISTS ltree;
""",
"""
CREATE TABLE IF NOT EXISTS store (
-- 'prefix' represents the doc's 'namespace'
prefix ltree NOT NULL,
prefix text NOT NULL,
key text NOT NULL,
value jsonb NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
Expand All @@ -54,8 +51,8 @@
);
""",
"""
-- For faster listing of namespaces & lookups by namespace with prefix/suffix matching
CREATE INDEX IF NOT EXISTS store_prefix_idx ON store USING gist (prefix);
-- For faster lookups by prefix
CREATE INDEX IF NOT EXISTS store_prefix_idx ON store USING btree (prefix text_pattern_ops);
""",
]

Expand Down Expand Up @@ -93,7 +90,7 @@ def _get_batch_GET_ops_queries(
FROM store
WHERE prefix = %s AND key IN ({keys_to_query})
"""
params = (_namespace_to_ltree(namespace), *keys)
params = (_namespace_to_text(namespace), *keys)
results.append((query, params, namespace, items))
return results

Expand All @@ -120,7 +117,7 @@ def _get_batch_PUT_queries(
query = (
f"DELETE FROM store WHERE prefix = %s AND key IN ({placeholders})"
)
params = (_namespace_to_ltree(namespace), *keys)
params = (_namespace_to_text(namespace), *keys)
queries.append((query, params))
if inserts:
values = []
Expand All @@ -129,7 +126,7 @@ def _get_batch_PUT_queries(
values.append("(%s, %s, %s, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)")
insertion_params.extend(
[
_namespace_to_ltree(op.namespace),
_namespace_to_text(op.namespace),
op.key,
Jsonb(op.value),
]
Expand All @@ -152,11 +149,11 @@ def _get_batch_search_queries(
queries: list[tuple[str, Sequence]] = []
for _, op in search_ops:
query = """
SELECT prefix, key, value, created_at, updated_at, prefix
SELECT prefix, key, value, created_at, updated_at
FROM store
WHERE prefix <@ %s
WHERE prefix LIKE %s
"""
params: list = [_namespace_to_ltree(op.namespace_prefix)]
params: list = [f"{_namespace_to_text(op.namespace_prefix)}%"]

if op.filter:
filter_conditions = []
Expand All @@ -181,34 +178,52 @@ def _get_batch_list_namespaces_queries(
) -> list[tuple[str, Sequence]]:
queries: list[tuple[str, Sequence]] = []
for _, op in list_ops:
query = "SELECT DISTINCT subltree(prefix, 0, LEAST(nlevel(prefix), %s)) AS truncated_prefix FROM store"
# https://www.postgresql.org/docs/current/ltree.html
# The length of a label path cannot exceed 65535 labels.
params: list[Any] = [op.max_depth if op.max_depth is not None else 65536]
query = """
SELECT DISTINCT ON (truncated_prefix) truncated_prefix, prefix
FROM (
SELECT
prefix,
CASE
WHEN %s::integer IS NOT NULL THEN
(SELECT STRING_AGG(part, '.' ORDER BY idx)
FROM (
SELECT part, ROW_NUMBER() OVER () AS idx
FROM UNNEST(REGEXP_SPLIT_TO_ARRAY(prefix, '\.')) AS part
LIMIT %s::integer
) subquery
)
ELSE prefix
END AS truncated_prefix
FROM store
"""
params: list[Any] = [op.max_depth, op.max_depth]

conditions = []
if op.match_conditions:
for condition in op.match_conditions:
if condition.match_type == "prefix":
conditions.append("prefix ~ %s::lquery")
lquery_pattern = f"{_namespace_to_ltree(condition.path)}.*"
params.append(lquery_pattern)
conditions.append("prefix LIKE %s")
params.append(
f"{_namespace_to_text(condition.path, handle_wildcards=True)}%"
)
elif condition.match_type == "suffix":
conditions.append("prefix ~ %s::lquery")
lquery_pattern = f"*.{_namespace_to_ltree(condition.path)}"
params.append(lquery_pattern)
conditions.append("prefix LIKE %s")
params.append(
f"%{_namespace_to_text(condition.path, handle_wildcards=True)}"
)
else:
logger.warning(
f"Unknown match_type in list_namespaces: {condition.match_type}"
)

if conditions:
query += " WHERE " + " AND ".join(conditions)
query += ") AS subquery "

query += " ORDER BY truncated_prefix LIMIT %s OFFSET %s"
params.extend([op.limit, op.offset])

queries.append((query, params))

return queries


Expand Down Expand Up @@ -386,13 +401,17 @@ def setup(self) -> None:
class Row(TypedDict):
key: str
value: Any
prefix: bytes
prefix: str
created_at: datetime
updated_at: datetime


def _namespace_to_ltree(namespace: tuple[str, ...]) -> str:
"""Convert namespace tuple to ltree-compatible string."""
def _namespace_to_text(
namespace: tuple[str, ...], handle_wildcards: bool = False
) -> str:
"""Convert namespace tuple to text string."""
if handle_wildcards:
namespace = tuple("%" if val == "*" else val for val in namespace)
return ".".join(namespace)


Expand Down Expand Up @@ -435,7 +454,9 @@ def _json_loads(content: Union[bytes, orjson.Fragment]) -> Any:
return orjson.loads(cast(bytes, content))


def _decode_ns_bytes(namespace: Union[str, bytes]) -> tuple[str, ...]:
def _decode_ns_bytes(namespace: Union[str, bytes, list]) -> tuple[str, ...]:
if isinstance(namespace, list):
return tuple(namespace)
if isinstance(namespace, bytes):
namespace = namespace.decode()[1:]
return tuple(namespace.split("."))
5 changes: 2 additions & 3 deletions libs/checkpoint-postgres/tests/test_async_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def cursor_side_effect(binary: bool = False) -> Any:

async def execute_side_effect(query: str, *params: Any) -> None:
# My super sophisticated database.
if "WHERE prefix <@" in query:
if "SELECT prefix, key," in query:
cursor.fetchall = mock_search_cursor.fetchall
elif "SELECT DISTINCT subltree" in query:
elif "SELECT DISTINCT ON (truncated_prefix)" in query:
cursor.fetchall = mock_list_namespaces_cursor.fetchall
elif "WHERE prefix = %s AND key" in query:
cursor.fetchall = mock_get_cursor.fetchall
Expand Down Expand Up @@ -390,7 +390,6 @@ async def test_list_namespaces(self) -> None:

max_depth_result = await store.alist_namespaces(max_depth=3)
assert all([len(ns) <= 3 for ns in max_depth_result])

max_depth_result = await store.alist_namespaces(
max_depth=4, prefix=[test_pref, "*", "documents"]
)
Expand Down
4 changes: 2 additions & 2 deletions libs/checkpoint-postgres/tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def cursor_side_effect(binary: bool = False) -> Any:

def execute_side_effect(query: str, *params: Any) -> None:
# My super sophisticated database.
if "WHERE prefix <@" in query:
if "SELECT prefix, key, value" in query:
cursor.fetchall = mock_search_cursor.fetchall
elif "SELECT DISTINCT subltree" in query:
elif "SELECT DISTINCT ON (truncated_prefix)" in query:
cursor.fetchall = mock_list_namespaces_cursor.fetchall
elif "WHERE prefix = %s AND key" in query:
cursor.fetchall = mock_get_cursor.fetchall
Expand Down

0 comments on commit 5c3ac5d

Please sign in to comment.