From e0a0958a60079d2425965c199baae6eb9d4886b6 Mon Sep 17 00:00:00 2001
From: Nuno Campos <nuno@langchain.dev>
Date: Mon, 9 Dec 2024 08:14:41 -0800
Subject: [PATCH] lib: imperative api: Generators use yield to publish
 stream_mode=custom events

---
 libs/langgraph/langgraph/func/__init__.py | 29 ++++++++++++++++++++---
 1 file changed, 26 insertions(+), 3 deletions(-)

diff --git a/libs/langgraph/langgraph/func/__init__.py b/libs/langgraph/langgraph/func/__init__.py
index 2dda24754..26d583ed2 100644
--- a/libs/langgraph/langgraph/func/__init__.py
+++ b/libs/langgraph/langgraph/func/__init__.py
@@ -1,6 +1,7 @@
 import asyncio
 import concurrent
 import concurrent.futures
+import inspect
 import types
 from functools import partial, update_wrapper
 from typing import (
@@ -24,7 +25,7 @@
 from langgraph.pregel.read import PregelNode
 from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
 from langgraph.store.base import BaseStore
-from langgraph.types import RetryPolicy
+from langgraph.types import RetryPolicy, StreamMode, StreamWriter
 
 P = ParamSpec("P")
 P1 = TypeVar("P1")
@@ -76,10 +77,32 @@ def entrypoint(
     store: Optional[BaseStore] = None,
 ) -> Callable[[types.FunctionType], Pregel]:
     def _imp(func: types.FunctionType) -> Pregel:
+        if inspect.isgeneratorfunction(func):
+
+            def gen_wrapper(*args: Any, writer: StreamWriter, **kwargs: Any) -> Any:
+                for chunk in func(*args, **kwargs):
+                    writer(chunk)
+
+            bound = get_runnable_for_func(gen_wrapper)
+            stream_mode: StreamMode = "custom"
+        elif inspect.isasyncgenfunction(func):
+
+            async def agen_wrapper(
+                *args: Any, writer: StreamWriter, **kwargs: Any
+            ) -> Any:
+                async for chunk in func(*args, **kwargs):
+                    writer(chunk)
+
+            bound = get_runnable_for_func(agen_wrapper)
+            stream_mode = "custom"
+        else:
+            bound = get_runnable_for_func(func)
+            stream_mode = "updates"
+
         return Pregel(
             nodes={
                 func.__name__: PregelNode(
-                    bound=get_runnable_for_func(func),
+                    bound=bound,
                     triggers=[START],
                     channels=[START],
                     writers=[ChannelWrite([ChannelWriteEntry(END)], tags=[TAG_HIDDEN])],
@@ -89,7 +112,7 @@ def _imp(func: types.FunctionType) -> Pregel:
             input_channels=START,
             output_channels=END,
             stream_channels=END,
-            stream_mode="updates",
+            stream_mode=stream_mode,
             checkpointer=checkpointer,
             store=store,
         )