Skip to content

Commit

Permalink
Merge pull request #1766 from langchain-ai/nc/18sep/mypy-checkpoint
Browse files Browse the repository at this point in the history
ci: Enable mypy checks for checkpoint lib
  • Loading branch information
nfcampos authored Sep 19, 2024
2 parents d4ba315 + c793a9e commit bc95a79
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 37 deletions.
3 changes: 2 additions & 1 deletion libs/checkpoint/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ lint lint_diff lint_package lint_tests:
poetry run ruff check .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)

format format_diff:
poetry run ruff format $(PYTHON_FILES)
Expand Down
16 changes: 11 additions & 5 deletions libs/checkpoint/langgraph/checkpoint/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Any,
AsyncIterator,
Dict,
Generic,
Iterator,
List,
Literal,
Expand All @@ -11,11 +12,11 @@
Optional,
Tuple,
TypedDict,
TypeVar,
Union,
)

from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig
from typing_extensions import TypeVar

from langgraph.checkpoint.base.id import uuid6
from langgraph.checkpoint.serde.base import SerializerProtocol, maybe_add_typed_methods
Expand All @@ -27,7 +28,7 @@
SendProtocol,
)

V = TypeVar("V", int, float, str)
V = TypeVar("V", int, float, str, default=int)
PendingWrite = Tuple[str, str, Any]


Expand Down Expand Up @@ -135,7 +136,7 @@ def create_checkpoint(
if channels is None:
values = checkpoint["channel_values"]
else:
values: dict[str, Any] = {}
values = {}
for k, v in channels.items():
if k not in checkpoint["channel_versions"]:
continue
Expand Down Expand Up @@ -192,7 +193,7 @@ class CheckpointTuple(NamedTuple):
)


class BaseCheckpointSaver:
class BaseCheckpointSaver(Generic[V]):
"""Base class for creating a graph checkpointer.
Checkpointers allow LangGraph agents to persist their state
Expand Down Expand Up @@ -420,7 +421,12 @@ def get_next_version(self, current: Optional[V], channel: ChannelProtocol) -> V:
Returns:
V: The next version identifier, which must be increasing.
"""
return current + 1 if current is not None else 1
if isinstance(current, str):
raise NotImplementedError
elif current is None:
return 1
else:
return current + 1


class EmptyChannelError(Exception):
Expand Down
17 changes: 11 additions & 6 deletions libs/checkpoint/langgraph/checkpoint/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


class MemorySaver(
BaseCheckpointSaver, AbstractContextManager, AbstractAsyncContextManager
BaseCheckpointSaver[str], AbstractContextManager, AbstractAsyncContextManager
):
"""An in-memory checkpoint saver.
Expand Down Expand Up @@ -54,9 +54,14 @@ class MemorySaver(
"""

# thread ID -> checkpoint NS -> checkpoint ID -> checkpoint mapping
storage: defaultdict[str, dict[str, dict[str, tuple[bytes, bytes, Optional[str]]]]]
storage: defaultdict[
str,
dict[
str, dict[str, tuple[tuple[str, bytes], tuple[str, bytes], Optional[str]]]
],
]
writes: defaultdict[
tuple[str, str, str], dict[tuple[str, int], tuple[str, str, bytes]]
tuple[str, str, str], dict[tuple[str, int], tuple[str, str, tuple[str, bytes]]]
]

