Skip to content

Commit

Permalink
Remove runtime value substitution
Browse files Browse the repository at this point in the history
- doesn't work as expected
  • Loading branch information
nfcampos committed Sep 22, 2024
1 parent 84a45d1 commit d6015e7
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 111 deletions.
3 changes: 0 additions & 3 deletions libs/langgraph/langgraph/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@
# denotes push-style tasks, ie. those created by Send objects
PULL = "__pregel_pull"
# denotes pull-style tasks, ie. those triggered by edges
RUNTIME_PLACEHOLDER = "__pregel_runtime_placeholder__"
# placeholder for managed values replaced at runtime
NS_SEP = "|"
# for checkpoint_ns, separates each level (ie. graph|subgraph|subsubgraph)
NS_END = ":"
Expand Down Expand Up @@ -103,7 +101,6 @@
# other constants
PUSH,
PULL,
RUNTIME_PLACEHOLDER,
NS_SEP,
NS_END,
}
48 changes: 1 addition & 47 deletions libs/langgraph/langgraph/managed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,11 @@
from langchain_core.runnables import RunnableConfig
from typing_extensions import Self, TypeGuard

from langgraph.constants import RUNTIME_PLACEHOLDER

V = TypeVar("V")
U = TypeVar("U")


class ManagedValue(ABC, Generic[V]):
runtime: bool = False
"""Whether the managed value is always created at runtime, ie. never stored."""

def __init__(self, config: RunnableConfig) -> None:
self.config = config

