Skip to content

Commit

Permalink
langgraph: add structured output to create_react_agent (#2848)
Browse files Browse the repository at this point in the history
```python
class WeatherResponse(BaseModel):
    """Respond to the user with this"""

    temperature: float = Field(description="The temperature in fahrenheit")
    wind_direction: str = Field(
        description="The direction of the wind in abbreviated form"
    )
    wind_speed: float = Field(description="The speed of the wind in mph")

@tool
def get_weather(city: Literal["nyc", "sf"]):
    """Use this to get weather information."""
    if city == "nyc":
        return "It is cloudy in NYC, with 5 mph winds in the North-East direction and a temperature of 70 degrees"
    elif city == "sf":
        return "It is 75 degrees and sunny in SF, with 3 mph winds in the South-East direction"
    else:
        raise AssertionError("Unknown city")

model = ChatOpenAI()
tools = [get_weather]
agent_with_structured_output = create_react_agent(model, tools, response_format=WeatherResponse)
agent_with_structured_output.invoke({"messages": [("user", "what's the weather in nyc?")]})
```

```pycon
{
    'messages': [...],
    'structured_response': WeatherResponse(temperature=70.0, wind_directon='NE', wind_speed=5.0)
}
```
  • Loading branch information
vbarda authored Jan 10, 2025
1 parent 35c3ba0 commit 10d46ac
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 13 deletions.
127 changes: 118 additions & 9 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union, cast
from typing import (
Callable,
Literal,
Optional,
Sequence,
Type,
TypeVar,
Union,
cast,
)

from langchain_core.language_models import BaseChatModel, LanguageModelLike
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage
Expand All @@ -8,11 +17,12 @@
RunnableConfig,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel
from typing_extensions import Annotated, TypedDict

from langgraph._api.deprecation import deprecated_parameter
from langgraph.errors import ErrorCode, create_error_message
from langgraph.graph import StateGraph
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.message import add_messages
from langgraph.managed import IsLastStep, RemainingSteps
Expand All @@ -22,11 +32,14 @@
from langgraph.types import Checkpointer
from langgraph.utils.runnable import RunnableCallable

StructuredResponse = Union[dict, BaseModel]
StructuredResponseSchema = Union[dict, type[BaseModel]]


# We create the AgentState that we will pass around
# This simply involves a list of messages
# We want steps to return messages to append to the list
# So we annotate the messages attribute with operator.add
# So we annotate the messages attribute with `add_messages` reducer
class AgentState(TypedDict):
"""The state of the agent."""

Expand All @@ -36,6 +49,8 @@ class AgentState(TypedDict):

remaining_steps: RemainingSteps

structured_response: StructuredResponse


StateSchema = TypeVar("StateSchema", bound=AgentState)
StateSchemaType = Type[StateSchema]
Expand Down Expand Up @@ -162,6 +177,19 @@ def _should_bind_tools(model: LanguageModelLike, tools: Sequence[BaseTool]) -> b
return False


def _get_model(model: LanguageModelLike) -> BaseChatModel:
"""Get the underlying model from a RunnableBinding or return the model itself."""
if isinstance(model, RunnableBinding):
model = model.bound

if not isinstance(model, BaseChatModel):
raise TypeError(
f"Expected `model` to be a ChatModel or RunnableBinding (e.g. model.bind_tools(...)), got {type(model)}"
)

return model


def _validate_chat_history(
messages: Sequence[BaseMessage],
) -> None:
Expand Down Expand Up @@ -201,6 +229,9 @@ def create_react_agent(
state_schema: Optional[StateSchemaType] = None,
messages_modifier: Optional[MessagesModifier] = None,
state_modifier: Optional[StateModifier] = None,
response_format: Optional[
Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]]
] = None,
checkpointer: Optional[Checkpointer] = None,
store: Optional[BaseStore] = None,
interrupt_before: Optional[list[str]] = None,
Expand Down Expand Up @@ -236,6 +267,25 @@ def create_react_agent(
- str: This is converted to a SystemMessage and added to the beginning of the list of messages in state["messages"].
- Callable: This function should take in full graph state and the output is then passed to the language model.
- Runnable: This runnable should take in full graph state and the output is then passed to the language model.
response_format: An optional schema for the final agent output.
If provided, output will be formatted to match the given schema and returned in the 'structured_response' state key.
If not provided, `structured_response` will not be present in the output state.
Can be passed in as:
- an OpenAI function/tool schema,
- a JSON Schema,
- a TypedDict class,
- or a Pydantic class.
- a tuple (prompt, schema), where schema is one of the above.
The prompt will be used together with the model that is being used to generate the structured response.
!!! Important
`response_format` requires the model to support `.with_structured_output`
!!! Note
The graph will make a separate call to the LLM to generate the structured response after the agent loop is finished.
This is not the only strategy to get structured responses, see more options in [this guide](https://langchain-ai.github.io/langgraph/how-tos/react-agent-structured-output/).
checkpointer: An optional checkpoint saver object. This is used for persisting
the state of the graph (e.g., as chat memory) for a single thread (e.g., a single conversation).
store: An optional store object. This is used for persisting data
Expand Down Expand Up @@ -527,9 +577,11 @@ class Agent,Tools otherClass
"""

if state_schema is not None:
if missing_keys := {"messages", "is_last_step"} - set(
state_schema.__annotations__
):
required_keys = {"messages", "remaining_steps"}
if response_format is not None:
required_keys.add("structured_response")

if missing_keys := required_keys - set(state_schema.__annotations__):
raise ValueError(f"Missing required key(s) {missing_keys} in state_schema")

if isinstance(tools, ToolExecutor):
Expand Down Expand Up @@ -633,11 +685,54 @@ async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
# We return a list, because this will get added to the existing list
return {"messages": [response]}

def generate_structured_response(
state: AgentState, config: RunnableConfig
) -> AgentState:
# NOTE: we exclude the last message because there is enough information
# for the LLM to generate the structured response
messages = state["messages"][:-1]
structured_response_schema = response_format
if isinstance(response_format, tuple):
system_prompt, structured_response_schema = response_format
messages = [SystemMessage(content=system_prompt)] + list(messages)

model_with_structured_output = _get_model(model).with_structured_output(
cast(StructuredResponseSchema, structured_response_schema)
)
response = model_with_structured_output.invoke(messages, config)
return {"structured_response": response}

async def agenerate_structured_response(
state: AgentState, config: RunnableConfig
) -> AgentState:
# NOTE: we exclude the last message because there is enough information
# for the LLM to generate the structured response
messages = state["messages"][:-1]
structured_response_schema = response_format
if isinstance(response_format, tuple):
system_prompt, structured_response_schema = response_format
messages = [SystemMessage(content=system_prompt)] + list(messages)

model_with_structured_output = _get_model(model).with_structured_output(
cast(StructuredResponseSchema, structured_response_schema)
)
response = await model_with_structured_output.ainvoke(messages, config)
return {"structured_response": response}

if not tool_calling_enabled:
# Define a new graph
workflow = StateGraph(state_schema or AgentState)
workflow.add_node("agent", RunnableCallable(call_model, acall_model))
workflow.set_entry_point("agent")
if response_format is not None:
workflow.add_node(
"generate_structured_response",
RunnableCallable(
generate_structured_response, agenerate_structured_response
),
)
workflow.add_edge("agent", "generate_structured_response")

return workflow.compile(
checkpointer=checkpointer,
store=store,
Expand All @@ -647,12 +742,12 @@ async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
)

# Define the function that determines whether to continue or not
def should_continue(state: AgentState) -> Literal["tools", "__end__"]:
def should_continue(state: AgentState) -> str:
messages = state["messages"]
last_message = messages[-1]
# If there is no function call, then we finish
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
return "__end__"
return END if response_format is None else "generate_structured_response"
# Otherwise if there is, we continue
else:
return "tools"
Expand All @@ -668,21 +763,35 @@ def should_continue(state: AgentState) -> Literal["tools", "__end__"]:
# This means that this node is the first one called
workflow.set_entry_point("agent")

# Add a structured output node if response_format is provided
if response_format is not None:
workflow.add_node(
"generate_structured_response",
RunnableCallable(
generate_structured_response, agenerate_structured_response
),
)
workflow.add_edge("generate_structured_response", END)
should_continue_destinations = ["tools", "generate_structured_response"]
else:
should_continue_destinations = ["tools", END]

# We now add a conditional edge
workflow.add_conditional_edges(
# First, we define the start node. We use `agent`.
# This means these are the edges taken after the `agent` node is called.
"agent",
# Next, we pass in the function that will determine which node is called next.
should_continue,
path_map=should_continue_destinations,
)

def route_tool_responses(state: AgentState) -> Literal["agent", "__end__"]:
for m in reversed(state["messages"]):
if not isinstance(m, ToolMessage):
break
if m.name in should_return_direct:
return "__end__"
return END
return "agent"

if should_return_direct:
Expand Down
4 changes: 2 additions & 2 deletions libs/langgraph/tests/__snapshots__/test_large_cases.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -2832,10 +2832,10 @@
'''
# ---
# name: test_prebuilt_tool_chat
'{"$defs": {"BaseMessage": {"additionalProperties": true, "description": "Base abstract message class.\\n\\nMessages are the inputs and outputs of ChatModels.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "type"], "title": "BaseMessage", "type": "object"}}, "properties": {"messages": {"items": {"$ref": "#/$defs/BaseMessage"}, "title": "Messages", "type": "array"}}, "required": ["messages"], "title": "LangGraphInput", "type": "object"}'
'{"$defs": {"BaseMessage": {"additionalProperties": true, "description": "Base abstract message class.\\n\\nMessages are the inputs and outputs of ChatModels.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "type"], "title": "BaseMessage", "type": "object"}, "BaseModel": {"properties": {}, "title": "BaseModel", "type": "object"}}, "properties": {"messages": {"items": {"$ref": "#/$defs/BaseMessage"}, "title": "Messages", "type": "array"}, "structured_response": {"anyOf": [{"type": "object"}, {"$ref": "#/$defs/BaseModel"}], "title": "Structured Response"}}, "required": ["messages", "structured_response"], "title": "LangGraphInput", "type": "object"}'
# ---
# name: test_prebuilt_tool_chat.1
'{"$defs": {"BaseMessage": {"additionalProperties": true, "description": "Base abstract message class.\\n\\nMessages are the inputs and outputs of ChatModels.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "type"], "title": "BaseMessage", "type": "object"}}, "properties": {"messages": {"items": {"$ref": "#/$defs/BaseMessage"}, "title": "Messages", "type": "array"}}, "required": ["messages"], "title": "LangGraphOutput", "type": "object"}'
'{"$defs": {"BaseMessage": {"additionalProperties": true, "description": "Base abstract message class.\\n\\nMessages are the inputs and outputs of ChatModels.", "properties": {"content": {"anyOf": [{"type": "string"}, {"items": {"anyOf": [{"type": "string"}, {"type": "object"}]}, "type": "array"}], "title": "Content"}, "additional_kwargs": {"title": "Additional Kwargs", "type": "object"}, "response_metadata": {"title": "Response Metadata", "type": "object"}, "type": {"title": "Type", "type": "string"}, "name": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Name"}, "id": {"anyOf": [{"type": "string"}, {"type": "null"}], "default": null, "title": "Id"}}, "required": ["content", "type"], "title": "BaseMessage", "type": "object"}, "BaseModel": {"properties": {}, "title": "BaseModel", "type": "object"}}, "properties": {"messages": {"items": {"$ref": "#/$defs/BaseMessage"}, "title": "Messages", "type": "array"}, "structured_response": {"anyOf": [{"type": "object"}, {"$ref": "#/$defs/BaseModel"}], "title": "Structured Response"}}, "required": ["messages", "structured_response"], "title": "LangGraphOutput", "type": "object"}'
# ---
# name: test_prebuilt_tool_chat.2
'''
Expand Down
45 changes: 43 additions & 2 deletions libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from langchain_core.runnables import Runnable, RunnableLambda
from langchain_core.tools import BaseTool, ToolException
from langchain_core.tools import tool as dec_tool
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, Field, ValidationError
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic.v1 import ValidationError as ValidationErrorV1
from typing_extensions import TypedDict
Expand All @@ -47,7 +47,11 @@
create_react_agent,
tools_condition,
)
from langgraph.prebuilt.chat_agent_executor import AgentState, _validate_chat_history
from langgraph.prebuilt.chat_agent_executor import (
AgentState,
StructuredResponse,
_validate_chat_history,
)
from langgraph.prebuilt.tool_node import (
TOOL_CALL_ERROR_TEMPLATE,
InjectedState,
Expand All @@ -71,6 +75,7 @@

class FakeToolCallingModel(BaseChatModel):
tool_calls: Optional[list[list[ToolCall]]] = None
structured_response: Optional[StructuredResponse] = None
index: int = 0
tool_style: Literal["openai", "anthropic"] = "openai"

Expand Down Expand Up @@ -98,6 +103,14 @@ def _generate(
def _llm_type(self) -> str:
return "fake-tool-call-model"

def with_structured_output(
self, schema: Type[BaseModel]
) -> Runnable[LanguageModelInput, StructuredResponse]:
if self.structured_response is None:
raise ValueError("Structured response is not set")

return RunnableLambda(lambda x: self.structured_response)

def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
Expand Down Expand Up @@ -511,6 +524,34 @@ def handler(e: Union[str, int]):
_infer_handled_types(handler)


@pytest.mark.skipif(
not IS_LANGCHAIN_CORE_030_OR_GREATER,
reason="Pydantic v1 is required for this test to pass in langchain-core < 0.3",
)
def test_react_agent_with_structured_response() -> None:
class WeatherResponse(BaseModel):
temperature: float = Field(description="The temperature in fahrenheit")

tool_calls = [[{"args": {}, "id": "1", "name": "get_weather"}], []]

def get_weather():
"""Get the weather"""
return "The weather is sunny and 75°F."

expected_structured_response = WeatherResponse(temperature=75)
model = FakeToolCallingModel(
tool_calls=tool_calls, structured_response=expected_structured_response
)
for response_format in (WeatherResponse, ("Meow", WeatherResponse)):
agent = create_react_agent(
model, [get_weather], response_format=response_format
)
response = agent.invoke({"messages": [HumanMessage("What's the weather?")]})
assert response["structured_response"] == expected_structured_response
assert len(response["messages"]) == 4
assert response["messages"][-2].content == "The weather is sunny and 75°F."


# tools for testing Too
def tool1(some_val: int, some_other_val: str) -> str:
"""Tool 1 docstring."""
Expand Down

0 comments on commit 10d46ac

Please sign in to comment.