Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kafka: Expose finally_send as public api #1693

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions libs/scheduler-kafka/langgraph/scheduler/kafka/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
MessageToExecutor,
MessageToOrchestrator,
Producer,
Sendable,
Topics,
)
from langgraph.utils.config import patch_configurable
Expand Down Expand Up @@ -129,7 +130,9 @@ async def each(self, msg: MessageToExecutor) -> None:
input=orjson.Fragment(
self.graph.checkpointer.serde.dumps(arg["input"])
),
finally_executor=[msg],
finally_send=[
Sendable(topic=self.topics.executor, value=msg)
],
)
),
# use thread_id, checkpoint_ns as partition key
Expand Down Expand Up @@ -211,7 +214,7 @@ async def attempt(self, msg: MessageToExecutor) -> None:
MessageToOrchestrator(
input=None,
config=msg["config"],
finally_executor=msg.get("finally_executor"),
finally_send=msg.get("finally_send"),
)
),
# use thread_id, checkpoint_ns as partition key
Expand Down Expand Up @@ -322,7 +325,9 @@ def each(self, msg: MessageToExecutor) -> None:
input=orjson.Fragment(
self.graph.checkpointer.serde.dumps(arg["input"])
),
finally_executor=[msg],
finally_send=[
Sendable(topic=self.topics.executor, value=msg)
],
)
),
# use thread_id, checkpoint_ns as partition key
Expand Down Expand Up @@ -403,7 +408,7 @@ def attempt(self, msg: MessageToExecutor) -> None:
MessageToOrchestrator(
input=None,
config=msg["config"],
finally_executor=msg.get("finally_executor"),
finally_send=msg.get("finally_send"),
)
),
# use thread_id, checkpoint_ns as partition key
Expand Down
29 changes: 18 additions & 11 deletions libs/scheduler-kafka/langgraph/scheduler/kafka/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ async def attempt(self, msg: MessageToOrchestrator) -> None:
},
),
task=ExecutorTask(id=task.id, path=task.path),
finally_executor=msg.get("finally_executor"),
finally_send=msg.get("finally_send"),
)
),
)
Expand All @@ -212,12 +212,16 @@ async def attempt(self, msg: MessageToOrchestrator) -> None:
)
],
)
elif loop.status == "done" and msg.get("finally_executor"):
# schedule any finally_executor tasks
elif loop.status == "done" and msg.get("finally_send"):
# send any finally_send messages
futs = await asyncio.gather(
*(
self.producer.send(self.topics.executor, value=serde.dumps(m))
for m in msg["finally_executor"]
self.producer.send(
m["topic"],
value=serde.dumps(m["value"]) if m.get("value") else None,
key=serde.dumps(m["key"]) if m.get("key") else None,
)
for m in msg["finally_send"]
)
)
# wait for messages to be sent
Expand Down Expand Up @@ -288,7 +292,6 @@ def __next__(self) -> list[MessageToOrchestrator]:
recs = self.consumer.getmany(
timeout_ms=self.batch_max_ms, max_records=self.batch_max_n
)
print("orch.__next__", recs)
# dedupe messages, eg. if multiple nodes finish around same time
uniq = set(msg.value for msgs in recs.values() for msg in msgs)
msgs: list[MessageToOrchestrator] = [serde.loads(msg) for msg in uniq]
Expand Down Expand Up @@ -370,7 +373,7 @@ def attempt(self, msg: MessageToOrchestrator) -> None:
},
),
task=ExecutorTask(id=task.id, path=task.path),
finally_executor=msg.get("finally_executor"),
finally_send=msg.get("finally_send"),
)
),
)
Expand All @@ -394,11 +397,15 @@ def attempt(self, msg: MessageToOrchestrator) -> None:
)
],
)
elif loop.status == "done" and msg.get("finally_executor"):
# schedule any finally_executor tasks
elif loop.status == "done" and msg.get("finally_send"):
# schedule any finally_send msgs
futs = [
self.producer.send(self.topics.executor, value=serde.dumps(m))
for m in msg["finally_executor"]
self.producer.send(
m["topic"],
value=serde.dumps(m["value"]) if m.get("value") else None,
key=serde.dumps(m["key"]) if m.get("key") else None,
)
for m in msg["finally_send"]
]
# wait for messages to be sent
concurrent.futures.wait(futs)
10 changes: 8 additions & 2 deletions libs/scheduler-kafka/langgraph/scheduler/kafka/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@ class Topics(NamedTuple):
error: str


