From 89fb8d506ff33b441f90972d127746d15da6192e Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sat, 21 Sep 2024 17:21:18 -0700 Subject: [PATCH] Add more comments and docstrings --- libs/langgraph/langgraph/constants.py | 167 ++++++++---------- libs/langgraph/langgraph/errors.py | 25 +-- libs/langgraph/langgraph/pregel/algo.py | 25 ++- libs/langgraph/langgraph/pregel/debug.py | 4 + libs/langgraph/langgraph/pregel/executor.py | 17 +- libs/langgraph/langgraph/pregel/io.py | 5 +- libs/langgraph/langgraph/pregel/loop.py | 6 +- libs/langgraph/langgraph/pregel/messages.py | 3 + libs/langgraph/langgraph/pregel/metadata.py | 0 libs/langgraph/langgraph/pregel/read.py | 24 ++- libs/langgraph/langgraph/pregel/runner.py | 8 + libs/langgraph/langgraph/pregel/types.py | 112 ++++++++++-- libs/langgraph/langgraph/pregel/utils.py | 2 +- libs/langgraph/langgraph/pregel/validate.py | 2 +- libs/langgraph/langgraph/pregel/write.py | 29 ++- libs/langgraph/langgraph/version.py | 2 +- libs/langgraph/tests/test_tracing_interops.py | 7 +- 17 files changed, 274 insertions(+), 164 deletions(-) delete mode 100644 libs/langgraph/langgraph/pregel/metadata.py diff --git a/libs/langgraph/langgraph/constants.py b/libs/langgraph/langgraph/constants.py index e8719e664..eb0f703be 100644 --- a/libs/langgraph/langgraph/constants.py +++ b/libs/langgraph/langgraph/constants.py @@ -1,133 +1,106 @@ -from dataclasses import dataclass from types import MappingProxyType -from typing import Any, Literal, Mapping +from typing import Any, Mapping +from langgraph.pregel.types import Interrupt, Send # noqa: F401 + +# Interrupt, Send re-exported for backwards compatibility + +# --- Empty read-only containers --- +EMPTY_MAP: Mapping[str, Any] = MappingProxyType({}) +EMPTY_SEQ: tuple[str, ...] = tuple() + +# --- Reserved write keys --- INPUT = "__input__" +# for values passed as input to the graph +INTERRUPT = "__interrupt__" +# for dynamic interrupts raised by nodes +ERROR = "__error__" +# for errors raised by nodes +NO_WRITES = "__no_writes__" +# marker to signal node didn't write anything +SCHEDULED = "__scheduled__" +# marker to signal node was scheduled (in distributed mode) +TASKS = "__pregel_tasks" +# for Send objects returned by nodes/edges, corresponds to PUSH below +START = "__start__" +# marker for the first (maybe virtual) node in graph-style Pregel +END = "__end__" +# marker for the last (maybe virtual) node in graph-style Pregel + +# --- Reserved config.configurable keys --- CONFIG_KEY_SEND = "__pregel_send" +# holds the `write` function that accepts writes to state/edges/reserved keys CONFIG_KEY_READ = "__pregel_read" +# holds the `read` function that returns a copy of the current state CONFIG_KEY_CHECKPOINTER = "__pregel_checkpointer" +# holds a `BaseCheckpointSaver` passed from parent graph to child graphs CONFIG_KEY_STREAM = "__pregel_stream" +# holds a `StreamProtocol` passed from parent graph to child graphs CONFIG_KEY_STREAM_WRITER = "__pregel_stream_writer" +# holds a `StreamWriter` for stream_mode=custom CONFIG_KEY_STORE = "__pregel_store" +# holds a `BaseStore` made available to managed values CONFIG_KEY_RESUMING = "__pregel_resuming" +# holds a boolean indicating if subgraphs should resume from a previous checkpoint CONFIG_KEY_TASK_ID = "__pregel_task_id" +# holds the task ID for the current task CONFIG_KEY_DEDUPE_TASKS = "__pregel_dedupe_tasks" +# holds a boolean indicating if tasks should be deduplicated (for distributed mode) CONFIG_KEY_ENSURE_LATEST = "__pregel_ensure_latest" +# holds a boolean indicating whether to assert the requested checkpoint is the latest +# (for distributed mode) CONFIG_KEY_DELEGATE = "__pregel_delegate" -# this one part of public API so more readable +# holds a boolean indicating whether to delegate subgraphs (for distributed mode) CONFIG_KEY_CHECKPOINT_MAP = "checkpoint_map" -INTERRUPT = "__interrupt__" -ERROR = "__error__" -NO_WRITES = "__no_writes__" -SCHEDULED = "__scheduled__" -TASKS = "__pregel_tasks" # for backwards compat, this is the original name of PUSH +# holds a mapping of checkpoint_ns -> checkpoint_id for parent graphs +CONFIG_KEY_CHECKPOINT_ID = "checkpoint_id" +# holds the current checkpoint_id, if any +CONFIG_KEY_CHECKPOINT_NS = "checkpoint_ns" +# holds the current checkpoint_ns, "" for root graph + +# --- Other constants --- PUSH = "__pregel_push" +# denotes push-style tasks, ie. those created by Send objects PULL = "__pregel_pull" +# denotes pull-style tasks, ie. those triggered by edges RUNTIME_PLACEHOLDER = "__pregel_runtime_placeholder__" +# placeholder for managed values replaced at runtime +TAG_HIDDEN = "langsmith:hidden" +# tag to hide a node/edge from certain tracing/streaming environments +NS_SEP = "|" +# for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph) +NS_END = ":" +# for checkpoint_ns, for each level, separates the namespace from the task_id + RESERVED = { - SCHEDULED, + # reserved write keys + INPUT, INTERRUPT, ERROR, NO_WRITES, + SCHEDULED, TASKS, - PUSH, - PULL, + # reserved config.configurable keys CONFIG_KEY_SEND, CONFIG_KEY_READ, CONFIG_KEY_CHECKPOINTER, - CONFIG_KEY_CHECKPOINT_MAP, CONFIG_KEY_STREAM, CONFIG_KEY_STREAM_WRITER, CONFIG_KEY_STORE, + CONFIG_KEY_CHECKPOINT_MAP, CONFIG_KEY_RESUMING, CONFIG_KEY_TASK_ID, CONFIG_KEY_DEDUPE_TASKS, CONFIG_KEY_ENSURE_LATEST, CONFIG_KEY_DELEGATE, - INPUT, + CONFIG_KEY_CHECKPOINT_MAP, + CONFIG_KEY_CHECKPOINT_ID, + CONFIG_KEY_CHECKPOINT_NS, + # other constants + PUSH, + PULL, RUNTIME_PLACEHOLDER, + TAG_HIDDEN, + NS_SEP, + NS_END, } -TAG_HIDDEN = "langsmith:hidden" - -START = "__start__" -END = "__end__" - -NS_SEP = "|" -NS_END = ":" - -EMPTY_MAP: Mapping[str, Any] = MappingProxyType({}) - - -class Send: - """A message or packet to send to a specific node in the graph. - - The `Send` class is used within a `StateGraph`'s conditional edges to - dynamically invoke a node with a custom state at the next step. - - Importantly, the sent state can differ from the core graph's state, - allowing for flexible and dynamic workflow management. - - One such example is a "map-reduce" workflow where your graph invokes - the same node multiple times in parallel with different states, - before aggregating the results back into the main graph's state. - - Attributes: - node (str): The name of the target node to send the message to. - arg (Any): The state or message to send to the target node. - - Examples: - >>> from typing import Annotated - >>> import operator - >>> class OverallState(TypedDict): - ... subjects: list[str] - ... jokes: Annotated[list[str], operator.add] - ... - >>> from langgraph.constants import Send - >>> from langgraph.graph import END, START - >>> def continue_to_jokes(state: OverallState): - ... return [Send("generate_joke", {"subject": s}) for s in state['subjects']] - ... - >>> from langgraph.graph import StateGraph - >>> builder = StateGraph(OverallState) - >>> builder.add_node("generate_joke", lambda state: {"jokes": [f"Joke about {state['subject']}"]}) - >>> builder.add_conditional_edges(START, continue_to_jokes) - >>> builder.add_edge("generate_joke", END) - >>> graph = builder.compile() - >>> - >>> # Invoking with two subjects results in a generated joke for each - >>> graph.invoke({"subjects": ["cats", "dogs"]}) - {'subjects': ['cats', 'dogs'], 'jokes': ['Joke about cats', 'Joke about dogs']} - """ - - node: str - arg: Any - - def __init__(self, /, node: str, arg: Any) -> None: - """ - Initialize a new instance of the Send class. - - Args: - node (str): The name of the target node to send the message to. - arg (Any): The state or message to send to the target node. - """ - self.node = node - self.arg = arg - - def __hash__(self) -> int: - return hash((self.node, self.arg)) - - def __repr__(self) -> str: - return f"Send(node={self.node!r}, arg={self.arg!r})" - - def __eq__(self, value: object) -> bool: - return ( - isinstance(value, Send) - and self.node == value.node - and self.arg == value.arg - ) - - -@dataclass -class Interrupt: - value: Any - when: Literal["during"] = "during" diff --git a/libs/langgraph/langgraph/errors.py b/libs/langgraph/langgraph/errors.py index ec84e0b28..ad3150942 100644 --- a/libs/langgraph/langgraph/errors.py +++ b/libs/langgraph/langgraph/errors.py @@ -1,8 +1,10 @@ from typing import Any, Sequence -from langgraph.checkpoint.base import EmptyChannelError +from langgraph.checkpoint.base import EmptyChannelError # noqa: F401 from langgraph.constants import Interrupt +# EmptyChannelError re-exported for backwards compatibility + class GraphRecursionError(RecursionError): """Raised when the graph has exhausted the maximum number of steps. @@ -24,13 +26,14 @@ class GraphRecursionError(RecursionError): class InvalidUpdateError(Exception): - """Raised when attempting to update a channel with an invalid sequence of updates.""" + """Raised when attempting to update a channel with an invalid set of updates.""" pass class GraphInterrupt(Exception): - """Raised when a subgraph is interrupted.""" + """Raised when a subgraph is interrupted, supressed by the root graph. + Never raised directly, or surfaced to the user.""" def __init__(self, interrupts: Sequence[Interrupt] = ()) -> None: super().__init__(interrupts) @@ -44,7 +47,7 @@ def __init__(self, value: Any) -> None: class GraphDelegate(Exception): - """Raised when a graph is delegated.""" + """Raised when a graph is delegated (for distributed mode).""" def __init__(self, *args: dict[str, Any]) -> None: super().__init__(*args) @@ -57,22 +60,12 @@ class EmptyInputError(Exception): class TaskNotFound(Exception): - """Raised when the executor is unable to find a task.""" + """Raised when the executor is unable to find a task (for distributed mode).""" pass class CheckpointNotLatest(Exception): - """Raised when the checkpoint is not the latest version.""" + """Raised when the checkpoint is not the latest version (for distributed mode).""" pass - - -__all__ = [ - "GraphRecursionError", - "InvalidUpdateError", - "GraphInterrupt", - "NodeInterrupt", - "EmptyInputError", - "EmptyChannelError", -] diff --git a/libs/langgraph/langgraph/pregel/algo.py b/libs/langgraph/langgraph/pregel/algo.py index 8a63d4cc6..44da0bbd5 100644 --- a/libs/langgraph/langgraph/pregel/algo.py +++ b/libs/langgraph/langgraph/pregel/algo.py @@ -33,6 +33,7 @@ CONFIG_KEY_READ, CONFIG_KEY_SEND, CONFIG_KEY_TASK_ID, + EMPTY_SEQ, INTERRUPT, NO_WRITES, NS_END, @@ -55,10 +56,11 @@ GetNextVersion = Callable[[Optional[V], BaseChannel], V] -EMPTY_SEQ: tuple[str, ...] = tuple() - class WritesProtocol(Protocol): + """Protocol for objects containing writes to be applied to checkpoint. + Implemented by PregelTaskWrites and PregelExecutableTask.""" + @property def name(self) -> str: ... @@ -70,6 +72,9 @@ def triggers(self) -> Sequence[str]: ... class PregelTaskWrites(NamedTuple): + """Simplest implementation of WritesProtocol, for usage with writes that + don't originate from a runnable task, eg. graph input, update_state, etc.""" + name: str writes: Sequence[tuple[str, Any]] triggers: Sequence[str] @@ -80,6 +85,7 @@ def should_interrupt( interrupt_nodes: Union[All, Sequence[str]], tasks: Iterable[PregelExecutableTask], ) -> list[PregelExecutableTask]: + """Check if the graph should be interrupted based on current state.""" version_type = type(next(iter(checkpoint["channel_versions"].values()), None)) null_version = version_type() # type: ignore[misc] seen = checkpoint["versions_seen"].get(INTERRUPT, {}) @@ -117,6 +123,9 @@ def local_read( select: Union[list[str], str], fresh: bool = False, ) -> Union[dict[str, Any], Any]: + """Function injected under CONFIG_KEY_READ in task config, to read current state. + Used by conditional edges to read a copy of the state with reflecting the writes + from that node only.""" if isinstance(select, str): managed_keys = [] for c, _ in task.writes: @@ -153,6 +162,8 @@ def local_write( managed: ManagedValueMapping, writes: Sequence[tuple[str, Any]], ) -> None: + """Function injected under CONFIG_KEY_SEND in task config, to write to channels. + Validates writes and forwards them to `commit` function.""" for chan, value in writes: if chan == TASKS: if not isinstance(value, Send): @@ -169,6 +180,7 @@ def local_write( def increment(current: Optional[int], channel: BaseChannel) -> int: + """Default channel versioning function, increments the current int version.""" return current + 1 if current is not None else 1 @@ -178,6 +190,9 @@ def apply_writes( tasks: Iterable[WritesProtocol], get_next_version: Optional[GetNextVersion], ) -> dict[str, list[Any]]: + """Apply writes from a set of tasks (usually the tasks from a Pregel step) + to the checkpoint and channels, and return managed values writes to be applied + externally.""" # update seen versions for task in tasks: checkpoint["versions_seen"].setdefault(task.name, {}).update( @@ -297,6 +312,9 @@ def prepare_next_tasks( checkpointer: Optional[BaseCheckpointSaver] = None, manager: Union[None, ParentRunManager, AsyncParentRunManager] = None, ) -> Union[dict[str, PregelTask], dict[str, PregelExecutableTask]]: + """Prepare the set of tasks that will make up the next Pregel step. + This is the union of all PUSH tasks (Sends) and PULL tasks (nodes triggered + by edges).""" tasks: dict[str, Union[PregelTask, PregelExecutableTask]] = {} # Consume pending packets for idx, _ in enumerate(checkpoint["pending_sends"]): @@ -348,6 +366,8 @@ def prepare_single_task( checkpointer: Optional[BaseCheckpointSaver] = None, manager: Union[None, ParentRunManager, AsyncParentRunManager] = None, ) -> Union[None, PregelTask, PregelExecutableTask]: + """Prepares a single task for the next Pregel step, given a task path, which + uniquely identifies a PUSH or PULL task within the graph.""" checkpoint_id = UUID(checkpoint["id"]).bytes configurable = config.get("configurable", {}) parent_ns = configurable.get("checkpoint_ns", "") @@ -568,6 +588,7 @@ def _proc_input( *, for_execution: bool, ) -> Iterator[Any]: + """Prepare input for a PULL task, based on the process's channels and triggers.""" # If all trigger channels subscribed by this process are not empty # then invoke the process with the values of all non-empty channels if isinstance(proc.channels, dict): diff --git a/libs/langgraph/langgraph/pregel/debug.py b/libs/langgraph/langgraph/pregel/debug.py index 56f1eb9c6..9c4661c0c 100644 --- a/libs/langgraph/langgraph/pregel/debug.py +++ b/libs/langgraph/langgraph/pregel/debug.py @@ -84,6 +84,7 @@ class DebugOutputCheckpoint(DebugOutputBase): def map_debug_tasks( step: int, tasks: Iterable[PregelExecutableTask] ) -> Iterator[DebugOutputTask]: + """Produce "task" events for stream_mode=debug.""" ts = datetime.now(timezone.utc).isoformat() for task in tasks: if task.config is not None and TAG_HIDDEN in task.config.get("tags", []): @@ -107,6 +108,7 @@ def map_debug_task_results( task_tup: tuple[PregelExecutableTask, Sequence[tuple[str, Any]]], stream_keys: Union[str, Sequence[str]], ) -> Iterator[DebugOutputTaskResult]: + """Produce "task_result" events for stream_mode=debug.""" stream_channels_list = ( [stream_keys] if isinstance(stream_keys, str) else stream_keys ) @@ -135,6 +137,7 @@ def map_debug_checkpoint( tasks: Iterable[PregelExecutableTask], pending_writes: list[PendingWrite], ) -> Iterator[DebugOutputCheckpoint]: + """Produce "checkpoint" events for stream_mode=debug.""" yield { "type": "checkpoint", "timestamp": checkpoint["ts"], @@ -213,6 +216,7 @@ def tasks_w_writes( pending_writes: Optional[list[PendingWrite]], states: Optional[dict[str, Union[RunnableConfig, StateSnapshot]]], ) -> tuple[PregelTask, ...]: + """Apply writes / subgraph states to tasks to be returned in a StateSnapshot.""" pending_writes = pending_writes or [] return tuple( PregelTask( diff --git a/libs/langgraph/langgraph/pregel/executor.py b/libs/langgraph/langgraph/pregel/executor.py index 46f1c3f64..691098b7a 100644 --- a/libs/langgraph/langgraph/pregel/executor.py +++ b/libs/langgraph/langgraph/pregel/executor.py @@ -39,6 +39,13 @@ def __call__( class BackgroundExecutor(ContextManager): + """A context manager that runs sync tasks in the background. + Uses a thread pool executor to delegate tasks to separate threads. + On exit, + - cancels any (not yet started) tasks with `__cancel_on_exit__=True` + - waits for all tasks to finish + - re-raises the first exception from tasks with `__reraise_on_exit__=True`""" + def __init__(self, config: RunnableConfig) -> None: self.stack = ExitStack() self.executor = self.stack.enter_context(get_executor_for_config(config)) @@ -49,7 +56,7 @@ def submit( # type: ignore[valid-type] fn: Callable[P, T], *args: P.args, __name__: Optional[str] = None, # currently not used in sync version - __cancel_on_exit__: bool = False, + __cancel_on_exit__: bool = False, # for sync, can cancel only if not started __reraise_on_exit__: bool = True, **kwargs: P.kwargs, ) -> concurrent.futures.Future[T]: @@ -101,6 +108,14 @@ def __exit__( class AsyncBackgroundExecutor(AsyncContextManager): + """A context manager that runs async tasks in the background. + Uses the current event loop to delegate tasks to asyncio tasks. + On exit, + - cancels any tasks with `__cancel_on_exit__=True` + - waits for all tasks to finish + - re-raises the first exception from tasks with `__reraise_on_exit__=True` + ignoring CancelledError""" + def __init__(self) -> None: self.context_not_supported = sys.version_info < (3, 11) self.tasks: dict[asyncio.Task, tuple[bool, bool]] = {} diff --git a/libs/langgraph/langgraph/pregel/io.py b/libs/langgraph/langgraph/pregel/io.py index ad2252c9d..6542b1d91 100644 --- a/libs/langgraph/langgraph/pregel/io.py +++ b/libs/langgraph/langgraph/pregel/io.py @@ -3,7 +3,7 @@ from langchain_core.runnables.utils import AddableDict from langgraph.channels.base import BaseChannel, EmptyChannelError -from langgraph.constants import ERROR, INTERRUPT, TAG_HIDDEN +from langgraph.constants import EMPTY_SEQ, ERROR, INTERRUPT, TAG_HIDDEN from langgraph.pregel.log import logger from langgraph.pregel.types import PregelExecutableTask @@ -97,9 +97,6 @@ def __radd__(self, other: dict[str, Any]) -> "AddableUpdatesDict": raise TypeError("AddableUpdatesDict does not support right-side addition") -EMPTY_SEQ: tuple[str, ...] = tuple() - - def map_output_updates( output_channels: Union[str, Sequence[str]], tasks: list[tuple[PregelExecutableTask, Sequence[tuple[str, Any]]]], diff --git a/libs/langgraph/langgraph/pregel/loop.py b/libs/langgraph/langgraph/pregel/loop.py index 49baa1846..c97cc8a69 100644 --- a/libs/langgraph/langgraph/pregel/loop.py +++ b/libs/langgraph/langgraph/pregel/loop.py @@ -44,6 +44,7 @@ CONFIG_KEY_RESUMING, CONFIG_KEY_STREAM, CONFIG_KEY_TASK_ID, + EMPTY_SEQ, ERROR, INPUT, INTERRUPT, @@ -101,13 +102,12 @@ V = TypeVar("V") P = ParamSpec("P") +StreamChunk = tuple[tuple[str, ...], str, Any] + INPUT_DONE = object() INPUT_RESUMING = object() -EMPTY_SEQ = () SPECIAL_CHANNELS = (ERROR, INTERRUPT, SCHEDULED) -StreamChunk = tuple[tuple[str, ...], str, Any] - class StreamProtocol: __slots__ = ("modes", "__call__") diff --git a/libs/langgraph/langgraph/pregel/messages.py b/libs/langgraph/langgraph/pregel/messages.py index 7c3f90b10..d0ae539e2 100644 --- a/libs/langgraph/langgraph/pregel/messages.py +++ b/libs/langgraph/langgraph/pregel/messages.py @@ -24,6 +24,9 @@ class StreamMessagesHandler(BaseCallbackHandler, _StreamingCallbackHandler): + """A callback handler that implements stream_mode=messages. + Collects messages from (1) chat model stream events and (2) node outputs.""" + def __init__(self, stream: Callable[[StreamChunk], None]): self.stream = stream self.metadata: dict[UUID, Meta] = {} diff --git a/libs/langgraph/langgraph/pregel/metadata.py b/libs/langgraph/langgraph/pregel/metadata.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/libs/langgraph/langgraph/pregel/read.py b/libs/langgraph/langgraph/pregel/read.py index 3ad988b89..097e76fa2 100644 --- a/libs/langgraph/langgraph/pregel/read.py +++ b/libs/langgraph/langgraph/pregel/read.py @@ -31,6 +31,9 @@ class ChannelRead(RunnableCallable): + """Implements the logic for reading state from CONFIG_KEY_READ. + Usable both as a runnable as well as a static method to call imperatively.""" + channel: Union[str, list[str]] fresh: bool = False @@ -108,21 +111,39 @@ def do_read( class PregelNode(Runnable): + """A node in a Pregel graph. This won't be invoked as a runnable by the graph + itself, but instead acts as a container for the components necessary to make + a PregelExecutableTask for a node.""" + channels: Union[list[str], Mapping[str, str]] + """The channels that will be passed as input to `bound`. + If a list, the node will be invoked with the first of that isn't empty. + If a dict, the keys are the names of the channels, and the values are the keys + to use in the input to `bound`.""" triggers: list[str] + """If any of these channels is written to, this node will be triggered in + the next step.""" mapper: Optional[Callable[[Any], Any]] + """A function to transform the input before passing it to `bound`.""" writers: list[Runnable] + """A list of writers that will be executed after `bound`, responsible for + taking the output of `bound` and writing it to the appropriate channels.""" bound: Runnable[Any, Any] + """The main logic of the node. This will be invoked with the input from + `channels`.""" retry_policy: Optional[RetryPolicy] + """The retry policy to use when invoking the node.""" tags: Optional[Sequence[str]] + """Tags to attach to the node for tracing.""" metadata: Optional[Mapping[str, Any]] + """Metadata to attach to the node for tracing.""" def __init__( self, @@ -151,7 +172,7 @@ def copy(self, update: dict[str, Any]) -> PregelNode: @cached_property def flat_writers(self) -> list[Runnable]: - """Get writers with optimizations applied.""" + """Get writers with optimizations applied. Dedupes consecutive ChannelWrites.""" writers = self.writers.copy() while ( len(writers) > 1 @@ -170,6 +191,7 @@ def flat_writers(self) -> list[Runnable]: @cached_property def node(self) -> Optional[Runnable[Any, Any]]: + """Get a runnable that combines `bound` and `writers`.""" writers = self.flat_writers if self.bound is DEFAULT_BOUND and not writers: return None diff --git a/libs/langgraph/langgraph/pregel/runner.py b/libs/langgraph/langgraph/pregel/runner.py index 14e84352f..32bbf74bd 100644 --- a/libs/langgraph/langgraph/pregel/runner.py +++ b/libs/langgraph/langgraph/pregel/runner.py @@ -22,6 +22,10 @@ class PregelRunner: + """Responsible for executing a set of Pregel tasks concurrently, commiting + their writes, yielding control to caller when there is output to emit, and + interrupting other tasks if appropriate.""" + def __init__( self, *, @@ -215,6 +219,8 @@ def commit( def _should_stop_others( done: Union[set[concurrent.futures.Future[Any]], set[asyncio.Future[Any]]], ) -> bool: + """Check if any task failed, if so, cancel all other tasks. + GraphInterrupts are not considered failures.""" for fut in done: if fut.cancelled(): return True @@ -227,6 +233,7 @@ def _should_stop_others( def _exception( fut: Union[concurrent.futures.Future[Any], asyncio.Future[Any]], ) -> Optional[BaseException]: + """Return the exception from a future, without raising CancelledError.""" if fut.cancelled(): if isinstance(fut, asyncio.Future): return asyncio.CancelledError() @@ -245,6 +252,7 @@ def _panic_or_proceed( timeout_exc_cls: Type[Exception] = TimeoutError, panic: bool = True, ) -> None: + """Cancel remaining tasks if any failed, re-raise exception if panic is True.""" done: set[Union[concurrent.futures.Future[Any], asyncio.Future[Any]]] = set() inflight: set[Union[concurrent.futures.Future[Any], asyncio.Future[Any]]] = set() for fut, val in futs.items(): diff --git a/libs/langgraph/langgraph/pregel/types.py b/libs/langgraph/langgraph/pregel/types.py index d34845483..4cca9946c 100644 --- a/libs/langgraph/langgraph/pregel/types.py +++ b/libs/langgraph/langgraph/pregel/types.py @@ -1,10 +1,28 @@ from collections import deque +from dataclasses import dataclass from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, Type, Union from langchain_core.runnables import Runnable, RunnableConfig from langgraph.checkpoint.base import CheckpointMetadata -from langgraph.constants import Interrupt + +All = Literal["*"] + +StreamMode = Literal["values", "updates", "debug", "messages", "custom"] +"""How the stream method should emit outputs. + +- 'values': Emit all values of the state for each step. +- 'updates': Emit only the node name(s) and updates + that were returned by the node(s) **after** each step. +- 'debug': Emit debug events for each step. +- 'messages': Emit LLM messages token-by-token. +- 'custom': Emit custom output `write: StreamWriter` kwarg of each node. +""" + +StreamWriter = Callable[[Any], None] +"""Callable that accepts a single argument and writes it to the output stream. +Always injected into nodes if requested as a keyword argument, but it's a no-op +when not using stream_mode="custom".""" def default_retry_on(exc: Exception) -> bool: @@ -63,6 +81,12 @@ class CachePolicy(NamedTuple): pass +@dataclass +class Interrupt: + value: Any + when: Literal["during"] = "during" + + class PregelTask(NamedTuple): id: str name: str @@ -105,20 +129,72 @@ class StateSnapshot(NamedTuple): """Tasks to execute in this step. If already attempted, may contain an error.""" -All = Literal["*"] - -StreamMode = Literal["values", "updates", "debug", "messages", "custom"] -"""How the stream method should emit outputs. - -- 'values': Emit all values of the state for each step. -- 'updates': Emit only the node name(s) and updates - that were returned by the node(s) **after** each step. -- 'debug': Emit debug events for each step. -- 'messages': Emit LLM messages token-by-token. -- 'custom': Emit custom output `write: StreamWriter` kwarg of each node. -""" - -StreamWriter = Callable[[Any], None] -"""Callable that accepts a single argument and writes it to the output stream. -Always injected into nodes if requested, -but it's a no-op when not using stream_mode="custom".""" +class Send: + """A message or packet to send to a specific node in the graph. + + The `Send` class is used within a `StateGraph`'s conditional edges to + dynamically invoke a node with a custom state at the next step. + + Importantly, the sent state can differ from the core graph's state, + allowing for flexible and dynamic workflow management. + + One such example is a "map-reduce" workflow where your graph invokes + the same node multiple times in parallel with different states, + before aggregating the results back into the main graph's state. + + Attributes: + node (str): The name of the target node to send the message to. + arg (Any): The state or message to send to the target node. + + Examples: + >>> from typing import Annotated + >>> import operator + >>> class OverallState(TypedDict): + ... subjects: list[str] + ... jokes: Annotated[list[str], operator.add] + ... + >>> from langgraph.constants import Send + >>> from langgraph.graph import END, START + >>> def continue_to_jokes(state: OverallState): + ... return [Send("generate_joke", {"subject": s}) for s in state['subjects']] + ... + >>> from langgraph.graph import StateGraph + >>> builder = StateGraph(OverallState) + >>> builder.add_node("generate_joke", lambda state: {"jokes": [f"Joke about {state['subject']}"]}) + >>> builder.add_conditional_edges(START, continue_to_jokes) + >>> builder.add_edge("generate_joke", END) + >>> graph = builder.compile() + >>> + >>> # Invoking with two subjects results in a generated joke for each + >>> graph.invoke({"subjects": ["cats", "dogs"]}) + {'subjects': ['cats', 'dogs'], 'jokes': ['Joke about cats', 'Joke about dogs']} + """ + + __slots__ = ("node", "arg") + + node: str + arg: Any + + def __init__(self, /, node: str, arg: Any) -> None: + """ + Initialize a new instance of the Send class. + + Args: + node (str): The name of the target node to send the message to. + arg (Any): The state or message to send to the target node. + """ + self.node = node + self.arg = arg + + def __hash__(self) -> int: + return hash((self.node, self.arg)) + + def __repr__(self) -> str: + return f"Send(node={self.node!r}, arg={self.arg!r})" + + def __eq__(self, value: object) -> bool: + return ( + isinstance(value, Send) + and self.node == value.node + and self.arg == value.arg + ) diff --git a/libs/langgraph/langgraph/pregel/utils.py b/libs/langgraph/langgraph/pregel/utils.py index c6dc064d3..3a29e5ed1 100644 --- a/libs/langgraph/langgraph/pregel/utils.py +++ b/libs/langgraph/langgraph/pregel/utils.py @@ -4,7 +4,7 @@ def get_new_channel_versions( previous_versions: ChannelVersions, current_versions: ChannelVersions ) -> ChannelVersions: - """Get new channel versions.""" + """Get subset of current_versions that are newer than previous_versions.""" if previous_versions: version_type = type(next(iter(current_versions.values()), None)) null_version = version_type() # type: ignore[misc] diff --git a/libs/langgraph/langgraph/pregel/validate.py b/libs/langgraph/langgraph/pregel/validate.py index 232014240..965fab54e 100644 --- a/libs/langgraph/langgraph/pregel/validate.py +++ b/libs/langgraph/langgraph/pregel/validate.py @@ -17,7 +17,7 @@ def validate_graph( ) -> None: for chan in channels: if chan in RESERVED: - raise ValueError(f"Channel names {RESERVED} are reserved") + raise ValueError(f"Channel names {chan} are reserved") subscribed_channels = set[str]() for name, node in nodes.items(): diff --git a/libs/langgraph/langgraph/pregel/write.py b/libs/langgraph/langgraph/pregel/write.py index 2adcab757..c2795c67c 100644 --- a/libs/langgraph/langgraph/pregel/write.py +++ b/libs/langgraph/langgraph/pregel/write.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio from typing import ( Any, Callable, @@ -22,30 +21,29 @@ TYPE_SEND = Callable[[Sequence[tuple[str, Any]]], None] R = TypeVar("R", bound=Runnable) - SKIP_WRITE = object() PASSTHROUGH = object() class ChannelWriteEntry(NamedTuple): channel: str + """Channel name to write to.""" value: Any = PASSTHROUGH + """Value to write, or PASSTHROUGH to use the input.""" skip_none: bool = False + """Whether to skip writing if the value is None.""" mapper: Optional[Callable] = None + """Function to transform the value before writing.""" class ChannelWrite(RunnableCallable): + """Implements th logic for sending writes to CONFIG_KEY_SEND. + Can be used as a runnable or as a static method to call imperatively.""" + writes: list[Union[ChannelWriteEntry, Send]] - """ - Sequence of write entries, each of which is a tuple of: - - channel name - - runnable to map input, or None to use the input, or any other value to use instead - - whether to skip writing if the mapped value is None - """ + """Sequence of write entries or Send objects to write.""" require_at_least_one_of: Optional[Sequence[str]] - """ - If defined, at least one of these channels must be written to. - """ + """If defined, at least one of these channels must be written to.""" def __init__( self, @@ -145,6 +143,7 @@ def do_write( @staticmethod def is_writer(runnable: Runnable) -> bool: + """Used by PregelNode to distinguish between writers and other runnables.""" return ( isinstance(runnable, ChannelWrite) or getattr(runnable, "_is_channel_writer", False) is True @@ -152,13 +151,9 @@ def is_writer(runnable: Runnable) -> bool: @staticmethod def register_writer(runnable: R) -> R: + """Used to mark a runnable as a writer, so that it can be detected by is_writer. + Instances of ChannelWrite are automatically marked as writers.""" # using object.__setattr__ to work around objects that override __setattr__ # eg. pydantic models and dataclasses object.__setattr__(runnable, "_is_channel_writer", True) return runnable - - -def _mk_future(val: Any) -> asyncio.Future: - fut: asyncio.Future[Any] = asyncio.Future() - fut.set_result(val) - return fut diff --git a/libs/langgraph/langgraph/version.py b/libs/langgraph/langgraph/version.py index 3368893c0..f5cb757f5 100644 --- a/libs/langgraph/langgraph/version.py +++ b/libs/langgraph/langgraph/version.py @@ -1,4 +1,4 @@ -"""Main entrypoint into package.""" +"""Exports package version.""" from importlib import metadata diff --git a/libs/langgraph/tests/test_tracing_interops.py b/libs/langgraph/tests/test_tracing_interops.py index 9bc8750b5..f1f0ad7fb 100644 --- a/libs/langgraph/tests/test_tracing_interops.py +++ b/libs/langgraph/tests/test_tracing_interops.py @@ -5,11 +5,14 @@ from unittest.mock import MagicMock import langsmith as ls +import pytest from langchain_core.runnables import RunnableConfig from langchain_core.tracers import LangChainTracer from langgraph.graph import StateGraph +pytestmark = pytest.mark.anyio + def _get_mock_client(**kwargs: Any) -> ls.Client: mock_session = MagicMock() @@ -76,7 +79,7 @@ async def child_node(state: State) -> State: child_builder = StateGraph(State) child_builder.add_node(child_node) child_builder.add_edge("__start__", "child_node") - child_graph = child_builder.compile() + child_graph = child_builder.compile().with_config(run_name="child_graph") parent_builder = StateGraph(State) parent_builder.add_node(parent_node) @@ -101,7 +104,7 @@ def get_posts(): # If the callbacks weren't propagated correctly, we'd # end up with broken dotted_orders parent_run = next(data for data in posts if data["name"] == "parent_node") - child_run = next(data for data in posts if data["name"] == "child_node") + child_run = next(data for data in posts if data["name"] == "child_graph") traceable_run = next(data for data in posts if data["name"] == "some_traceable") assert child_run["dotted_order"].startswith(traceable_run["dotted_order"])