From 217b07ae65cdfece3eda1c9e728aaf7015ae3cb7 Mon Sep 17 00:00:00 2001
From: Nuno Campos <nuno@langchain.dev>
Date: Wed, 11 Sep 2024 14:27:08 -0700
Subject: [PATCH] kafka: Make consumer and producer classes configurable

- Define protocol for sync and async producer and consumer
- Accept consumer/producer as init args in Orchestrator/Executor
- If not passed in, create default consumer/producer as before
---
 libs/scheduler-kafka/README.md                |  13 +-
 .../scheduler/kafka/default_async.py          |  16 +++
 .../langgraph/scheduler/kafka/executor.py     | 128 ++++++++++--------
 .../langgraph/scheduler/kafka/orchestrator.py | 112 ++++++++-------
 .../langgraph/scheduler/kafka/types.py        |  40 +++++-
 5 files changed, 194 insertions(+), 115 deletions(-)
 create mode 100644 libs/scheduler-kafka/langgraph/scheduler/kafka/default_async.py

diff --git a/libs/scheduler-kafka/README.md b/libs/scheduler-kafka/README.md
index c279f5e38..2d934d59f 100644
--- a/libs/scheduler-kafka/README.md
+++ b/libs/scheduler-kafka/README.md
@@ -34,9 +34,9 @@ from your_lib import graph # graph expected to be a compiled LangGraph graph
 logger = logging.getLogger(__name__)
 
 topics = Topics(
-    orchestrator: os.environ['KAFKA_TOPIC_ORCHESTRATOR'],
-    executor: os.environ['KAFKA_TOPIC_EXECUTOR'],
-    error: os.environ['KAFKA_TOPIC_ERROR'],
+    orchestrator=os.environ['KAFKA_TOPIC_ORCHESTRATOR'],
+    executor=os.environ['KAFKA_TOPIC_EXECUTOR'],
+    error=os.environ['KAFKA_TOPIC_ERROR'],
 )
 
 async def main():
@@ -64,9 +64,9 @@ from your_lib import graph # graph expected to be a compiled LangGraph graph
 logger = logging.getLogger(__name__)
 
 topics = Topics(
-    orchestrator: os.environ['KAFKA_TOPIC_ORCHESTRATOR'],
-    executor: os.environ['KAFKA_TOPIC_EXECUTOR'],
-    error: os.environ['KAFKA_TOPIC_ERROR'],
+    orchestrator=os.environ['KAFKA_TOPIC_ORCHESTRATOR'],
+    executor=os.environ['KAFKA_TOPIC_EXECUTOR'],
+    error=os.environ['KAFKA_TOPIC_ERROR'],
 )
 
 async def main():
@@ -91,7 +91,6 @@ python executor.py &
 
 You can pass any of the following values as `kwargs` to either `KafkaOrchestrator` or `KafkaExecutor` to configure the consumer:
 
-- group_id (str): a name for the consumer group. Defaults to 'orchestrator' or 'executor', respectively.
 - batch_max_n (int): Maximum number of messages to include in a single batch. Default: 10.
 - batch_max_ms (int): Maximum time in milliseconds to wait for messages to include in a batch. Default: 1000.
 - retry_policy (langgraph.pregel.types.RetryPolicy): Controls which graph-level errors will be retried when processing messages. A good use for this is to retry database errors thrown by the checkpointer. Defaults to None.
diff --git a/libs/scheduler-kafka/langgraph/scheduler/kafka/default_async.py b/libs/scheduler-kafka/langgraph/scheduler/kafka/default_async.py
new file mode 100644
index 000000000..ce6703a03
--- /dev/null
+++ b/libs/scheduler-kafka/langgraph/scheduler/kafka/default_async.py
@@ -0,0 +1,16 @@
+import dataclasses
+from typing import Any, Sequence
+
+import aiokafka
+
+
+class DefaultAsyncConsumer(aiokafka.AIOKafkaConsumer):
+    async def getmany(
+        self, timeout_ms: int, max_records: int
+    ) -> dict[str, Sequence[dict[str, Any]]]:
+        batch = await super().getmany(timeout_ms=timeout_ms, max_records=max_records)
+        return {t: [dataclasses.asdict(m) for m in msgs] for t, msgs in batch.items()}
+
+
+class DefaultAsyncProducer(aiokafka.AIOKafkaProducer):
+    pass
diff --git a/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py b/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py
index 86812c258..f7bceb86a 100644
--- a/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py
+++ b/libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py
@@ -3,7 +3,6 @@
 from functools import partial
 from typing import Any, Optional, Sequence
 