def __init__(
Expand Down Expand Up @@ -316,7 +321,7 @@ def put(
RunnableConfig: The updated config containing the saved checkpoint's timestamp.
"""
c = checkpoint.copy()
c.pop("pending_sends")
c.pop("pending_sends") # type: ignore[misc]
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"]["checkpoint_ns"]
self.storage[thread_id][checkpoint_ns].update(
Expand All @@ -341,7 +346,7 @@ def put_writes(
config: RunnableConfig,
writes: List[Tuple[str, Any]],
task_id: str,
) -> RunnableConfig:
) -> None:
"""Save a list of writes to the in-memory storage.
This method saves a list of writes to the in-memory storage. The writes are associated
Expand Down Expand Up @@ -444,7 +449,7 @@ async def aput_writes(
config: RunnableConfig,
writes: List[Tuple[str, Any]],
task_id: str,
) -> RunnableConfig:
) -> None:
"""Asynchronous version of put_writes.
This method is an asynchronous wrapper around put_writes that runs the synchronous
Expand Down
6 changes: 6 additions & 0 deletions libs/checkpoint/langgraph/checkpoint/serde/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ class SerializerCompat(SerializerProtocol):
def __init__(self, serde: SerializerProtocol) -> None:
self.serde = serde

def dumps(self, obj: Any) -> bytes:
return self.serde.dumps(obj)

def loads(self, data: bytes) -> Any:
return self.serde.loads(data)

def dumps_typed(self, obj: Any) -> tuple[str, bytes]:
return type(obj).__name__, self.serde.dumps(obj)

Expand Down
27 changes: 15 additions & 12 deletions libs/checkpoint/langgraph/checkpoint/serde/jsonplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
IPv6Interface,
IPv6Network,
)
from typing import Any, Optional, Sequence
from typing import Any, Callable, Optional, Sequence, Union, cast
from uuid import UUID

import msgpack
import msgpack # type: ignore[import-untyped]
from langchain_core.load.load import Reviver
from langchain_core.load.serializable import Serializable
from zoneinfo import ZoneInfo
Expand All @@ -33,12 +33,12 @@
class JsonPlusSerializer(SerializerProtocol):
def _encode_constructor_args(
self,
constructor: type[Any],
constructor: Union[Callable, type[Any]],
*,
method: Optional[str] = None,
method: Union[None, str, Sequence[Union[None, str]]] = None,
args: Optional[Sequence[Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
):
) -> dict[str, Any]:
out = {
"lc": 2,
"type": "constructor",
Expand All @@ -52,9 +52,9 @@ def _encode_constructor_args(
out["kwargs"] = kwargs
return out

def _default(self, obj):
def _default(self, obj: Any) -> Union[str, dict[str, Any]]:
if isinstance(obj, Serializable):
return obj.to_json()
return cast(dict[str, Any], obj.to_json())
elif hasattr(obj, "model_dump") and callable(obj.model_dump):
return self._encode_constructor_args(
obj.__class__, method=(None, "model_construct"), kwargs=obj.model_dump()
Expand Down Expand Up @@ -87,7 +87,10 @@ def _default(self, obj):
datetime, method="fromisoformat", args=(obj.isoformat(),)
)
elif isinstance(obj, timezone):
return self._encode_constructor_args(timezone, args=obj.__getinitargs__())
return self._encode_constructor_args(
timezone,
args=obj.__getinitargs__(), # type: ignore[attr-defined]
)
elif isinstance(obj, ZoneInfo):
return self._encode_constructor_args(ZoneInfo, args=(obj.key,))
elif isinstance(obj, timedelta):
Expand Down Expand Up @@ -217,7 +220,7 @@ def loads_typed(self, data: tuple[str, bytes]) -> Any:
EXT_PYDANTIC_V2 = 5


def _msgpack_default(obj):
def _msgpack_default(obj: Any) -> Union[str, msgpack.ExtType]:
if hasattr(obj, "model_dump") and callable(obj.model_dump): # pydantic v2
return msgpack.ExtType(
EXT_PYDANTIC_V2,
Expand Down Expand Up @@ -360,7 +363,7 @@ def _msgpack_default(obj):
(
obj.__class__.__module__,
obj.__class__.__name__,
obj.__getinitargs__(),
obj.__getinitargs__(), # type: ignore[attr-defined]
),
),
)
Expand Down Expand Up @@ -406,7 +409,7 @@ def _msgpack_default(obj):
raise TypeError(f"Object of type {obj.__class__.__name__} is not serializable")


def _msgpack_ext_hook(code: int, data: bytes):
def _msgpack_ext_hook(code: int, data: bytes) -> Any:
if code == EXT_CONSTRUCTOR_SINGLE_ARG:
try:
tup = msgpack.unpackb(data, ext_hook=_msgpack_ext_hook)
Expand Down Expand Up @@ -461,7 +464,7 @@ def _msgpack_ext_hook(code: int, data: bytes):
return


ENC_POOL = deque(maxlen=32)
ENC_POOL: deque[msgpack.Packer] = deque(maxlen=32)


def _msgpack_enc(data: Any) -> bytes:
Expand Down
4 changes: 2 additions & 2 deletions libs/checkpoint/langgraph/checkpoint/serde/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
SCHEDULED = "__scheduled__"
TASKS = "__pregel_tasks"

Value = TypeVar("Value")
Update = TypeVar("Update")
Value = TypeVar("Value", covariant=True)
Update = TypeVar("Update", contravariant=True)
C = TypeVar("C")


Expand Down
10 changes: 10 additions & 0 deletions libs/checkpoint/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,13 @@ now = true
delay = 0.1
runner_args = ["--ff", "-v", "--tb", "short"]
patterns = ["*.py"]

[tool.mypy]
# https://mypy.readthedocs.io/en/stable/config_file.html
disallow_untyped_defs = "True"
explicit_package_bases = "True"
warn_no_return = "False"
warn_unused_ignores = "True"
warn_redundant_casts = "True"
allow_redefinition = "True"
disable_error_code = "typeddict-item, return-value"
24 changes: 13 additions & 11 deletions libs/checkpoint/tests/test_memory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import pytest
from langchain_core.runnables import RunnableConfig

Expand All @@ -12,7 +14,7 @@

class TestMemorySaver:
@pytest.fixture(autouse=True)
def setup(self):
def setup(self) -> None:
self.memory_saver = MemorySaver()

# objects for test setup
Expand Down Expand Up @@ -57,21 +59,21 @@ def setup(self):
}
self.metadata_3: CheckpointMetadata = {}

async def test_search(self):
async def test_search(self) -> None:
# set up test
# save checkpoints
self.memory_saver.put(self.config_1, self.chkpnt_1, self.metadata_1, {})
self.memory_saver.put(self.config_2, self.chkpnt_2, self.metadata_2, {})
self.memory_saver.put(self.config_3, self.chkpnt_3, self.metadata_3, {})

# call method / assertions
query_1: CheckpointMetadata = {"source": "input"} # search by 1 key
query_2: CheckpointMetadata = {
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: CheckpointMetadata = {} # search by no keys, return all checkpoints
query_4: CheckpointMetadata = {"source": "update", "step": 1} # no match
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match

search_results_1 = list(self.memory_saver.list(None, filter=query_1))
assert len(search_results_1) == 1
Expand Down Expand Up @@ -99,21 +101,21 @@ async def test_search(self):

# TODO: test before and limit params

async def test_asearch(self):
async def test_asearch(self) -> None:
# set up test
# save checkpoints
self.memory_saver.put(self.config_1, self.chkpnt_1, self.metadata_1, {})
self.memory_saver.put(self.config_2, self.chkpnt_2, self.metadata_2, {})
self.memory_saver.put(self.config_3, self.chkpnt_3, self.metadata_3, {})

# call method / assertions
query_1: CheckpointMetadata = {"source": "input"} # search by 1 key
query_2: CheckpointMetadata = {
query_1 = {"source": "input"} # search by 1 key
query_2 = {
"step": 1,
"writes": {"foo": "bar"},
} # search by multiple keys
query_3: CheckpointMetadata = {} # search by no keys, return all checkpoints
query_4: CheckpointMetadata = {"source": "update", "step": 1} # no match
query_3: dict[str, Any] = {} # search by no keys, return all checkpoints
query_4 = {"source": "update", "step": 1} # no match

search_results_1 = [
c async for c in self.memory_saver.alist(None, filter=query_1)
Expand Down

0 comments on commit bc95a79

Please sign in to comment.