From 8a80b1d3b1e1e9455574fd3a7a918d9cb71a0d6c Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 12 Sep 2024 16:31:16 -0700 Subject: [PATCH] Add 2 more benchmark graphs (#1707) --- libs/langgraph/bench/__main__.py | 138 ++++++++++++++++++-- libs/langgraph/bench/fanout_to_subgraph.py | 2 +- libs/langgraph/bench/react_agent.py | 81 ++++++++++++ libs/langgraph/bench/wide_state.py | 140 +++++++++++++++++++++ 4 files changed, 351 insertions(+), 10 deletions(-) create mode 100644 libs/langgraph/bench/react_agent.py create mode 100644 libs/langgraph/bench/wide_state.py diff --git a/libs/langgraph/bench/__main__.py b/libs/langgraph/bench/__main__.py index 45c9b3dfd..37c8718b6 100644 --- a/libs/langgraph/bench/__main__.py +++ b/libs/langgraph/bench/__main__.py @@ -1,16 +1,30 @@ import random -from typing import Optional +from uuid import uuid4 +from langchain_core.messages import HumanMessage from pyperf._runner import Runner from uvloop import new_event_loop from bench.fanout_to_subgraph import fanout_to_subgraph +from bench.react_agent import react_agent +from bench.wide_state import wide_state from langgraph.checkpoint.memory import MemorySaver from langgraph.pregel import Pregel -async def run(graph: Pregel, input: dict, config: Optional[dict]): - len([c async for c in graph.astream(input, config=config)]) +async def run(graph: Pregel, input: dict): + len( + [ + c + async for c in graph.astream( + input, + { + "configurable": {"thread_id": str(uuid4())}, + "recursion_limit": 1000000000, + }, + ) + ] + ) benchmarks = ( @@ -22,7 +36,6 @@ async def run(graph: Pregel, input: dict, config: Optional[dict]): random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(10) ] }, - None, ), ( "fanout_to_subgraph_10x_checkpoint", @@ -32,7 +45,6 @@ async def run(graph: Pregel, input: dict, config: Optional[dict]): random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(10) ] }, - {"configurable": {"thread_id": "1"}}, ), ( "fanout_to_subgraph_100x", @@ -42,7 +54,6 @@ async def run(graph: Pregel, input: dict, config: Optional[dict]): random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(100) ] }, - None, ), ( "fanout_to_subgraph_100x_checkpoint", @@ -52,12 +63,121 @@ async def run(graph: Pregel, input: dict, config: Optional[dict]): random.choices("abcdefghijklmnopqrstuvwxyz", k=1000) for _ in range(100) ] }, - {"configurable": {"thread_id": "1"}}, + ), + ( + "react_agent_10x", + react_agent(10, checkpointer=None), + {"messages": [HumanMessage("hi?")]}, + ), + ( + "react_agent_10x_checkpoint", + react_agent(10, checkpointer=MemorySaver()), + {"messages": [HumanMessage("hi?")]}, + ), + ( + "react_agent_100x", + react_agent(100, checkpointer=None), + {"messages": [HumanMessage("hi?")]}, + ), + ( + "react_agent_100x_checkpoint", + react_agent(100, checkpointer=MemorySaver()), + {"messages": [HumanMessage("hi?")]}, + ), + ( + "wide_state_25x300", + wide_state(300).compile(checkpointer=None), + { + "messages": [ + { + str(i) * 10: { + str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 + for j in range(5) + } + for i in range(5) + } + ] + }, + ), + ( + "wide_state_25x300_checkpoint", + wide_state(300).compile(checkpointer=MemorySaver()), + { + "messages": [ + { + str(i) * 10: { + str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 + for j in range(5) + } + for i in range(5) + } + ] + }, + ), + ( + "wide_state_15x600", + wide_state(600).compile(checkpointer=None), + { + "messages": [ + { + str(i) * 10: { + str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 + for j in range(5) + } + for i in range(3) + } + ] + }, + ), + ( + "wide_state_15x600_checkpoint", + wide_state(600).compile(checkpointer=MemorySaver()), + { + "messages": [ + { + str(i) * 10: { + str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 + for j in range(5) + } + for i in range(3) + } + ] + }, + ), + ( + "wide_state_9x1200", + wide_state(1200).compile(checkpointer=None), + { + "messages": [ + { + str(i) * 10: { + str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 + for j in range(3) + } + for i in range(3) + } + ] + }, + ), + ( + "wide_state_9x1200_checkpoint", + wide_state(1200).compile(checkpointer=MemorySaver()), + { + "messages": [ + { + str(i) * 10: { + str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 + for j in range(3) + } + for i in range(3) + } + ] + }, ), ) r = Runner() -for name, graph, input, config in benchmarks: - r.bench_async_func(name, run, graph, input, config, loop_factory=new_event_loop) +for name, graph, input in benchmarks: + r.bench_async_func(name, run, graph, input, loop_factory=new_event_loop) diff --git a/libs/langgraph/bench/fanout_to_subgraph.py b/libs/langgraph/bench/fanout_to_subgraph.py index 53852609d..3c809a73b 100644 --- a/libs/langgraph/bench/fanout_to_subgraph.py +++ b/libs/langgraph/bench/fanout_to_subgraph.py @@ -1,4 +1,3 @@ -import asyncio import operator from typing import Annotated, TypedDict @@ -55,6 +54,7 @@ async def bump_loop(state: JokeOutput): if __name__ == "__main__": + import asyncio import random import uvloop diff --git a/libs/langgraph/bench/react_agent.py b/libs/langgraph/bench/react_agent.py new file mode 100644 index 000000000..4ad671f89 --- /dev/null +++ b/libs/langgraph/bench/react_agent.py @@ -0,0 +1,81 @@ +from typing import Any, Optional +from uuid import uuid4 + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.fake_chat_models import ( + FakeMessagesListChatModel, +) +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.tools import StructuredTool + +from langgraph.checkpoint.base import BaseCheckpointSaver +from langgraph.prebuilt.chat_agent_executor import create_react_agent +from langgraph.pregel import Pregel + + +def react_agent(n_tools: int, checkpointer: BaseCheckpointSaver) -> Pregel: + class FakeFuntionChatModel(FakeMessagesListChatModel): + def bind_tools(self, functions: list): + return self + + def _generate( + self, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + response = self.responses[self.i].copy() + if self.i < len(self.responses) - 1: + self.i += 1 + else: + self.i = 0 + generation = ChatGeneration(message=response) + return ChatResult(generations=[generation]) + + tool = StructuredTool.from_function( + lambda query: f"result for query: {query}" * 10, + name=str(uuid4()), + description="", + ) + + model = FakeFuntionChatModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "id": str(uuid4()), + "name": tool.name, + "args": {"query": str(uuid4()) * 100}, + } + ], + id=str(uuid4()), + ) + for _ in range(n_tools) + ] + + [ + AIMessage(content="answer" * 100, id=str(uuid4())), + ] + ) + + return create_react_agent(model, [tool], checkpointer=checkpointer) + + +if __name__ == "__main__": + import asyncio + + import uvloop + + from langgraph.checkpoint.memory import MemorySaver + + graph = react_agent(100, checkpointer=MemorySaver()) + input = {"messages": [HumanMessage("hi?")]} + config = {"configurable": {"thread_id": "1"}, "recursion_limit": 20000000000} + + async def run(): + len([c async for c in graph.astream(input, config=config)]) + + uvloop.install() + asyncio.run(run()) diff --git a/libs/langgraph/bench/wide_state.py b/libs/langgraph/bench/wide_state.py new file mode 100644 index 000000000..8c51538fa --- /dev/null +++ b/libs/langgraph/bench/wide_state.py @@ -0,0 +1,140 @@ +import operator +from dataclasses import dataclass, field +from functools import partial +from typing import Annotated, Optional, Sequence + +from langgraph.constants import END, START +from langgraph.graph.state import StateGraph + + +def wide_state(n: int) -> StateGraph: + @dataclass(kw_only=True) + class State: + messages: Annotated[list, operator.add] = field(default_factory=list) + trigger_events: Annotated[list, operator.add] = field(default_factory=list) + """The external events that are converted by the graph.""" + primary_issue_medium: Annotated[str, lambda x, y: y or x] = field( + default="email" + ) + autoresponse: Annotated[Optional[dict], lambda _, y: y] = field( + default=None + ) # Always overwrite + issue: Annotated[dict | None, lambda x, y: y if y else x] = field(default=None) + relevant_rules: Optional[list[dict]] = field(default=None) + """SOPs fetched from the rulebook that are relevant to the current conversation.""" + memory_docs: Optional[list[dict]] = field(default=None) + """Memory docs fetched from the memory service that are relevant to the current conversation.""" + categorizations: Annotated[list[dict], operator.add] = field( + default_factory=list + ) + """The issue categorizations auto-generated by the AI.""" + responses: Annotated[list[dict], operator.add] = field(default_factory=list) + """The draft responses recommended by the AI.""" + + user_info: Annotated[Optional[dict], lambda x, y: y if y is not None else x] = ( + field(default=None) + ) + """The current user state (by email).""" + crm_info: Annotated[Optional[dict], lambda x, y: y if y is not None else x] = ( + field(default=None) + ) + """The CRM information for organization the current user is from.""" + email_thread_id: Annotated[ + Optional[str], lambda x, y: y if y is not None else x + ] = field(default=None) + """The current email thread ID.""" + slack_participants: Annotated[dict, operator.or_] = field(default_factory=dict) + """The growing list of current slack participants.""" + bot_id: Optional[str] = field(default=None) + """The ID of the bot user in the slack channel.""" + notified_assignees: Annotated[dict, operator.or_] = field(default_factory=dict) + + def read_write(read: str, write: Sequence[str], input: State) -> dict: + val = getattr(input, read) + val_single = val[-1] if isinstance(val, list) else val + val_list = val if isinstance(val, list) else [val] + return { + k: val_list if isinstance(getattr(input, k), list) else val_single + for k in write + } + + builder = StateGraph(State) + builder.add_edge(START, "one") + builder.add_node( + "one", + partial(read_write, "messages", ["trigger_events", "primary_issue_medium"]), + ) + builder.add_edge("one", "two") + builder.add_node( + "two", + partial(read_write, "trigger_events", ["autoresponse", "issue"]), + ) + builder.add_edge("two", "three") + builder.add_edge("two", "four") + builder.add_node( + "three", + partial(read_write, "autoresponse", ["relevant_rules"]), + ) + builder.add_node( + "four", + partial( + read_write, + "trigger_events", + ["categorizations", "responses", "memory_docs"], + ), + ) + builder.add_node( + "five", + partial( + read_write, + "categorizations", + [ + "user_info", + "crm_info", + "email_thread_id", + "slack_participants", + "bot_id", + "notified_assignees", + ], + ), + ) + builder.add_edge(["three", "four"], "five") + builder.add_edge("five", "six") + builder.add_node( + "six", + partial(read_write, "responses", ["messages"]), + ) + builder.add_conditional_edges( + "six", lambda state: END if len(state.messages) > n else "one" + ) + + return builder + + +if __name__ == "__main__": + import asyncio + + import uvloop + + from langgraph.checkpoint.memory import MemorySaver + + graph = wide_state(1000).compile(checkpointer=MemorySaver()) + input = { + "messages": [ + { + str(i) * 10: { + str(j) * 10: ["hi?" * 10, True, 1, 6327816386138, None] * 5 + for j in range(5) + } + for i in range(5) + } + ] + } + config = {"configurable": {"thread_id": "1"}, "recursion_limit": 20000000000} + + async def run(): + async for c in graph.astream(input, config=config): + print(c.keys()) + + uvloop.install() + asyncio.run(run())