Skip to content

Commit

Permalink
Add Runnable.get_graph() to get a graph representation of a Runnable (l…
Browse files Browse the repository at this point in the history
…angchain-ai#15040)

It can be drawn in ascii with Runnable.get_graph().draw()
  • Loading branch information
nfcampos authored Dec 22, 2023
1 parent aad3d8b commit 7d5800e
Show file tree
Hide file tree
Showing 12 changed files with 739 additions and 27 deletions.
14 changes: 14 additions & 0 deletions libs/core/langchain_core/beta/runnables/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
List,
Mapping,
Optional,
Sequence,
Type,
TypeVar,
Union,
Expand Down Expand Up @@ -163,6 +164,9 @@ class ContextGet(RunnableSerializable):

key: Union[str, List[str]]

def __str__(self) -> str:
return f"ContextGet({_print_keys(self.key)})"

@property
def ids(self) -> List[str]:
prefix = self.prefix + "/" if self.prefix else ""
Expand Down Expand Up @@ -243,6 +247,9 @@ def __init__(
prefix=prefix,
)

def __str__(self) -> str:
return f"ContextSet({_print_keys(list(self.keys.keys()))})"

@property
def ids(self) -> List[str]:
prefix = self.prefix + "/" if self.prefix else ""
Expand Down Expand Up @@ -345,3 +352,10 @@ def setter(
**kwargs: SetValue,
) -> ContextSet:
return ContextSet(_key, _value, prefix=self.prefix, **kwargs)


def _print_keys(keys: Union[str, Sequence[str]]) -> str:
if isinstance(keys, str):
return f"'{keys}'"
else:
return ", ".join(f"'{k}'" for k in keys)
58 changes: 58 additions & 0 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from langchain_core.runnables.fallbacks import (
RunnableWithFallbacks as RunnableWithFallbacksT,
)
from langchain_core.runnables.graph import Graph
from langchain_core.tracers.log_stream import RunLog, RunLogPatch
from langchain_core.tracers.root_listeners import Listener

Expand Down Expand Up @@ -352,6 +353,18 @@ class _Config:
},
)

def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
"""Return a graph representation of this runnable."""
from langchain_core.runnables.graph import Graph

graph = Graph()
input_node = graph.add_node(self.get_input_schema(config))
runnable_node = graph.add_node(self)
output_node = graph.add_node(self.get_output_schema(config))
graph.add_edge(input_node, runnable_node)
graph.add_edge(runnable_node, output_node)
return graph