class Sendable(TypedDict):
topic: str
value: Optional[Any]
key: Optional[Any]


class MessageToOrchestrator(TypedDict):
input: Optional[dict[str, Any]]
config: RunnableConfig
finally_executor: Optional[Sequence["MessageToExecutor"]]
finally_send: Optional[Sequence[Sendable]]


class ExecutorTask(TypedDict):
Expand All @@ -25,7 +31,7 @@ class ExecutorTask(TypedDict):
class MessageToExecutor(TypedDict):
config: RunnableConfig
task: ExecutorTask
finally_executor: Optional[Sequence["MessageToExecutor"]]
finally_send: Optional[Sequence[Sendable]]


class ErrorMessage(TypedDict):
Expand Down
6 changes: 4 additions & 2 deletions libs/scheduler-kafka/tests/drain.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ async def drain_topics_async(
def done() -> bool:
return (
len(orch_msgs) > 0
and any(orch_msgs)
and len(exec_msgs) > 0
and any(exec_msgs)
and not orch_msgs[-1]
and not exec_msgs[-1]
)
Expand Down Expand Up @@ -97,7 +99,9 @@ def drain_topics(
def done() -> bool:
return (
len(orch_msgs) > 0
and any(orch_msgs)
and len(exec_msgs) > 0
and any(exec_msgs)
and not orch_msgs[-1]
and not exec_msgs[-1]
)
Expand All @@ -110,7 +114,6 @@ def orchestrator() -> None:
if debug:
print("\n---\norch", len(msgs), msgs)
if done():
print("am i done? orchestrator")
event.set()
if event.is_set():
break
Expand All @@ -126,7 +129,6 @@ def executor() -> None:
if debug:
print("\n---\nexec", len(msgs), msgs)
if done():
print("am i done? executor")
event.set()
if event.is_set():
break
Expand Down
8 changes: 4 additions & 4 deletions libs/scheduler-kafka/tests/test_fanout.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def test_fanout_graph(topics: Topics, acheckpointer: BaseCheckpointSaver)
"tags": [],
},
"input": None,
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history)
for _ in c.tasks
Expand All @@ -161,7 +161,7 @@ async def test_fanout_graph(topics: Topics, acheckpointer: BaseCheckpointSaver)
"id": t.id,
"path": list(t.path),
},
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history)
for t in c.tasks
Expand Down Expand Up @@ -218,7 +218,7 @@ async def test_fanout_graph_w_interrupt(
"tags": [],
},
"input": None,
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
# orchestrator messages appear only after tasks for that checkpoint
Expand All @@ -245,7 +245,7 @@ async def test_fanout_graph_w_interrupt(
"id": t.id,
"path": list(t.path),
},
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
for t in c.tasks
Expand Down
8 changes: 4 additions & 4 deletions libs/scheduler-kafka/tests/test_fanout_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_fanout_graph(topics: Topics, checkpointer: BaseCheckpointSaver) -> None
"tags": [],
},
"input": None,
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history)
for _ in c.tasks
Expand All @@ -158,7 +158,7 @@ def test_fanout_graph(topics: Topics, checkpointer: BaseCheckpointSaver) -> None
"id": t.id,
"path": list(t.path),
},
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history)
for t in c.tasks
Expand Down Expand Up @@ -216,7 +216,7 @@ def test_fanout_graph_w_interrupt(
"tags": [],
},
"input": None,
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
# orchestrator messages appear only after tasks for that checkpoint
Expand All @@ -243,7 +243,7 @@ def test_fanout_graph_w_interrupt(
"id": t.id,
"path": list(t.path),
},
"finally_executor": None,
"finally_send": None,
}
for c in reversed(history[1:]) # the last one wasn't executed
for t in c.tasks
Expand Down
Loading
Loading