Skip to content

Commit

Permalink
Add async methods to BaseChatMessageHistory and BaseMemory
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jan 31, 2024
1 parent 2e5949b commit 105bc67
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Sequence

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage
Expand All @@ -13,9 +13,19 @@ class ChatMessageHistory(BaseChatMessageHistory, BaseModel):

messages: List[BaseMessage] = Field(default_factory=list)

async def amessages(self) -> List[BaseMessage]:
return self.messages

def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
self.messages.append(message)

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Add messages to the store"""
self.add_messages(messages)

def clear(self) -> None:
self.messages = []

async def aclear(self) -> None:
self.clear()
19 changes: 18 additions & 1 deletion libs/core/langchain_core/chat_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
HumanMessage,
get_buffer_string,
)
from langchain_core.runnables import run_in_executor


class BaseChatMessageHistory(ABC):
Expand Down Expand Up @@ -56,6 +57,10 @@ def clear(self):
messages: List[BaseMessage]
"""A list of Messages stored in-memory."""

async def amessages(self) -> List[BaseMessage]:
"""Return messages stored in memory."""
return await run_in_executor(None, lambda: self.messages)

def add_user_message(self, message: Union[HumanMessage, str]) -> None:
"""Convenience method for adding a human message string to the store.
Expand Down Expand Up @@ -98,7 +103,7 @@ def add_message(self, message: BaseMessage) -> None:
"""
if type(self).add_messages != BaseChatMessageHistory.add_messages:
# This means that the sub-class has implemented an efficient add_messages
# method, so we should usage of add_message to that.
# method, so we should use it.
self.add_messages([message])
else:
raise NotImplementedError(
Expand All @@ -118,10 +123,22 @@ def add_messages(self, messages: Sequence[BaseMessage]) -> None:
for message in messages:
self.add_message(message)

async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
"""Add a list of messages.
Args:
messages: A list of BaseMessage objects to store.
"""
await run_in_executor(None, self.add_messages, messages)

@abstractmethod
def clear(self) -> None:
"""Remove all messages from the store"""

async def aclear(self) -> None:
"""Remove all messages from the store"""
await run_in_executor(None, self.clear)

def __str__(self) -> str:
"""Return a string representation of the chat history."""
return get_buffer_string(self.messages)
15 changes: 15 additions & 0 deletions libs/core/langchain_core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, List

from langchain_core.load.serializable import Serializable
from langchain_core.runnables import run_in_executor


class BaseMemory(Serializable, ABC):
Expand Down Expand Up @@ -50,10 +51,24 @@ def memory_variables(self) -> List[str]:
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""

async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
return await run_in_executor(None, self.load_memory_variables, inputs)

@abstractmethod
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save the context of this chain run to memory."""

async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Save the context of this chain run to memory."""
await run_in_executor(None, self.save_context, inputs, outputs)

@abstractmethod
def clear(self) -> None:
"""Clear memory contents."""

async def aclear(self) -> None:
"""Clear memory contents."""
await run_in_executor(None, self.clear)
46 changes: 42 additions & 4 deletions libs/langchain/langchain/memory/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,40 @@ def buffer(self) -> Any:
"""String buffer of memory."""
return self.buffer_as_messages if self.return_messages else self.buffer_as_str

@property
def buffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is True."""
async def abuffer(self) -> Any:
"""String buffer of memory."""
return (
await self.abuffer_as_messages()
if self.return_messages
else await self.abuffer_as_str()
)

def _buffer_as_str(self, messages: List[BaseMessage]) -> str:
return get_buffer_string(
self.chat_memory.messages,
messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)

@property
def buffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is True."""
return self._buffer_as_str(self.chat_memory.messages)

async def abuffer_as_str(self) -> str:
"""Exposes the buffer as a string in case return_messages is True."""
messages = await self.chat_memory.amessages()
return self._buffer_as_str(messages)

@property
def buffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is False."""
return self.chat_memory.messages

async def abuffer_as_messages(self) -> List[BaseMessage]:
"""Exposes the buffer as a list of messages in case return_messages is False."""
return await self.chat_memory.amessages()

