Skip to content

Commit

Permalink
Performance improvements in core library (#1683)
Browse files Browse the repository at this point in the history
* Performance improvements in core library

- Avoid creating new callback manager when received one as arg
- Avoid looking for config when already received one as arg
- Avoid copies of values in ensure_config/merge_configs
- Implement version of ensure_config that accepts multiple configs (avoids calling merge_configs first)
- Avoid calling merge_configs when we only need to attach extra tags/metadata

* Fix

* Fix

* Try again

* Debug ci job

* Fix

* Try again

* Try again

* Try again

* Some more variations

* Attach annotation to first changed file

* Fix

* Re-enable benchmarks
  • Loading branch information
nfcampos authored Sep 12, 2024
1 parent 2ef551f commit 8f87cd1
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 85 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/_integration_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4
- name: Get changed files
id: changed-files
uses: Ana06/get-changed-files@v2.2.0
uses: Ana06/get-changed-files@v2.3.0
with:
filter: "libs/cli/**"
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
- uses: actions/checkout@v4
- name: Get changed files
id: changed-files
uses: Ana06/get-changed-files@v2.2.0
uses: Ana06/get-changed-files@v2.3.0
with:
filter: "${{ inputs.working-directory }}/**"
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
Expand Down
44 changes: 36 additions & 8 deletions .github/workflows/bench.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,25 @@ jobs:
- name: Install dependencies
run: poetry install --with dev
- name: Run benchmarks
run: make benchmark
run: OUTPUT=out/benchmark-baseline.json make -s benchmark
- name: Upload benchmark baseline
uses: actions/upload-artifact@v4
with:
name: benchmark-baseline.json
path: libs/langgraph/out/benchmark.json
compare:
name: benchmark-baseline
path: libs/langgraph/out/benchmark-baseline.json
benchmark:
runs-on: ubuntu-latest
defaults:
run:
working-directory: libs/langgraph
needs: [baseline]
steps:
- uses: actions/checkout@v4
- id: files
name: Get changed files
uses: Ana06/[email protected]
with:
format: json
- name: Set up Python 3.11 + Poetry ${{ env.POETRY_VERSION }}
uses: "./.github/actions/poetry_setup"
with:
Expand All @@ -50,13 +55,36 @@ jobs:
- name: Install dependencies
run: poetry install --with dev
- name: Run benchmarks
run: make benchmark
id: benchmark
run: |
{
echo 'OUTPUT<<EOF'
make -s benchmark
echo EOF
} >> "$GITHUB_OUTPUT"
- name: Download benchmark baseline
uses: actions/download-artifact@v4
with:
name: benchmark-baseline.json
path: libs/langgraph/out
merge-multiple: true
- name: Compare benchmarks
run: poetry run pyperf compare_to out/benchmark-baseline.json out/benchmark.json --table --group-by-speed >> $GITHUB_OUTPUT
id: compare
run: |
{
echo 'OUTPUT<<EOF'
poetry run pyperf compare_to out/benchmark-baseline.json out/benchmark.json --table --group-by-speed
echo EOF
} >> "$GITHUB_OUTPUT"
- name: Annotation
run: echo "::notice file=libs/langgraph/bench/__main__.py::$GITHUB_OUTPUT"
uses: actions/github-script@v7
with:
script: |
const file = JSON.parse(`${{ steps.files.outputs.added_modified_renamed }}`)[0]
core.notice(`${{ steps.benchmark.outputs.OUTPUT }}`, {
title: 'Benchmark results',
file,
})
core.notice(`${{ steps.compare.outputs.OUTPUT }}`, {
title: 'Comparison against main',
file,
})
9 changes: 4 additions & 5 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from langchain_core.runnables.base import Input, Output
from langchain_core.runnables.config import (
RunnableConfig,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
)
Expand Down Expand Up @@ -86,6 +85,7 @@
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.utils.config import (
ensure_config,
merge_configs,
patch_checkpoint_map,
patch_config,
Expand Down Expand Up @@ -1156,7 +1156,7 @@ def output() -> Iterator:
else:
yield payload

config = ensure_config(merge_configs(self.config, config))
config = ensure_config(self.config, config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
None,
Expand Down Expand Up @@ -1337,7 +1337,7 @@ def output() -> Iterator:
else:
yield payload

config = ensure_config(merge_configs(self.config, config))
config = ensure_config(self.config, config)
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
None,
Expand Down Expand Up @@ -1402,8 +1402,7 @@ def output() -> Iterator:
# channel updates from step N are only visible in step N+1
# channels are guaranteed to be immutable for the duration of the step,
# with channel updates applied only at the transition between steps
while await asyncio.to_thread(
loop.tick,
while loop.tick(
input_keys=self.input_channels,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
Expand Down
31 changes: 17 additions & 14 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from langgraph.pregel.types import All, PregelExecutableTask, PregelTask
from langgraph.utils.config import merge_configs, patch_config

EMPTY_SEQ = tuple()


class WritesProtocol(Protocol):
name: str
Expand Down Expand Up @@ -78,7 +80,10 @@ def should_interrupt(
task
for task in tasks
if (
(not task.config or TAG_HIDDEN not in task.config.get("tags"))
(
not task.config
or TAG_HIDDEN not in task.config.get("tags", EMPTY_SEQ)
)
if interrupt_nodes == "*"
else task.name in interrupt_nodes
)
Expand Down Expand Up @@ -182,11 +187,10 @@ def apply_writes(
for chan in task.triggers
if chan not in RESERVED and chan in channels
}:
if channels[chan].consume():
if get_next_version is not None:
checkpoint["channel_versions"][chan] = get_next_version(
max_version, channels[chan]
)
if channels[chan].consume() and get_next_version is not None:
checkpoint["channel_versions"][chan] = get_next_version(
max_version, channels[chan]
)

# clear pending sends
if checkpoint["pending_sends"]:
Expand Down Expand Up @@ -216,8 +220,7 @@ def apply_writes(
updated_channels: set[str] = set()
for chan, vals in pending_writes_by_channel.items():
if chan in channels:
updated = channels[chan].update(vals)
if updated and get_next_version is not None:
if channels[chan].update(vals) and get_next_version is not None:
checkpoint["channel_versions"][chan] = get_next_version(
max_version, channels[chan]
)
Expand Down Expand Up @@ -370,6 +373,8 @@ def prepare_single_task(
proc = processes[packet.node]
if node := proc.node:
managed.replace_runtime_placeholders(step, packet.arg)
if proc.metadata:
metadata.update(proc.metadata)
writes = deque()
task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
return PregelExecutableTask(
Expand All @@ -379,9 +384,7 @@ def prepare_single_task(
writes,
patch_config(
merge_configs(
config,
processes[packet.node].config,
{"metadata": metadata},
config, {"metadata": metadata, "tags": proc.tags}
),
run_name=packet.node,
callbacks=(
Expand Down Expand Up @@ -478,6 +481,8 @@ def prepare_single_task(

if for_execution:
if node := proc.node:
if proc.metadata:
metadata.update(proc.metadata)
writes = deque()
task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
return PregelExecutableTask(
Expand All @@ -487,9 +492,7 @@ def prepare_single_task(
writes,
patch_config(
merge_configs(
config,
proc.config,
{"metadata": metadata},
config, {"metadata": metadata, "tags": proc.tags}
),
run_name=name,
callbacks=(
Expand Down
5 changes: 4 additions & 1 deletion libs/langgraph/langgraph/pregel/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def __radd__(self, other: dict[str, Any]) -> "AddableUpdatesDict":
raise TypeError("AddableUpdatesDict does not support right-side addition")


EMPTY_SEQ = tuple()


def map_output_updates(
output_channels: Union[str, Sequence[str]],
tasks: list[tuple[PregelExecutableTask, Sequence[tuple[str, Any]]]],
Expand All @@ -106,7 +109,7 @@ def map_output_updates(
output_tasks = [
(t, ww)
for t, ww in tasks
if (not t.config or TAG_HIDDEN not in t.config.get("tags"))
if (not t.config or TAG_HIDDEN not in t.config.get("tags", EMPTY_SEQ))
and ww[0][0] != ERROR
and ww[0][0] != INTERRUPT
]
Expand Down
7 changes: 6 additions & 1 deletion libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,9 @@ def _output_writes(
self, task_id: str, writes: Sequence[tuple[str, Any]], *, cached: bool = False
) -> None:
if task := self.tasks.get(task_id):
if task.config is not None and TAG_HIDDEN in task.config.get("tags"):
if task.config is not None and TAG_HIDDEN in task.config.get(
"tags", EMPTY_SEQ
):
return
if writes[0][0] != ERROR and writes[0][0] != INTERRUPT:
self._emit(
Expand Down Expand Up @@ -806,3 +808,6 @@ async def __aexit__(
return await asyncio.shield(
self.stack.__aexit__(exc_type, exc_value, traceback)
)


EMPTY_SEQ = tuple()
30 changes: 20 additions & 10 deletions libs/langgraph/langgraph/pregel/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ class PregelNode(Runnable):

retry_policy: Optional[RetryPolicy]

config: RunnableConfig
tags: Optional[Sequence[str]]

metadata: Optional[Mapping[str, Any]]

def __init__(
self,
Expand All @@ -133,17 +135,15 @@ def __init__(
metadata: Optional[Mapping[str, Any]] = None,
bound: Optional[Runnable[Any, Any]] = None,
retry_policy: Optional[RetryPolicy] = None,
config: Optional[RunnableConfig] = None,
) -> None:
self.channels = channels
self.triggers = list(triggers)
self.mapper = mapper
self.writers = writers or []
self.bound = bound if bound is not None else DEFAULT_BOUND
self.retry_policy = retry_policy
self.config = merge_configs(
config, {"tags": tags or [], "metadata": metadata or {}}
)
self.tags = tags
self.metadata = metadata

def copy(self, update: dict[str, Any]) -> PregelNode:
attrs = {**self.__dict__, **update}
Expand All @@ -162,7 +162,7 @@ def flat_writers(self) -> list[Runnable]:
# careful to not modify the original writers list or ChannelWrite
writers[-2] = ChannelWrite(
writes=writers[-2].writes + writers[-1].writes,
tags=writers[-2].config["tags"] if writers[-2].config else None,
tags=writers[-2].tags,
require_at_least_one_of=writers[-2].require_at_least_one_of,
)
writers.pop()
Expand Down Expand Up @@ -238,7 +238,11 @@ def invoke(
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
return self.bound.invoke(input, merge_configs(self.config, config), **kwargs)
return self.bound.invoke(
input,
merge_configs({"metadata": self.metadata, "tags": self.tags}, config),
**kwargs,
)

async def ainvoke(
self,
Expand All @@ -247,7 +251,9 @@ async def ainvoke(
**kwargs: Optional[Any],
) -> Output:
return await self.bound.ainvoke(
input, merge_configs(self.config, config), **kwargs
input,
merge_configs({"metadata": self.metadata, "tags": self.tags}, config),
**kwargs,
)

def stream(
Expand All @@ -257,7 +263,9 @@ def stream(
**kwargs: Optional[Any],
) -> Iterator[Output]:
yield from self.bound.stream(
input, merge_configs(self.config, config), **kwargs
input,
merge_configs({"metadata": self.metadata, "tags": self.tags}, config),
**kwargs,
)

async def astream(
Expand All @@ -267,6 +275,8 @@ async def astream(
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
async for item in self.bound.astream(
input, merge_configs(self.config, config), **kwargs
input,
merge_configs({"metadata": self.metadata, "tags": self.tags}, config),
**kwargs,
):
yield item
Loading

0 comments on commit 8f87cd1

Please sign in to comment.