Skip to content

Commit

Permalink
Use correct serde in postgres checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
vermapratyush committed Sep 12, 2024
1 parent 889cc9b commit 82dd8a7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
10 changes: 7 additions & 3 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ def __init__(
@classmethod
@asynccontextmanager
async def from_conn_string(
cls, conn_string: str, *, pipeline: bool = False
cls,
conn_string: str,
*,
pipeline: bool = False,
serde: Optional[SerializerProtocol] = None,
) -> AsyncIterator["AsyncPostgresSaver"]:
"""Create a new PostgresSaver instance from a connection string.
Expand All @@ -73,9 +77,9 @@ async def from_conn_string(
) as conn:
if pipeline:
async with conn.pipeline() as pipe:
yield AsyncPostgresSaver(conn, pipe)
yield AsyncPostgresSaver(conn=conn, pipe=pipe, serde=serde)
else:
yield AsyncPostgresSaver(conn)
yield AsyncPostgresSaver(conn=conn, serde=serde)

async def setup(self) -> None:
"""Set up the checkpoint database asynchronously.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
EmptyChannelError,
get_checkpoint_id,
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
from langgraph.checkpoint.serde.types import TASKS, ChannelProtocol

MetadataInput = Optional[dict[str, Any]]
Expand Down Expand Up @@ -131,8 +130,6 @@ class BasePostgresSaver(BaseCheckpointSaver):
UPSERT_CHECKPOINT_WRITES_SQL = UPSERT_CHECKPOINT_WRITES_SQL
INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL

jsonplus_serde = JsonPlusSerializer()

def _load_checkpoint(
self,
checkpoint: dict[str, Any],
Expand Down Expand Up @@ -224,12 +221,10 @@ def _dump_writes(
]

def _load_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
return self.jsonplus_serde.loads(self.jsonplus_serde.dumps(metadata))
return self.serde.loads(self.serde.dumps(metadata))

def _dump_metadata(self, metadata) -> str:
serialized_metadata_type, serialized_metadata = self.jsonplus_serde.dumps_typed(
metadata
)
serialized_metadata_type, serialized_metadata = self.serde.dumps_typed(metadata)
if serialized_metadata_type != "json":
raise TypeError(
f"Failed to properly serialize metadata -- expected 'json', got '{serialized_metadata_type}'"
Expand Down

0 comments on commit 82dd8a7

Please sign in to comment.