Skip to content

Commit

Permalink
Add more comments and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Sep 22, 2024
1 parent 41537c5 commit 89fb8d5
Show file tree
Hide file tree
Showing 17 changed files with 274 additions and 164 deletions.
167 changes: 70 additions & 97 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
@@ -1,133 +1,106 @@
from dataclasses import dataclass
from types import MappingProxyType
from typing import Any, Literal, Mapping
from typing import Any, Mapping

from langgraph.pregel.types import Interrupt, Send # noqa: F401

# Interrupt, Send re-exported for backwards compatibility

# --- Empty read-only containers ---
EMPTY_MAP: Mapping[str, Any] = MappingProxyType({})
EMPTY_SEQ: tuple[str, ...] = tuple()

# --- Reserved write keys ---
INPUT = "__input__"
# for values passed as input to the graph
INTERRUPT = "__interrupt__"
# for dynamic interrupts raised by nodes
ERROR = "__error__"
# for errors raised by nodes
NO_WRITES = "__no_writes__"
# marker to signal node didn't write anything
SCHEDULED = "__scheduled__"
# marker to signal node was scheduled (in distributed mode)
TASKS = "__pregel_tasks"
# for Send objects returned by nodes/edges, corresponds to PUSH below
START = "__start__"
# marker for the first (maybe virtual) node in graph-style Pregel
END = "__end__"
# marker for the last (maybe virtual) node in graph-style Pregel

# --- Reserved config.configurable keys ---
CONFIG_KEY_SEND = "__pregel_send"
# holds the `write` function that accepts writes to state/edges/reserved keys
CONFIG_KEY_READ = "__pregel_read"
# holds the `read` function that returns a copy of the current state
CONFIG_KEY_CHECKPOINTER = "__pregel_checkpointer"
# holds a `BaseCheckpointSaver` passed from parent graph to child graphs
CONFIG_KEY_STREAM = "__pregel_stream"
# holds a `StreamProtocol` passed from parent graph to child graphs
CONFIG_KEY_STREAM_WRITER = "__pregel_stream_writer"
# holds a `StreamWriter` for stream_mode=custom
CONFIG_KEY_STORE = "__pregel_store"
# holds a `BaseStore` made available to managed values
CONFIG_KEY_RESUMING = "__pregel_resuming"
# holds a boolean indicating if subgraphs should resume from a previous checkpoint
CONFIG_KEY_TASK_ID = "__pregel_task_id"
# holds the task ID for the current task
CONFIG_KEY_DEDUPE_TASKS = "__pregel_dedupe_tasks"
# holds a boolean indicating if tasks should be deduplicated (for distributed mode)
CONFIG_KEY_ENSURE_LATEST = "__pregel_ensure_latest"
# holds a boolean indicating whether to assert the requested checkpoint is the latest
# (for distributed mode)
CONFIG_KEY_DELEGATE = "__pregel_delegate"
# this one part of public API so more readable
# holds a boolean indicating whether to delegate subgraphs (for distributed mode)
CONFIG_KEY_CHECKPOINT_MAP = "checkpoint_map"
INTERRUPT = "__interrupt__"
ERROR = "__error__"
NO_WRITES = "__no_writes__"
SCHEDULED = "__scheduled__"
TASKS = "__pregel_tasks" # for backwards compat, this is the original name of PUSH
# holds a mapping of checkpoint_ns -> checkpoint_id for parent graphs
CONFIG_KEY_CHECKPOINT_ID = "checkpoint_id"
# holds the current checkpoint_id, if any
CONFIG_KEY_CHECKPOINT_NS = "checkpoint_ns"
# holds the current checkpoint_ns, "" for root graph

# --- Other constants ---
PUSH = "__pregel_push"
# denotes push-style tasks, ie. those created by Send objects
PULL = "__pregel_pull"
# denotes pull-style tasks, ie. those triggered by edges
RUNTIME_PLACEHOLDER = "__pregel_runtime_placeholder__"
# placeholder for managed values replaced at runtime
TAG_HIDDEN = "langsmith:hidden"
# tag to hide a node/edge from certain tracing/streaming environments
NS_SEP = "|"
# for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph)
NS_END = ":"
# for checkpoint_ns, for each level, separates the namespace from the task_id

RESERVED = {
SCHEDULED,
# reserved write keys
INPUT,
INTERRUPT,
ERROR,
NO_WRITES,
SCHEDULED,
TASKS,
PUSH,
PULL,
# reserved config.configurable keys
CONFIG_KEY_SEND,
CONFIG_KEY_READ,
CONFIG_KEY_CHECKPOINTER,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_STREAM,
CONFIG_KEY_STREAM_WRITER,
CONFIG_KEY_STORE,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_RESUMING,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_DEDUPE_TASKS,
CONFIG_KEY_ENSURE_LATEST,
CONFIG_KEY_DELEGATE,
INPUT,
CONFIG_KEY_CHECKPOINT_MAP,
CONFIG_KEY_CHECKPOINT_ID,
CONFIG_KEY_CHECKPOINT_NS,
# other constants
PUSH,
PULL,
RUNTIME_PLACEHOLDER,
TAG_HIDDEN,
NS_SEP,
NS_END,
}
TAG_HIDDEN = "langsmith:hidden"

