Skip to content

Commit

Permalink
Add 2 more benchmark graphs (#1707)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos authored Sep 12, 2024
1 parent 11319ae commit 8a80b1d
Show file tree
Hide file tree
Showing 4 changed files with 351 additions and 10 deletions.
138 changes: 129 additions & 9 deletions libs/langgraph/bench/__main__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
import random
from typing import Optional
from uuid import uuid4

from langchain_core.messages import HumanMessage
from pyperf._runner import Runner
from uvloop import new_event_loop

from bench.fanout_to_subgraph import fanout_to_subgraph
from bench.react_agent import react_agent
from bench.wide_state import wide_state
from langgraph.checkpoint.memory import MemorySaver
from langgraph.pregel import Pregel


async def run(graph: Pregel, input: dict, config: Optional[dict]):
len([c async for c in graph.astream(input, config=config)])
async def run(graph: Pregel, input: dict):
len(
[
c
async for c in graph.astream(
input,
{
"configurable": {"thread_id": str(uuid4())},
"recursion_limit": 1000000000,
},
)
]
)


benchmarks = (
Expand All @@ -22,7 +36,6 @@ async def run(graph: Pregel, input: dict, config: Optional[dict]):
random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(10)
]
},
None,
),
(
"fanout_to_subgraph_10x_checkpoint",
Expand All @@ -32,7 +45,6 @@ async def run(graph: Pregel, input: dict, config: Optional[dict]):
random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(10)
]
},
{"configurable": {"thread_id": "1"}},
),
(
"fanout_to_subgraph_100x",
Expand All @@ -42,7 +54,6 @@ async def run(graph: Pregel, input: dict, config: Optional[dict]):
random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(100)
]
},
None,
),
(
"fanout_to_subgraph_100x_checkpoint",
Expand All @@ -52,12 +63,121 @@ async def run(graph: Pregel, input: dict, config: Optional[dict]):
random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(100)
]
},
{"configurable": {"thread_id": "1"}},
),
(
"react_agent_10x",
react_agent(10, checkpointer=None),
{"messages": [HumanMessage("hi?")]},
),
(
"react_agent_10x_checkpoint",
react_agent(10, checkpointer=MemorySaver()),
{"messages": [HumanMessage("hi?")]},
),
(
"react_agent_100x",
react_agent(100, checkpointer=None),
{"messages": [HumanMessage("hi?")]},
),
(
"react_agent_100x_checkpoint",
react_agent(100, checkpointer=MemorySaver()),
{"messages": [HumanMessage("hi?")]},
),
(
"wide_state_25x300",
wide_state(300).compile(checkpointer=None),
{
"messages": [
{
str(i) * 10: {
str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5
for j in range(5)
}
for i in range(5)
}
]
},
),
(
"wide_state_25x300_checkpoint",
wide_state(300).compile(checkpointer=MemorySaver()),
{
"messages": [
{
str(i) * 10: {
str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5
for j in range(5)
}
for i in range(5)
}
]
},
),
(
"wide_state_15x600",
wide_state(600).compile(checkpointer=None),
{
"messages": [
{
str(i) * 10: {
str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5
for j in range(5)
}
for i in range(3)
}
]
},
),
(
"wide_state_15x600_checkpoint",
wide_state(600).compile(checkpointer=MemorySaver()),
{
"messages": [
{
str(i) * 10: {
str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5
for j in range(5)
}
for i in range(3)
}
]
},
),
(
"wide_state_9x1200",
wide_state(1200).compile(checkpointer=None),
{
"messages": [
{
str(i) * 10: {
str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5
for j in range(3)
}
for i in range(3)
}
]
},
),
(
"wide_state_9x1200_checkpoint",
wide_state(1200).compile(checkpointer=MemorySaver()),
{
"messages": [
{
str(i) * 10: {
str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5
for j in range(3)
}
for i in range(3)
}
]
},
),
)


r = Runner()

for name, graph, input, config in benchmarks:
r.bench_async_func(name, run, graph, input, config, loop_factory=new_event_loop)
for name, graph, input in benchmarks:
r.bench_async_func(name, run, graph, input, loop_factory=new_event_loop)
2 changes: 1 addition & 1 deletion libs/langgraph/bench/fanout_to_subgraph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import operator
from typing import Annotated, TypedDict

Expand Down Expand Up @@ -55,6 +54,7 @@ async def bump_loop(state: JokeOutput):


if __name__ == "__main__":
import asyncio
import random

import uvloop
Expand Down
81 changes: 81 additions & 0 deletions libs/langgraph/bench/react_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Any, Optional
from uuid import uuid4

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.fake_chat_models import (
FakeMessagesListChatModel,
)
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.tools import StructuredTool

from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.prebuilt.chat_agent_executor import create_react_agent
from langgraph.pregel import Pregel


def react_agent(n_tools: int, checkpointer: BaseCheckpointSaver) -> Pregel:
class FakeFuntionChatModel(FakeMessagesListChatModel):
def bind_tools(self, functions: list):
return self

def _generate(
self,
messages: list[BaseMessage],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
response = self.responses[self.i].copy()
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
generation = ChatGeneration(message=response)
return ChatResult(generations=[generation])

tool = StructuredTool.from_function(
lambda query: f"result for query: {query}" * 10,
name=str(uuid4()),
description="",
)

model = FakeFuntionChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"id": str(uuid4()),
"name": tool.name,
"args": {"query": str(uuid4()) * 100},
}
],
id=str(uuid4()),
)
for _ in range(n_tools)
]
+ [
AIMessage(content="answer" * 100, id=str(uuid4())),
]
)

return create_react_agent(model, [tool], checkpointer=checkpointer)


if __name__ == "__main__":
import asyncio

import uvloop

from langgraph.checkpoint.memory import MemorySaver

graph = react_agent(100, checkpointer=MemorySaver())
input = {"messages": [HumanMessage("hi?")]}
config = {"configurable": {"thread_id": "1"}, "recursion_limit": 20000000000}

async def run():
len([c async for c in graph.astream(input, config=config)])

uvloop.install()
asyncio.run(run())
Loading

0 comments on commit 8a80b1d

Please sign in to comment.