@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables.
Expand All @@ -45,6 +65,11 @@ def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}

async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
buffer = await self.abuffer()
return {self.memory_key: buffer}


class ConversationStringBufferMemory(BaseMemory):
"""Buffer for storing conversation memory."""
Expand Down Expand Up @@ -77,6 +102,10 @@ def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
return {self.memory_key: self.buffer}

async def aload_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Return history buffer."""
return self.load_memory_variables(inputs)

def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
if self.input_key is None:
Expand All @@ -93,6 +122,15 @@ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
ai = f"{self.ai_prefix}: " + outputs[output_key]
self.buffer += "\n" + "\n".join([human, ai])

async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Save context from this conversation to buffer."""
return self.save_context(inputs, outputs)

def clear(self) -> None:
"""Clear memory contents."""
self.buffer = ""

async def aclear(self) -> None:
self.clear()
19 changes: 17 additions & 2 deletions libs/langchain/langchain/memory/chat_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.memory import BaseMemory
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.pydantic_v1 import Field

from langchain.memory.utils import get_prompt_input_key
Expand Down Expand Up @@ -35,9 +36,23 @@ def _get_input_output(
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
self.chat_memory.add_user_message(input_str)
self.chat_memory.add_ai_message(output_str)
self.chat_memory.add_messages(
[HumanMessage(content=input_str), AIMessage(content=output_str)]
)

async def asave_context(
self, inputs: Dict[str, Any], outputs: Dict[str, str]
) -> None:
"""Save context from this conversation to buffer."""
input_str, output_str = self._get_input_output(inputs, outputs)
await self.chat_memory.aadd_messages(
[HumanMessage(content=input_str), AIMessage(content=output_str)]
)

def clear(self) -> None:
"""Clear memory contents."""
self.chat_memory.clear()

async def aclear(self) -> None:
"""Clear memory contents."""
await self.chat_memory.aclear()
32 changes: 30 additions & 2 deletions libs/langchain/tests/unit_tests/chains/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,22 @@ def test_memory_ai_prefix() -> None:
"""Test that ai_prefix in the memory component works."""
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
memory.save_context({"input": "bar"}, {"output": "foo"})
assert memory.buffer == "Human: bar\nAssistant: foo"
assert memory.load_memory_variables({}) == {"foo": "Human: bar\nAssistant: foo"}


def test_memory_human_prefix() -> None:
"""Test that human_prefix in the memory component works."""
memory = ConversationBufferMemory(memory_key="foo", human_prefix="Friend")
memory.save_context({"input": "bar"}, {"output": "foo"})
assert memory.buffer == "Friend: bar\nAI: foo"
assert memory.load_memory_variables({}) == {"foo": "Friend: bar\nAI: foo"}


async def test_memory_async() -> None:
memory = ConversationBufferMemory(memory_key="foo", ai_prefix="Assistant")
await memory.asave_context({"input": "bar"}, {"output": "foo"})
assert await memory.aload_memory_variables({}) == {
"foo": "Human: bar\nAssistant: foo"
}


def test_conversation_chain_works() -> None:
Expand Down Expand Up @@ -100,3 +108,23 @@ def test_clearing_conversation_memory(memory: BaseMemory) -> None:

memory.clear()
assert memory.load_memory_variables({}) == {"baz": ""}


@pytest.mark.parametrize(
"memory",
[
ConversationBufferMemory(memory_key="baz"),
ConversationSummaryMemory(llm=FakeLLM(), memory_key="baz"),
ConversationBufferWindowMemory(memory_key="baz"),
],
)
async def test_clearing_conversation_memory_async(memory: BaseMemory) -> None:
"""Test clearing the conversation memory."""
# This is a good input because the input is not the same as baz.
good_inputs = {"foo": "bar", "baz": "foo"}
# This is a good output because there is one variable.
good_outputs = {"bar": "foo"}
await memory.asave_context(good_inputs, good_outputs)

await memory.aclear()
assert await memory.aload_memory_variables({}) == {"baz": ""}

0 comments on commit 105bc67

Please sign in to comment.