START = "__start__"
END = "__end__"

NS_SEP = "|"
NS_END = ":"

EMPTY_MAP: Mapping[str, Any] = MappingProxyType({})


class Send:
"""A message or packet to send to a specific node in the graph.
The `Send` class is used within a `StateGraph`'s conditional edges to
dynamically invoke a node with a custom state at the next step.
Importantly, the sent state can differ from the core graph's state,
allowing for flexible and dynamic workflow management.
One such example is a "map-reduce" workflow where your graph invokes
the same node multiple times in parallel with different states,
before aggregating the results back into the main graph's state.
Attributes:
node (str): The name of the target node to send the message to.
arg (Any): The state or message to send to the target node.
Examples:
>>> from typing import Annotated
>>> import operator
>>> class OverallState(TypedDict):
... subjects: list[str]
... jokes: Annotated[list[str], operator.add]
...
>>> from langgraph.constants import Send
>>> from langgraph.graph import END, START
>>> def continue_to_jokes(state: OverallState):
... return [Send("generate_joke", {"subject": s}) for s in state['subjects']]
...
>>> from langgraph.graph import StateGraph
>>> builder = StateGraph(OverallState)
>>> builder.add_node("generate_joke", lambda state: {"jokes": [f"Joke about {state['subject']}"]})
>>> builder.add_conditional_edges(START, continue_to_jokes)
>>> builder.add_edge("generate_joke", END)
>>> graph = builder.compile()
>>>
>>> # Invoking with two subjects results in a generated joke for each
>>> graph.invoke({"subjects": ["cats", "dogs"]})
{'subjects': ['cats', 'dogs'], 'jokes': ['Joke about cats', 'Joke about dogs']}
"""

node: str
arg: Any

def __init__(self, /, node: str, arg: Any) -> None:
"""
Initialize a new instance of the Send class.
Args:
node (str): The name of the target node to send the message to.
arg (Any): The state or message to send to the target node.
"""
self.node = node
self.arg = arg

def __hash__(self) -> int:
return hash((self.node, self.arg))

def __repr__(self) -> str:
return f"Send(node={self.node!r}, arg={self.arg!r})"

def __eq__(self, value: object) -> bool:
return (
isinstance(value, Send)
and self.node == value.node
and self.arg == value.arg
)


@dataclass
class Interrupt:
value: Any
when: Literal["during"] = "during"
25 changes: 9 additions & 16 deletions libs/langgraph/langgraph/errors.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Sequence

from langgraph.checkpoint.base import EmptyChannelError
from langgraph.checkpoint.base import EmptyChannelError # noqa: F401
from langgraph.constants import Interrupt

# EmptyChannelError re-exported for backwards compatibility


class GraphRecursionError(RecursionError):
"""Raised when the graph has exhausted the maximum number of steps.
Expand All @@ -24,13 +26,14 @@ class GraphRecursionError(RecursionError):


class InvalidUpdateError(Exception):
"""Raised when attempting to update a channel with an invalid sequence of updates."""
"""Raised when attempting to update a channel with an invalid set of updates."""

pass


class GraphInterrupt(Exception):
"""Raised when a subgraph is interrupted."""
"""Raised when a subgraph is interrupted, supressed by the root graph.

Check failure on line 35 in libs/langgraph/langgraph/errors.py

View workflow job for this annotation

GitHub Actions / (Check for spelling errors)

