Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements in core library #1683

Merged
merged 13 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/_integration_test.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: CLI integration test

Check notice on line 1 in .github/workflows/_integration_test.yml

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 60.9 ms +- 1.2 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 104 ms +- 2 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 595 ms +- 18 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 1.06 sec +- 0.04 sec

Check notice on line 1 in .github/workflows/_integration_test.yml

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+------------------------------------+--------------------+------------------------+ | Benchmark | benchmark-baseline | benchmark | +====================================+====================+========================+ | fanout_to_subgraph_10x | 126 ms | 60.9 ms: 2.07x faster | +------------------------------------+--------------------+------------------------+ | fanout_to_subgraph_100x | 1.15 sec | 595 ms: 1.94x faster | +------------------------------------+--------------------+------------------------+ | fanout_to_subgraph_10x_checkpoint | 171 ms | 104 ms: 1.64x faster | +------------------------------------+--------------------+------------------------+ | fanout_to_subgraph_100x_checkpoint | 1.63 sec | 1.06 sec: 1.55x faster | +------------------------------------+--------------------+------------------------+ | Geometric mean | (ref) | 1.79x faster | +------------------------------------+--------------------+------------------------+

on:
workflow_call:
Expand All @@ -22,7 +22,7 @@
- uses: actions/checkout@v4
- name: Get changed files
id: changed-files
uses: Ana06/get-changed-files@v2.2.0
uses: Ana06/get-changed-files@v2.3.0
with:
filter: "libs/cli/**"
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
- uses: actions/checkout@v4
- name: Get changed files
id: changed-files
uses: Ana06/get-changed-files@v2.2.0
uses: Ana06/get-changed-files@v2.3.0
with:
filter: "${{ inputs.working-directory }}/**"
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
Expand Down
44 changes: 36 additions & 8 deletions .github/workflows/bench.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,25 @@ jobs:
- name: Install dependencies
run: poetry install --with dev
- name: Run benchmarks
run: make benchmark
run: OUTPUT=out/benchmark-baseline.json make -s benchmark
- name: Upload benchmark baseline
uses: actions/upload-artifact@v4
with:
name: benchmark-baseline.json
path: libs/langgraph/out/benchmark.json
compare:
name: benchmark-baseline
path: libs/langgraph/out/benchmark-baseline.json
benchmark:
runs-on: ubuntu-latest
defaults:
run:
working-directory: libs/langgraph
needs: [baseline]
steps:
- uses: actions/checkout@v4
- id: files
name: Get changed files
uses: Ana06/[email protected]
with:
format: json
- name: Set up Python 3.11 + Poetry ${{ env.POETRY_VERSION }}
uses: "./.github/actions/poetry_setup"
with:
Expand All @@ -50,13 +55,36 @@ jobs:
- name: Install dependencies
run: poetry install --with dev
- name: Run benchmarks
run: make benchmark
id: benchmark
run: |
{
echo 'OUTPUT<<EOF'
make -s benchmark
echo EOF
} >> "$GITHUB_OUTPUT"
- name: Download benchmark baseline
uses: actions/download-artifact@v4
with:
name: benchmark-baseline.json
path: libs/langgraph/out
merge-multiple: true
- name: Compare benchmarks
run: poetry run pyperf compare_to out/benchmark-baseline.json out/benchmark.json --table --group-by-speed >> $GITHUB_OUTPUT
id: compare
run: |
{
echo 'OUTPUT<<EOF'
poetry run pyperf compare_to out/benchmark-baseline.json out/benchmark.json --table --group-by-speed
echo EOF
} >> "$GITHUB_OUTPUT"
- name: Annotation
run: echo "::notice file=libs/langgraph/bench/__main__.py::$GITHUB_OUTPUT"
uses: actions/github-script@v7
with:
script: |
const file = JSON.parse(`${{ steps.files.outputs.added_modified_renamed }}`)[0]
core.notice(`${{ steps.benchmark.outputs.OUTPUT }}`, {
title: 'Benchmark results',
file,
})
core.notice(`${{ steps.compare.outputs.OUTPUT }}`, {
title: 'Comparison against main',
file,
})
9 changes: 4 additions & 5 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from langchain_core.runnables.base import Input, Output
from langchain_core.runnables.config import (
RunnableConfig,
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
)
Expand Down Expand Up @@ -86,6 +85,7 @@
from langgraph.pregel.write import ChannelWrite, ChannelWriteEntry
from langgraph.store.base import BaseStore
from langgraph.utils.config import (
ensure_config,
merge_configs,
patch_checkpoint_map,
patch_config,
Expand Down Expand Up @@ -1156,7 +1156,7 @@ def output() -> Iterator:
else:
yield payload

config = ensure_config(merge_configs(self.config, config))
config = ensure_config(self.config, config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
None,
Expand Down Expand Up @@ -1337,7 +1337,7 @@ def output() -> Iterator:
else:
yield payload

config = ensure_config(merge_configs(self.config, config))
config = ensure_config(self.config, config)
callback_manager = get_async_callback_manager_for_config(config)
run_manager = await callback_manager.on_chain_start(
None,
Expand Down Expand Up @@ -1402,8 +1402,7 @@ def output() -> Iterator:
# channel updates from step N are only visible in step N+1
# channels are guaranteed to be immutable for the duration of the step,
# with channel updates applied only at the transition between steps
while await asyncio.to_thread(
loop.tick,
while loop.tick(
input_keys=self.input_channels,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
Expand Down
31 changes: 17 additions & 14 deletions libs/langgraph/langgraph/pregel/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from langgraph.pregel.types import All, PregelExecutableTask, PregelTask
from langgraph.utils.config import merge_configs, patch_config

EMPTY_SEQ = tuple()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an idea of how helpful this was? I guess this is sort of an inner loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was more because the previous code was actually incorrect, as it could end up trying to do <str> in None



class WritesProtocol(Protocol):
name: str
Expand Down Expand Up @@ -78,7 +80,10 @@ def should_interrupt(
task
for task in tasks
if (
(not task.config or TAG_HIDDEN not in task.config.get("tags"))
(
not task.config
or TAG_HIDDEN not in task.config.get("tags", EMPTY_SEQ)
)
if interrupt_nodes == "*"
else task.name in interrupt_nodes
)
Expand Down Expand Up @@ -182,11 +187,10 @@ def apply_writes(
for chan in task.triggers
if chan not in RESERVED and chan in channels
}:
if channels[chan].consume():
if get_next_version is not None:
checkpoint["channel_versions"][chan] = get_next_version(
max_version, channels[chan]
)
if channels[chan].consume() and get_next_version is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this just stylistic or does it actually improve perf?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this was just style

checkpoint["channel_versions"][chan] = get_next_version(
max_version, channels[chan]
)

# clear pending sends
if checkpoint["pending_sends"]:
Expand Down Expand Up @@ -216,8 +220,7 @@ def apply_writes(
updated_channels: set[str] = set()
for chan, vals in pending_writes_by_channel.items():
if chan in channels:
updated = channels[chan].update(vals)
if updated and get_next_version is not None:
if channels[chan].update(vals) and get_next_version is not None:
checkpoint["channel_versions"][chan] = get_next_version(
max_version, channels[chan]
)
Expand Down Expand Up @@ -370,6 +373,8 @@ def prepare_single_task(
proc = processes[packet.node]
if node := proc.node:
managed.replace_runtime_placeholders(step, packet.arg)
if proc.metadata:
metadata.update(proc.metadata)
writes = deque()
task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
return PregelExecutableTask(
Expand All @@ -379,9 +384,7 @@ def prepare_single_task(
writes,
patch_config(
merge_configs(
config,
processes[packet.node].config,
{"metadata": metadata},
config, {"metadata": metadata, "tags": proc.tags}
),
run_name=packet.node,
callbacks=(
Expand Down Expand Up @@ -478,6 +481,8 @@ def prepare_single_task(

if for_execution:
if node := proc.node:
if proc.metadata:
metadata.update(proc.metadata)
writes = deque()
task_checkpoint_ns = f"{checkpoint_ns}:{task_id}"
return PregelExecutableTask(
Expand All @@ -487,9 +492,7 @@ def prepare_single_task(
writes,
patch_config(
merge_configs(
config,
proc.config,
{"metadata": metadata},
config, {"metadata": metadata, "tags": proc.tags}
),
run_name=name,
callbacks=(
Expand Down
5 changes: 4 additions & 1 deletion libs/langgraph/langgraph/pregel/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def __radd__(self, other: dict[str, Any]) -> "AddableUpdatesDict":
raise TypeError("AddableUpdatesDict does not support right-side addition")


EMPTY_SEQ = tuple()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just make a single empty_seq in an internal _constants file or something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so?



def map_output_updates(
output_channels: Union[str, Sequence[str]],
tasks: list[tuple[PregelExecutableTask, Sequence[tuple[str, Any]]]],
Expand All @@ -106,7 +109,7 @@ def map_output_updates(
output_tasks = [
(t, ww)
for t, ww in tasks
if (not t.config or TAG_HIDDEN not in t.config.get("tags"))
if (not t.config or TAG_HIDDEN not in t.config.get("tags", EMPTY_SEQ))
and ww[0][0] != ERROR
and ww[0][0] != INTERRUPT
]
Expand Down
7 changes: 6 additions & 1 deletion libs/langgraph/langgraph/pregel/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,9 @@ def _output_writes(
self, task_id: str, writes: Sequence[tuple[str, Any]], *, cached: bool = False
) -> None:
if task := self.tasks.get(task_id):
if task.config is not None and TAG_HIDDEN in task.config.get("tags"):
if task.config is not None and TAG_HIDDEN in task.config.get(
"tags", EMPTY_SEQ
):
return
if writes[0][0] != ERROR and writes[0][0] != INTERRUPT:
self._emit(
Expand Down Expand Up @@ -806,3 +808,6 @@ async def __aexit__(
return await asyncio.shield(
self.stack.__aexit__(exc_type, exc_value, traceback)
)


EMPTY_SEQ = tuple()
30 changes: 20 additions & 10 deletions libs/langgraph/langgraph/pregel/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ class PregelNode(Runnable):

retry_policy: Optional[RetryPolicy]

config: RunnableConfig
tags: Optional[Sequence[str]]

metadata: Optional[Mapping[str, Any]]

def __init__(
self,
Expand All @@ -133,17 +135,15 @@ def __init__(
metadata: Optional[Mapping[str, Any]] = None,
bound: Optional[Runnable[Any, Any]] = None,
retry_policy: Optional[RetryPolicy] = None,
config: Optional[RunnableConfig] = None,
) -> None:
self.channels = channels
self.triggers = list(triggers)
self.mapper = mapper
self.writers = writers or []
self.bound = bound if bound is not None else DEFAULT_BOUND
self.retry_policy = retry_policy
self.config = merge_configs(
config, {"tags": tags or [], "metadata": metadata or {}}
)
self.tags = tags
self.metadata = metadata

def copy(self, update: dict[str, Any]) -> PregelNode:
attrs = {**self.__dict__, **update}
Expand All @@ -162,7 +162,7 @@ def flat_writers(self) -> list[Runnable]:
# careful to not modify the original writers list or ChannelWrite
writers[-2] = ChannelWrite(
writes=writers[-2].writes + writers[-1].writes,
tags=writers[-2].config["tags"] if writers[-2].config else None,
tags=writers[-2].tags,
require_at_least_one_of=writers[-2].require_at_least_one_of,
)
writers.pop()
Expand Down Expand Up @@ -238,7 +238,11 @@ def invoke(
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
return self.bound.invoke(input, merge_configs(self.config, config), **kwargs)
return self.bound.invoke(
input,
merge_configs({"metadata": self.metadata, "tags": self.tags}, config),
**kwargs,
)

async def ainvoke(
self,
Expand All @@ -247,7 +251,9 @@ async def ainvoke(
**kwargs: Optional[Any],
) -> Output:
return await self.bound.ainvoke(
input, merge_configs(self.config, config), **kwargs
input,
merge_configs({"metadata": self.metadata, "tags": self.tags}, config),
**kwargs,
)

def stream(
Expand All @@ -257,7 +263,9 @@ def stream(
**kwargs: Optional[Any],
) -> Iterator[Output]:
yield from self.bound.stream(
input, merge_configs(self.config, config), **kwargs
input,
merge_configs({"metadata": self.metadata, "tags": self.tags}, config),
**kwargs,
)

async def astream(
Expand All @@ -267,6 +275,8 @@ async def astream(
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
async for item in self.bound.astream(
input, merge_configs(self.config, config), **kwargs
input,
merge_configs({"metadata": self.metadata, "tags": self.tags}, config),
**kwargs,
):
yield item
Loading
Loading