Skip to content

Commit

Permalink
Validate in async batched store (#2017)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Oct 6, 2024
1 parent b35fe58 commit 05dbc1d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
11 changes: 10 additions & 1 deletion libs/checkpoint/langgraph/store/base/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
import weakref
from typing import Any, Optional

from langgraph.store.base import BaseStore, GetOp, Item, Op, PutOp, SearchOp
from langgraph.store.base import (
BaseStore,
GetOp,
Item,
Op,
PutOp,
SearchOp,
_validate_namespace,
)


class AsyncBatchedBaseStore(BaseStore):
Expand Down Expand Up @@ -46,6 +54,7 @@ async def aput(
key: str,
value: dict[str, Any],
) -> None:
_validate_namespace(namespace)
fut = self._loop.create_future()
self._aqueue[fut] = PutOp(namespace, key, value)
return await fut
Expand Down
38 changes: 38 additions & 0 deletions libs/checkpoint/tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,41 @@ async def test_cannot_put_empty_namespace() -> None:
assert store.search(("langgraph", "foo"))[0].value == doc
store.delete(("langgraph", "foo"), "bar")
assert store.get(("langgraph", "foo"), "bar") is None

class MockAsyncBatchedStore(AsyncBatchedBaseStore):
def __init__(self):
super().__init__()
self._store = InMemoryStore()

def batch(self, ops: Iterable[Op]) -> list[Result]:
return self._store.batch(ops)

async def abatch(self, ops: Iterable[Op]) -> list[Result]:
return self._store.batch(ops)

async_store = MockAsyncBatchedStore()
doc = {"foo": "bar"}

with pytest.raises(InvalidNamespaceError):
await async_store.aput((), "foo", doc)

with pytest.raises(InvalidNamespaceError):
await async_store.aput(("the", "thing.about"), "foo", doc)

with pytest.raises(InvalidNamespaceError):
await async_store.aput(("some", "fun", ""), "foo", doc)

with pytest.raises(InvalidNamespaceError):
await async_store.aput(("langgraph", "foo"), "bar", doc)

await async_store.aput(("foo", "langgraph", "foo"), "bar", doc)
assert (await async_store.aget(("foo", "langgraph", "foo"), "bar")).value == doc
assert (await async_store.asearch(("foo", "langgraph", "foo")))[0].value == doc
await async_store.adelete(("foo", "langgraph", "foo"), "bar")
assert (await async_store.aget(("foo", "langgraph", "foo"), "bar")) is None

await async_store.abatch([PutOp(("valid", "namespace"), "key", doc)])
assert (await async_store.aget(("valid", "namespace"), "key")).value == doc
assert (await async_store.asearch(("valid", "namespace")))[0].value == doc
await async_store.adelete(("valid", "namespace"), "key")
assert (await async_store.aget(("valid", "namespace"), "key")) is None

0 comments on commit 05dbc1d

Please sign in to comment.