supressed ==> suppressed
Never raised directly, or surfaced to the user."""

def __init__(self, interrupts: Sequence[Interrupt] = ()) -> None:
super().__init__(interrupts)
Expand All @@ -44,7 +47,7 @@ def __init__(self, value: Any) -> None:


class GraphDelegate(Exception):
"""Raised when a graph is delegated."""
"""Raised when a graph is delegated (for distributed mode)."""

def __init__(self, *args: dict[str, Any]) -> None:
super().__init__(*args)
Expand All @@ -57,22 +60,12 @@ class EmptyInputError(Exception):


class TaskNotFound(Exception):
"""Raised when the executor is unable to find a task."""
"""Raised when the executor is unable to find a task (for distributed mode)."""

pass


class CheckpointNotLatest(Exception):
"""Raised when the checkpoint is not the latest version."""
"""Raised when the checkpoint is not the latest version (for distributed mode)."""

pass


__all__ = [
"GraphRecursionError",
"InvalidUpdateError",
"GraphInterrupt",
"NodeInterrupt",
"EmptyInputError",
"EmptyChannelError",
]
25 changes: 23 additions & 2 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CONFIG_KEY_READ,
CONFIG_KEY_SEND,
CONFIG_KEY_TASK_ID,
EMPTY_SEQ,
INTERRUPT,
NO_WRITES,
NS_END,
Expand All @@ -55,10 +56,11 @@

GetNextVersion = Callable[[Optional[V], BaseChannel], V]

EMPTY_SEQ: tuple[str, ...] = tuple()


class WritesProtocol(Protocol):
"""Protocol for objects containing writes to be applied to checkpoint.
Implemented by PregelTaskWrites and PregelExecutableTask."""

@property
def name(self) -> str: ...

Expand All @@ -70,6 +72,9 @@ def triggers(self) -> Sequence[str]: ...


class PregelTaskWrites(NamedTuple):
"""Simplest implementation of WritesProtocol, for usage with writes that
don't originate from a runnable task, eg. graph input, update_state, etc."""

name: str
writes: Sequence[tuple[str, Any]]
triggers: Sequence[str]
Expand All @@ -80,6 +85,7 @@ def should_interrupt(
interrupt_nodes: Union[All, Sequence[str]],
tasks: Iterable[PregelExecutableTask],
) -> list[PregelExecutableTask]:
"""Check if the graph should be interrupted based on current state."""
version_type = type(next(iter(checkpoint["channel_versions"].values()), None))
null_version = version_type() # type: ignore[misc]
seen = checkpoint["versions_seen"].get(INTERRUPT, {})
Expand Down Expand Up @@ -117,6 +123,9 @@ def local_read(
select: Union[list[str], str],
fresh: bool = False,
) -> Union[dict[str, Any], Any]:
"""Function injected under CONFIG_KEY_READ in task config, to read current state.
Used by conditional edges to read a copy of the state with reflecting the writes
from that node only."""
if isinstance(select, str):
managed_keys = []
for c, _ in task.writes:
Expand Down Expand Up @@ -153,6 +162,8 @@ def local_write(
managed: ManagedValueMapping,
writes: Sequence[tuple[str, Any]],
) -> None:
"""Function injected under CONFIG_KEY_SEND in task config, to write to channels.
Validates writes and forwards them to `commit` function."""
for chan, value in writes:
if chan == TASKS:
if not isinstance(value, Send):
Expand All @@ -169,6 +180,7 @@ def local_write(


def increment(current: Optional[int], channel: BaseChannel) -> int:
"""Default channel versioning function, increments the current int version."""
return current + 1 if current is not None else 1


Expand All @@ -178,6 +190,9 @@ def apply_writes(
tasks: Iterable[WritesProtocol],
get_next_version: Optional[GetNextVersion],
) -> dict[str, list[Any]]:
"""Apply writes from a set of tasks (usually the tasks from a Pregel step)
to the checkpoint and channels, and return managed values writes to be applied
externally."""
# update seen versions
for task in tasks:
checkpoint["versions_seen"].setdefault(task.name, {}).update(
Expand Down Expand Up @@ -297,6 +312,9 @@ def prepare_next_tasks(
checkpointer: Optional[BaseCheckpointSaver] = None,
manager: Union[None, ParentRunManager, AsyncParentRunManager] = None,
) -> Union[dict[str, PregelTask], dict[str, PregelExecutableTask]]:
"""Prepare the set of tasks that will make up the next Pregel step.
This is the union of all PUSH tasks (Sends) and PULL tasks (nodes triggered
by edges)."""
tasks: dict[str, Union[PregelTask, PregelExecutableTask]] = {}
# Consume pending packets
for idx, _ in enumerate(checkpoint["pending_sends"]):
Expand Down Expand Up @@ -348,6 +366,8 @@ def prepare_single_task(
checkpointer: Optional[BaseCheckpointSaver] = None,
manager: Union[None, ParentRunManager, AsyncParentRunManager] = None,
) -> Union[None, PregelTask, PregelExecutableTask]:
"""Prepares a single task for the next Pregel step, given a task path, which
uniquely identifies a PUSH or PULL task within the graph."""
checkpoint_id = UUID(checkpoint["id"]).bytes
configurable = config.get("configurable", {})
parent_ns = configurable.get("checkpoint_ns", "")
Expand Down Expand Up @@ -568,6 +588,7 @@ def _proc_input(
*,
for_execution: bool,
) -> Iterator[Any]:
"""Prepare input for a PULL task, based on the process's channels and triggers."""
# If all trigger channels subscribed by this process are not empty
# then invoke the process with the values of all non-empty channels
if isinstance(proc.channels, dict):
Expand Down
Loading

0 comments on commit 89fb8d5

Please sign in to comment.