def __or__(
self,
other: Union[
Expand Down Expand Up @@ -1447,6 +1460,26 @@ def config_specs(self) -> List[ConfigurableFieldSpec]:

return get_unique_config_specs(spec for spec, _ in all_specs)

def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
from langchain_core.runnables.graph import Graph

graph = Graph()
for step in self.steps:
current_last_node = graph.last_node()
step_graph = step.get_graph(config)
if step is not self.first:
step_graph.trim_first_node()
if step is not self.last:
step_graph.trim_last_node()
graph.extend(step_graph)
step_first_node = step_graph.first_node()
if not step_first_node:
raise ValueError(f"Runnable {step} has no first node")
if current_last_node:
graph.add_edge(current_last_node, step_first_node)

return graph

def __repr__(self) -> str:
return "\n| ".join(
repr(s) if i == 0 else indent_lines_after_first(repr(s), "| ")
Expand Down Expand Up @@ -1992,6 +2025,31 @@ def config_specs(self) -> List[ConfigurableFieldSpec]:
spec for step in self.steps.values() for spec in step.config_specs
)

def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
from langchain_core.runnables.graph import Graph

graph = Graph()
input_node = graph.add_node(self.get_input_schema(config))
output_node = graph.add_node(self.get_output_schema(config))
for step in self.steps.values():
step_graph = step.get_graph()
step_graph.trim_first_node()
step_graph.trim_last_node()
if not step_graph:
graph.add_edge(input_node, output_node)
else:
graph.extend(step_graph)
step_first_node = step_graph.first_node()
if not step_first_node:
raise ValueError(f"Runnable {step} has no first node")
step_last_node = step_graph.last_node()
if not step_last_node:
raise ValueError(f"Runnable {step} has no last node")
graph.add_edge(input_node, step_first_node)
graph.add_edge(step_last_node, output_node)

return graph

def __repr__(self) -> str:
map_for_repr = ",\n ".join(
f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}"
Expand Down
133 changes: 133 additions & 0 deletions libs/core/langchain_core/runnables/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, NamedTuple, Optional, Type, Union
from uuid import uuid4

from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable
from langchain_core.runnables.graph_draw import draw


class Edge(NamedTuple):
source: str
target: str


class Node(NamedTuple):
id: str
data: Union[Type[BaseModel], Runnable]


@dataclass
class Graph:
nodes: Dict[str, Node] = field(default_factory=dict)
edges: List[Edge] = field(default_factory=list)

def __bool__(self) -> bool:
return bool(self.nodes)

def next_id(self) -> str:
return uuid4().hex

def add_node(self, data: Union[Type[BaseModel], Runnable]) -> Node:
"""Add a node to the graph and return it."""
node = Node(id=self.next_id(), data=data)
self.nodes[node.id] = node
return node

def remove_node(self, node: Node) -> None:
"""Remove a node from the graphm and all edges connected to it."""
self.nodes.pop(node.id)
self.edges = [
edge
for edge in self.edges
if edge.source != node.id and edge.target != node.id
]

def add_edge(self, source: Node, target: Node) -> Edge:
"""Add an edge to the graph and return it."""
if source.id not in self.nodes:
raise ValueError(f"Source node {source.id} not in graph")
if target.id not in self.nodes:
raise ValueError(f"Target node {target.id} not in graph")
edge = Edge(source=source.id, target=target.id)
self.edges.append(edge)
return edge

def extend(self, graph: Graph) -> None:
"""Add all nodes and edges from another graph.
Note this doesn't check for duplicates, nor does it connect the graphs."""
self.nodes.update(graph.nodes)
self.edges.extend(graph.edges)

def first_node(self) -> Optional[Node]:
"""Find the single node that is not a target of any edge.
If there is no such node, or there are multiple, return None.
When drawing the graph this node would be the origin."""
targets = {edge.target for edge in self.edges}
found: List[Node] = []
for node in self.nodes.values():
if node.id not in targets:
found.append(node)
return found[0] if len(found) == 1 else None

def last_node(self) -> Optional[Node]:
"""Find the single node that is not a source of any edge.
If there is no such node, or there are multiple, return None.
When drawing the graph this node would be the destination.
"""
sources = {edge.source for edge in self.edges}
found: List[Node] = []
for node in self.nodes.values():
if node.id not in sources:
found.append(node)
return found[0] if len(found) == 1 else None

def trim_first_node(self) -> None:
"""Remove the first node if it exists and has a single outgoing edge,
ie. if removing it would not leave the graph without a "first" node."""
first_node = self.first_node()
if first_node:
if (
len(self.nodes) == 1
or len([edge for edge in self.edges if edge.source == first_node.id])
== 1
):
self.remove_node(first_node)

def trim_last_node(self) -> None:
"""Remove the last node if it exists and has a single incoming edge,
ie. if removing it would not leave the graph without a "last" node."""
last_node = self.last_node()
if last_node:
if (
len(self.nodes) == 1
or len([edge for edge in self.edges if edge.target == last_node.id])
== 1
):
self.remove_node(last_node)

def draw_ascii(self) -> str:
def node_data(node: Node) -> str:
if isinstance(node.data, Runnable):
try:
data = str(node.data)
if (
data.startswith("<")
or data[0] != data[0].upper()
or len(data.splitlines()) > 1
):
data = node.data.__class__.__name__
elif len(data) > 36:
data = data[:36] + "..."
except Exception:
data = node.data.__class__.__name__
else:
data = node.data.__name__
return data

return draw(
{node.id: node_data(node) for node in self.nodes.values()},
[(edge.source, edge.target) for edge in self.edges],
)
Loading

0 comments on commit 7d5800e

Please sign in to comment.