Skip to content

Commit

Permalink
core[patch]: Fix runnable with message history (#14629)
Browse files Browse the repository at this point in the history
Fix bug shown in #14458. Namely, that saving inputs to history fails
when the input to base runnable is a list of messages
  • Loading branch information
baskaryan authored Dec 13, 2023
1 parent 9974353 commit 4745195
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 40 deletions.
71 changes: 35 additions & 36 deletions libs/core/langchain_core/runnables/history.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from __future__ import annotations

import asyncio
import inspect
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Type,
Union,
)

from langchain_core.chat_history import BaseChatMessageHistory
Expand All @@ -25,8 +29,6 @@
from langchain_core.runnables.config import RunnableConfig
from langchain_core.tracers.schemas import Run

import inspect
from typing import Callable, Dict, Union

MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], Dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
Expand All @@ -35,13 +37,9 @@
class RunnableWithMessageHistory(RunnableBindingBase):
"""A runnable that manages chat message history for another runnable.
Base runnable must have inputs and outputs that can be converted to a list of
BaseMessages.
Base runnable must have inputs and outputs that can be converted to a list of BaseMessages.
RunnableWithMessageHistory must always be called with a config that contains
session_id, e.g.:
``{"configurable": {"session_id": "<SESSION_ID>"}}`
RunnableWithMessageHistory must always be called with a config that contains session_id, e.g. ``{"configurable": {"session_id": "<SESSION_ID>"}}`.
Example (dict input):
.. code-block:: python
Expand Down Expand Up @@ -82,9 +80,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
# -> "The inverse of cosine is called arccosine ..."
Here's an example that uses an in memory chat history, and a factory that
takes in two keys (user_id and conversation id) to create a chat history instance.
Example (get_session_history takes two keys, user_id and conversation id):
.. code-block:: python
store = {}
Expand Down Expand Up @@ -164,46 +160,43 @@ def __init__(
"""Initialize RunnableWithMessageHistory.
Args:
runnable: The base Runnable to be wrapped.
Must take as input one of:
- A sequence of BaseMessages
- A dict with one key for all messages
- A dict with one key for the current input string/message(s) and
runnable: The base Runnable to be wrapped. Must take as input one of:
1. A sequence of BaseMessages
2. A dict with one key for all messages
3. A dict with one key for the current input string/message(s) and
a separate key for historical messages. If the input key points
to a string, it will be treated as a HumanMessage in history.
Must return as output one of:
- A string which can be treated as an AIMessage
- A BaseMessage or sequence of BaseMessages
- A dict with a key for a BaseMessage or sequence of BaseMessages
1. A string which can be treated as an AIMessage
2. A BaseMessage or sequence of BaseMessages
3. A dict with a key for a BaseMessage or sequence of BaseMessages
get_session_history: Function that returns a new BaseChatMessageHistory.
This function should either take a single positional argument
`session_id` of type string and return a corresponding
chat message history instance.
.. code-block:: python
```python
def get_session_history(
session_id: str,
*,
user_id: Optional[str]=None
) -> BaseChatMessageHistory:
...
```
def get_session_history(
session_id: str,
*,
user_id: Optional[str]=None
) -> BaseChatMessageHistory:
...
Or it should take keyword arguments that match the keys of
`session_history_config_specs` and return a corresponding
chat message history instance.
```python
def get_session_history(
*,
user_id: str,
thread_id: str,
) -> BaseChatMessageHistory:
...
```
.. code-block:: python
def get_session_history(
*,
user_id: str,
thread_id: str,
) -> BaseChatMessageHistory:
...
input_messages_key: Must be specified if the base runnable accepts a dict
as input.
Expand Down Expand Up @@ -350,6 +343,12 @@ def _exit_history(self, run: Run, config: RunnableConfig) -> None:
input_val = inputs[self.input_messages_key or "input"]
input_messages = self._get_input_messages(input_val)

# If historic messages were prepended to the input messages, remove them to
# avoid adding duplicate messages to history.
if not self.history_messages_key:
historic_messages = config["configurable"]["message_history"].messages
input_messages = input_messages[len(historic_messages) :]

# Get the output messages
output_val = load(run.outputs)
output_messages = self._get_output_messages(output_val)
Expand Down
21 changes: 17 additions & 4 deletions libs/core/tests/unit_tests/runnables/test_history.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.pydantic_v1 import BaseModel
Expand All @@ -18,8 +18,10 @@ def test_interfaces() -> None:
assert str(history) == "System: system\nHuman: human 1\nAI: ai\nHuman: human 2"


def _get_get_session_history() -> Callable[..., ChatMessageHistory]:
chat_history_store = {}
def _get_get_session_history(
*, store: Optional[Dict[str, Any]] = None
) -> Callable[..., ChatMessageHistory]:
chat_history_store = store if store is not None else {}

def get_session_history(session_id: str, **kwargs: Any) -> ChatMessageHistory:
if session_id not in chat_history_store:
Expand All @@ -34,13 +36,24 @@ def test_input_messages() -> None:
lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
)
get_session_history = _get_get_session_history()
store: Dict = {}
get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
config: RunnableConfig = {"configurable": {"session_id": "1"}}
output = with_history.invoke([HumanMessage(content="hello")], config)
assert output == "you said: hello"
output = with_history.invoke([HumanMessage(content="good bye")], config)
assert output == "you said: hello\ngood bye"
assert store == {
"1": ChatMessageHistory(
messages=[
HumanMessage(content="hello"),
AIMessage(content="you said: hello"),
HumanMessage(content="good bye"),
AIMessage(content="you said: hello\ngood bye"),
]
)
}


def test_input_dict() -> None:
Expand Down

0 comments on commit 4745195

Please sign in to comment.