Expand Down Expand Up @@ -105,45 +100,4 @@ def is_writable_managed_value(value: Any) -> TypeGuard[Type[WritableManagedValue
ChannelTypePlaceholder = object()


class ManagedValueMapping(dict[str, ManagedValue]):
def replace_runtime_values(
self, step: int, values: Union[dict[str, Any], Any]
) -> None:
if not self or not values:
return
if all(not mv.runtime for mv in self.values()):
return
if isinstance(values, dict):
for key, value in values.items():
for chan, mv in self.items():
if mv.runtime and mv(step) is value:
values[key] = {RUNTIME_PLACEHOLDER: chan}
elif hasattr(values, "__dir__") and callable(values.__dir__):
for key in dir(values):
try:
value = getattr(values, key)
for chan, mv in self.items():
if mv.runtime and mv(step) is value:
setattr(values, key, {RUNTIME_PLACEHOLDER: chan})
except AttributeError:
pass

def replace_runtime_placeholders(
self, step: int, values: Union[dict[str, Any], Any]
) -> None:
if not self or not values:
return
if all(not mv.runtime for mv in self.values()):
return
if isinstance(values, dict):
for key, value in values.items():
if isinstance(value, dict) and RUNTIME_PLACEHOLDER in value:
values[key] = self[value[RUNTIME_PLACEHOLDER]](step)
elif hasattr(values, "__dir__") and callable(values.__dir__):
for key in dir(values):
try:
value = getattr(values, key)
if isinstance(value, dict) and RUNTIME_PLACEHOLDER in value:
setattr(values, key, self[value[RUNTIME_PLACEHOLDER]](step))
except AttributeError:
pass
ManagedValueMapping = dict[str, ManagedValue]
10 changes: 2 additions & 8 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,11 +852,8 @@ def update_state(
# deque.extend is thread-safe
CONFIG_KEY_SEND: partial(
local_write,
step + 1,
writes.extend,
self.nodes,
channels,
managed,
self.nodes.keys(),
),
CONFIG_KEY_READ: partial(
local_read,
Expand Down Expand Up @@ -1001,11 +998,8 @@ async def aupdate_state(
# deque.extend is thread-safe
CONFIG_KEY_SEND: partial(
local_write,
step + 1,
writes.extend,
self.nodes,
channels,
managed,
self.nodes.keys(),
),
CONFIG_KEY_READ: partial(
local_read,
Expand Down
26 changes: 5 additions & 21 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,27 +155,18 @@ def local_read(


def local_write(
step: int,
commit: Callable[[Sequence[tuple[str, Any]]], None],
processes: Mapping[str, PregelNode],
channels: Mapping[str, BaseChannel],
managed: ManagedValueMapping,
process_keys: Iterable[str],
writes: Sequence[tuple[str, Any]],
) -> None:
"""Function injected under CONFIG_KEY_SEND in task config, to write to channels.
Validates writes and forwards them to `commit` function."""
for chan, value in writes:
if chan == TASKS:
if not isinstance(value, Send):
raise InvalidUpdateError(
f"Invalid packet type, expected Packet, got {value}"
)
if value.node not in processes:
raise InvalidUpdateError(f"Expected Send, got {value}")
if value.node not in process_keys:
raise InvalidUpdateError(f"Invalid node name {value.node} in packet")
# replace any runtime values with placeholders
managed.replace_runtime_values(step, value.arg)
elif chan not in channels and chan not in managed:
logger.warning(f"Skipping write for channel '{chan}' which has no readers")
commit(writes)


Expand Down Expand Up @@ -411,7 +402,6 @@ def prepare_single_task(
if for_execution:
proc = processes[packet.node]
if node := proc.node:
managed.replace_runtime_placeholders(step, packet.arg)
if proc.metadata:
metadata.update(proc.metadata)
writes: deque[tuple[str, Any]] = deque()
Expand All @@ -433,11 +423,8 @@ def prepare_single_task(
# deque.extend is thread-safe
CONFIG_KEY_SEND: partial(
local_write,
step,
writes.extend,
processes,
channels,
managed,
processes.keys(),
),
CONFIG_KEY_READ: partial(
local_read,
Expand Down Expand Up @@ -543,11 +530,8 @@ def prepare_single_task(
# deque.extend is thread-safe
CONFIG_KEY_SEND: partial(
local_write,
step,
writes.extend,
processes,
channels,
managed,
processes.keys(),
),
CONFIG_KEY_READ: partial(
local_read,
Expand Down
19 changes: 5 additions & 14 deletions libs/langgraph/tests/test_pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4440,25 +4440,16 @@ def should_continue(data: AgentState) -> str:
), "nodes can pass extra data to their cond edges, which isn't saved in state"
# Logic to decide whether to continue in the loop or exit
if tool_calls := data["messages"][-1].tool_calls:
return [
Send("tools", {"call": tool_call, "my_session": data["session"]})
for tool_call in tool_calls
]
return [Send("tools", tool_call) for tool_call in tool_calls]
else:
return END

class ToolInput(TypedDict):
call: ToolCall
my_session: httpx.Client

def tools_node(input: ToolInput, config: RunnableConfig) -> AgentState:
assert isinstance(input["my_session"], httpx.Client)
tool_call = input["call"]
time.sleep(tool_call["args"].get("idx", 0) / 10)
output = tools_by_name[tool_call["name"]].invoke(tool_call["args"], config)
def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState:
time.sleep(input["args"].get("idx", 0) / 10)
output = tools_by_name[input["name"]].invoke(input["args"], config)
return {
"messages": ToolMessage(
content=output, name=tool_call["name"], tool_call_id=tool_call["id"]
content=output, name=input["name"], tool_call_id=input["id"]
)
}

Expand Down
23 changes: 5 additions & 18 deletions libs/langgraph/tests/test_pregel_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4204,12 +4204,6 @@ def search_api(query: str) -> str:
]


# defined outside to allow deserializer to see it
class ToolInput(BaseModel, arbitrary_types_allowed=True):
call: ToolCall
my_session: httpx.AsyncClient


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_ASYNC)
async def test_state_graph_packets(checkpointer_name: str) -> None:
from langchain_core.language_models.fake_chat_models import (
Expand Down Expand Up @@ -4273,23 +4267,16 @@ def should_continue(data: AgentState) -> str:
assert isinstance(data["session"], httpx.AsyncClient)
# Logic to decide whether to continue in the loop or exit
if tool_calls := data["messages"][-1].tool_calls:
return [
Send("tools", ToolInput(call=tool_call, my_session=data["session"]))
for tool_call in tool_calls
]
return [Send("tools", tool_call) for tool_call in tool_calls]
else:
return END

async def tools_node(input: ToolInput, config: RunnableConfig) -> AgentState:
assert isinstance(input.my_session, httpx.AsyncClient)
tool_call = input.call
await asyncio.sleep(tool_call["args"].get("idx", 0) / 10)
output = await tools_by_name[tool_call["name"]].ainvoke(
tool_call["args"], config
)
async def tools_node(input: ToolCall, config: RunnableConfig) -> AgentState:
await asyncio.sleep(input["args"].get("idx", 0) / 10)
output = await tools_by_name[input["name"]].ainvoke(input["args"], config)
return {
"messages": ToolMessage(
content=output, name=tool_call["name"], tool_call_id=tool_call["id"]
content=output, name=input["name"], tool_call_id=input["id"]
)
}

Expand Down

0 comments on commit d6015e7

Please sign in to comment.