Skip to content

Commit

Permalink
Changes following review
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Feb 1, 2024
1 parent 105bc67 commit 944a049
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ChatMessageHistory(BaseChatMessageHistory, BaseModel):

messages: List[BaseMessage] = Field(default_factory=list)

async def amessages(self) -> List[BaseMessage]:
async def aget_messages(self) -> List[BaseMessage]:
return self.messages

def add_message(self, message: BaseMessage) -> None:
Expand Down
40 changes: 33 additions & 7 deletions libs/core/langchain_core/chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,26 @@
class BaseChatMessageHistory(ABC):
"""Abstract base class for storing chat message history.
Implementations should over-ride the add_messages method to handle bulk addition
of messages.
Implementations guidelines:
The default implementation of add_message will correctly call add_messages, so
it is not necessary to implement both methods.
Implementations are expected to over-ride all or some of the following methods:
* add_messages: sync variant for bulk addition of messages
* aadd_messages: async variant for bulk addition of messages
* messages: sync variant for getting messages
* aget_messages: async variant for getting messages
* clear: sync variant for clearing messages
* aclear: async variant for clearing messages
add_messages contains a default implementation that calls add_message
for each message in the sequence. This is provided for backwards compatibility
with existing implementations which only had add_message.
Async variants all have default implementations that call the sync variants.
Implementers can choose to over-ride the async implementations to provide
truly async implementations.
Usage guidelines:
When used for updating history, users should favor usage of `add_messages`
over `add_message` or other variants like `add_user_message` and `add_ai_message`
Expand Down Expand Up @@ -55,10 +70,21 @@ def clear(self):
"""

messages: List[BaseMessage]
"""A list of Messages stored in-memory."""
"""A property or attribute that returns a list of messages.
In general, getting the messages may involve IO to the underlying
persistence layer, so this operation is expected to incur some
latency.
"""

async def aget_messages(self) -> List[BaseMessage]:
"""Async version of getting messages.
Can over-ride this method to provide an efficient async implementation.
async def amessages(self) -> List[BaseMessage]:
"""Return messages stored in memory."""
In general, fetching messages may involve IO to the underlying
persistence layer.
"""
return await run_in_executor(None, lambda: self.messages)

def add_user_message(self, message: Union[HumanMessage, str]) -> None:
Expand Down
33 changes: 33 additions & 0 deletions libs/core/tests/unit_tests/chat_history/test_chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,36 @@ def clear(self) -> None:
assert len(store) == 4
assert store[2] == HumanMessage(content="Hello")
assert store[3] == HumanMessage(content="World")


async def test_async_interface() -> None:
"""Test async interface for BaseChatMessageHistory."""

class BulkAddHistory(BaseChatMessageHistory):
def __init__(self) -> None:
self.messages = []

def add_messages(self, message: Sequence[BaseMessage]) -> None:
"""Add a message to the store."""
self.messages.extend(message)

def clear(self) -> None:
"""Clear the store."""
self.messages.clear()

chat_history = BulkAddHistory()
await chat_history.aadd_messages(
[HumanMessage(content="Hello"), HumanMessage(content="World")]
)
assert await chat_history.aget_messages() == [
HumanMessage(content="Hello"),
HumanMessage(content="World"),
]
await chat_history.aadd_messages([HumanMessage(content="!")])
assert await chat_history.aget_messages() == [
HumanMessage(content="Hello"),
HumanMessage(content="World"),
HumanMessage(content="!"),
]
await chat_history.aclear()
assert await chat_history.aget_messages() == []
4 changes: 2 additions & 2 deletions libs/langchain/langchain/memory/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def buffer_as_str(self) -> str:

async def abuffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is True."""
messages = await self.chat_memory.amessages()
messages = await self.chat_memory.aget_messages()
return self._buffer_as_str(messages)

@property
Expand All @@ -51,7 +51,7 @@ def buffer_as_messages(self) -> List[BaseMessage]:

async def abuffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is False."""
return await self.chat_memory.amessages()
return await self.chat_memory.aget_messages()

@property
def memory_variables(self) -> List[str]:
Expand Down

0 comments on commit 944a049

Please sign in to comment.