-import aiokafka
 import orjson
 from langchain_core.runnables import RunnableConfig
 from typing_extensions import Self
@@ -19,6 +18,8 @@
 from langgraph.pregel.types import RetryPolicy
 from langgraph.scheduler.kafka.retry import aretry
 from langgraph.scheduler.kafka.types import (
+    AsyncConsumer,
+    AsyncProducer,
     ErrorMessage,
     MessageToExecutor,
     MessageToOrchestrator,
@@ -28,51 +29,56 @@
 
 
 class KafkaExecutor(AbstractAsyncContextManager):
+    consumer: AsyncConsumer
+
+    producer: AsyncProducer
+
     def __init__(
         self,
         graph: Pregel,
         topics: Topics,
         *,
-        group_id: str = "executor",
         batch_max_n: int = 10,
         batch_max_ms: int = 1000,
         retry_policy: Optional[RetryPolicy] = None,
-        consumer_kwargs: Optional[dict[str, Any]] = None,
-        producer_kwargs: Optional[dict[str, Any]] = None,
+        consumer: Optional[AsyncConsumer] = None,
+        producer: Optional[AsyncProducer] = None,
         **kwargs: Any,
     ) -> None:
         self.graph = graph
         self.topics = topics
         self.stack = AsyncExitStack()
         self.kwargs = kwargs
-        self.consumer_kwargs = consumer_kwargs or {}
-        self.producer_kwargs = producer_kwargs or {}
-        self.group_id = group_id
+        self.consumer = consumer
+        self.producer = producer
         self.batch_max_n = batch_max_n
         self.batch_max_ms = batch_max_ms
         self.retry_policy = retry_policy
 
     async def __aenter__(self) -> Self:
-        self.consumer = await self.stack.enter_async_context(
-            aiokafka.AIOKafkaConsumer(
-                self.topics.executor,
-                value_deserializer=serde.loads,
-                auto_offset_reset="earliest",
-                group_id=self.group_id,
-                enable_auto_commit=False,
-                **self.kwargs,
-            )
-        )
-        self.producer = await self.stack.enter_async_context(
-            aiokafka.AIOKafkaProducer(
-                key_serializer=serde.dumps,
-                value_serializer=serde.dumps,
-                **self.kwargs,
-            )
-        )
         self.subgraphs = {
             k: v async for k, v in self.graph.aget_subgraphs(recurse=True)
         }
+        if self.consumer is None:
+            from langgraph.scheduler.kafka.default_async import DefaultAsyncConsumer
+
+            self.consumer = await self.stack.enter_async_context(
+                DefaultAsyncConsumer(
+                    self.topics.executor,
+                    auto_offset_reset="earliest",
+                    group_id="executor",
+                    enable_auto_commit=False,
+                    **self.kwargs,
+                )
+            )
+        if self.producer is None:
+            from langgraph.scheduler.kafka.default_async import DefaultAsyncProducer
+
+            self.producer = await self.stack.enter_async_context(
+                DefaultAsyncProducer(
+                    **self.kwargs,
+                )
+            )
         return self
 
     async def __aexit__(self, *args: Any) -> None:
@@ -83,15 +89,12 @@ def __aiter__(self) -> Self:
 
     async def __anext__(self) -> Sequence[MessageToExecutor]:
         # wait for next batch
-        try:
-            recs = await self.consumer.getmany(
-                timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
-            )
-            msgs: list[MessageToExecutor] = [
-                msg.value for msgs in recs.values() for msg in msgs
-            ]
-        except aiokafka.ConsumerStoppedError:
-            raise StopAsyncIteration from None
+        recs = await self.consumer.getmany(
+            timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
+        )
+        msgs: list[MessageToExecutor] = [
+            serde.loads(msg["value"]) for msgs in recs.values() for msg in msgs
+        ]
         # process batch
         await asyncio.gather(*(self.each(msg) for msg in msgs))
         # commit offsets
@@ -106,30 +109,38 @@ async def each(self, msg: MessageToExecutor) -> None:
             pass
         except GraphDelegate as exc:
             for arg in exc.args:
-                await self.producer.send_and_wait(
+                fut = await self.producer.send(
                     self.topics.orchestrator,
-                    value=MessageToOrchestrator(
-                        config=arg["config"],
-                        input=orjson.Fragment(
-                            self.graph.checkpointer.serde.dumps(arg["input"])
-                        ),
-                        finally_executor=[msg],
+                    value=serde.dumps(
+                        MessageToOrchestrator(
+                            config=arg["config"],
+                            input=orjson.Fragment(
+                                self.graph.checkpointer.serde.dumps(arg["input"])
+                            ),
+                            finally_executor=[msg],
+                        )
                     ),
                     # use thread_id, checkpoint_ns as partition key
-                    key=(
-                        arg["config"]["configurable"]["thread_id"],
-                        arg["config"]["configurable"].get("checkpoint_ns"),
+                    key=serde.dumps(
+                        (
+                            arg["config"]["configurable"]["thread_id"],
+                            arg["config"]["configurable"].get("checkpoint_ns"),
+                        )
                     ),
                 )
+                await fut
         except Exception as exc:
-            await self.producer.send_and_wait(
+            fut = await self.producer.send(
                 self.topics.error,
-                value=ErrorMessage(
-                    topic=self.topics.executor,
-                    msg=msg,
-                    error=repr(exc),
+                value=serde.dumps(
+                    ErrorMessage(
+                        topic=self.topics.executor,
+                        msg=msg,
+                        error=repr(exc),
+                    )
                 ),
             )
+            await fut
 
     async def attempt(self, msg: MessageToExecutor) -> None:
         # find graph
@@ -182,19 +193,24 @@ async def attempt(self, msg: MessageToExecutor) -> None:
                     msg["config"], [(ERROR, TaskNotFound())]
                 )
         # notify orchestrator
-        await self.producer.send_and_wait(
+        fut = await self.producer.send(
             self.topics.orchestrator,
-            value=MessageToOrchestrator(
-                input=None,
-                config=msg["config"],
-                finally_executor=msg.get("finally_executor"),
+            value=serde.dumps(
+                MessageToOrchestrator(
+                    input=None,
+                    config=msg["config"],
+                    finally_executor=msg.get("finally_executor"),
+                )
             ),
             # use thread_id, checkpoint_ns as partition key
-            key=(
-                msg["config"]["configurable"]["thread_id"],
-                msg["config"]["configurable"].get("checkpoint_ns"),
+            key=serde.dumps(
+                (
+                    msg["config"]["configurable"]["thread_id"],
+                    msg["config"]["configurable"].get("checkpoint_ns"),
+                )
             ),
         )
+        await fut
 
     def _put_writes(
         self,
diff --git a/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py b/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py
index c13de9c08..e594d88ba 100644
--- a/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py
+++ b/libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py
@@ -2,7 +2,6 @@
 from contextlib import AbstractAsyncContextManager, AsyncExitStack
 from typing import Any, Optional
 
-import aiokafka
 from langchain_core.runnables import ensure_config
 from typing_extensions import Self
 
@@ -21,6 +20,8 @@
 from langgraph.pregel.types import RetryPolicy
 from langgraph.scheduler.kafka.retry import aretry
 from langgraph.scheduler.kafka.types import (
+    AsyncConsumer,
+    AsyncProducer,
     ErrorMessage,
     ExecutorTask,
     MessageToExecutor,
@@ -31,50 +32,55 @@
 
 
 class KafkaOrchestrator(AbstractAsyncContextManager):
+    consumer: AsyncConsumer
+
+    producer: AsyncProducer
+
     def __init__(
         self,
         graph: Pregel,
         topics: Topics,
-        group_id: str = "orchestrator",
         batch_max_n: int = 10,
         batch_max_ms: int = 1000,
         retry_policy: Optional[RetryPolicy] = None,
-        consumer_kwargs: Optional[dict[str, Any]] = None,
-        producer_kwargs: Optional[dict[str, Any]] = None,
+        consumer: Optional[AsyncConsumer] = None,
+        producer: Optional[AsyncProducer] = None,
         **kwargs: Any,
     ) -> None:
         self.graph = graph
         self.topics = topics
         self.stack = AsyncExitStack()
         self.kwargs = kwargs
-        self.consumer_kwargs = consumer_kwargs or {}
-        self.producer_kwargs = producer_kwargs or {}
-        self.group_id = group_id
+        self.consumer = consumer
+        self.producer = producer
         self.batch_max_n = batch_max_n
         self.batch_max_ms = batch_max_ms
         self.retry_policy = retry_policy
 
     async def __aenter__(self) -> Self:
-        self.consumer = await self.stack.enter_async_context(
-            aiokafka.AIOKafkaConsumer(
-                self.topics.orchestrator,
-                auto_offset_reset="earliest",
-                group_id=self.group_id,
-                enable_auto_commit=False,
-                **self.kwargs,
-                **self.consumer_kwargs,
-            )
-        )
-        self.producer = await self.stack.enter_async_context(
-            aiokafka.AIOKafkaProducer(
-                value_serializer=serde.dumps,
-                **self.kwargs,
-                **self.producer_kwargs,
-            )
-        )
         self.subgraphs = {
             k: v async for k, v in self.graph.aget_subgraphs(recurse=True)
         }
+        if self.consumer is None:
+            from langgraph.scheduler.kafka.default_async import DefaultAsyncConsumer
+
+            self.consumer = await self.stack.enter_async_context(
+                DefaultAsyncConsumer(
+                    self.topics.orchestrator,
+                    auto_offset_reset="earliest",
+                    group_id="orchestrator",
+                    enable_auto_commit=False,
+                    **self.kwargs,
+                )
+            )
+        if self.producer is None:
+            from langgraph.scheduler.kafka.default_async import DefaultAsyncProducer
+
+            self.producer = await self.stack.enter_async_context(
+                DefaultAsyncProducer(
+                    **self.kwargs,
+                )
+            )
         return self
 
     async def __aexit__(self, *args: Any) -> None:
@@ -85,15 +91,12 @@ def __aiter__(self) -> Self:
 
     async def __anext__(self) -> list[MessageToOrchestrator]:
         # wait for next batch
-        try:
-            recs = await self.consumer.getmany(
-                timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
-            )
-            # dedupe messages, eg. if multiple nodes finish around same time
-            uniq = set(msg.value for msgs in recs.values() for msg in msgs)
-            msgs: list[MessageToOrchestrator] = [serde.loads(msg) for msg in uniq]
-        except aiokafka.ConsumerStoppedError:
-            raise StopAsyncIteration from None
+        recs = await self.consumer.getmany(
+            timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
+        )
+        # dedupe messages, eg. if multiple nodes finish around same time
+        uniq = set(msg["value"] for msgs in recs.values() for msg in msgs)
+        msgs: list[MessageToOrchestrator] = [serde.loads(msg) for msg in uniq]
         # process batch
         await asyncio.gather(*(self.each(msg) for msg in msgs))
         # commit offsets
@@ -109,14 +112,17 @@ async def each(self, msg: MessageToOrchestrator) -> None:
         except GraphInterrupt:
             pass
         except Exception as exc:
-            await self.producer.send_and_wait(
+            fut = await self.producer.send(
                 self.topics.error,
-                value=ErrorMessage(
-                    topic=self.topics.orchestrator,
-                    msg=msg,
-                    error=repr(exc),
+                value=serde.dumps(
+                    ErrorMessage(
+                        topic=self.topics.orchestrator,
+                        msg=msg,
+                        error=repr(exc),
+                    )
                 ),
             )
+            await fut
 
     async def attempt(self, msg: MessageToOrchestrator) -> None:
         # find graph
@@ -155,21 +161,25 @@ async def attempt(self, msg: MessageToOrchestrator) -> None:
                 # schedule any new tasks
                 if new_tasks := [t for t in loop.tasks.values() if not t.scheduled]:
                     # send messages to executor
-                    futures: list[asyncio.Future] = await asyncio.gather(
+                    futures = await asyncio.gather(
                         *(
                             self.producer.send(
                                 self.topics.executor,
-                                value=MessageToExecutor(
-                                    config=patch_configurable(
-                                        loop.config,
-                                        {
-                                            **loop.checkpoint_config["configurable"],
-                                            CONFIG_KEY_DEDUPE_TASKS: True,
-                                            CONFIG_KEY_ENSURE_LATEST: True,
-                                        },
-                                    ),
-                                    task=ExecutorTask(id=task.id, path=task.path),
-                                    finally_executor=msg.get("finally_executor"),
+                                value=serde.dumps(
+                                    MessageToExecutor(
+                                        config=patch_configurable(
+                                            loop.config,
+                                            {
+                                                **loop.checkpoint_config[
+                                                    "configurable"
+                                                ],
+                                                CONFIG_KEY_DEDUPE_TASKS: True,
+                                                CONFIG_KEY_ENSURE_LATEST: True,
+                                            },
+                                        ),
+                                        task=ExecutorTask(id=task.id, path=task.path),
+                                        finally_executor=msg.get("finally_executor"),
+                                    )
                                 ),
                             )
                             for task in new_tasks
@@ -197,7 +207,7 @@ async def attempt(self, msg: MessageToOrchestrator) -> None:
                 # schedule any finally_executor tasks
                 futs = await asyncio.gather(
                     *(
-                        self.producer.send(self.topics.executor, value=m)
+                        self.producer.send(self.topics.executor, value=serde.dumps(m))
                         for m in msg["finally_executor"]
                     )
                 )
diff --git a/libs/scheduler-kafka/langgraph/scheduler/kafka/types.py b/libs/scheduler-kafka/langgraph/scheduler/kafka/types.py
index 3bc298b93..65385a3a0 100644
--- a/libs/scheduler-kafka/langgraph/scheduler/kafka/types.py
+++ b/libs/scheduler-kafka/langgraph/scheduler/kafka/types.py
@@ -1,4 +1,6 @@
-from typing import Any, NamedTuple, Optional, Sequence, TypedDict, Union
+import asyncio
+import concurrent.futures
+from typing import Any, NamedTuple, Optional, Protocol, Sequence, TypedDict, Union
 
 from langchain_core.runnables import RunnableConfig
 
@@ -30,3 +32,39 @@ class ErrorMessage(TypedDict):
     topic: str
     error: str
     msg: Union[MessageToExecutor, MessageToOrchestrator]
+
+
+class Consumer(Protocol):
+    def getmany(
+        self, timeout_ms: int, max_records: int
+    ) -> dict[str, Sequence[dict[str, Any]]]: ...
+
+    def commit(self) -> None: ...
+
+
+class AsyncConsumer(Protocol):
+    async def getmany(
+        self, timeout_ms: int, max_records: int
+    ) -> dict[str, Sequence[dict[str, Any]]]: ...
+
+    async def commit(self) -> None: ...
+
+
+class Producer(Protocol):
+    def send(
+        self,
+        topic: str,
+        *,
+        key: Optional[bytes] = None,
+        value: Optional[bytes] = None,
+    ) -> concurrent.futures.Future: ...
+
+
+class AsyncProducer(Protocol):
+    async def send(
+        self,
+        topic: str,
+        *,
+        key: Optional[bytes] = None,
+        value: Optional[bytes] = None,
+    ) -> asyncio.Future: ...