Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core:Add optional max_messages to MessagePlaceholder #16098

Merged
merged 9 commits into from
Jun 19, 2024
29 changes: 27 additions & 2 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env

Expand Down Expand Up @@ -160,6 +160,24 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
# AIMessage(content="5 + 2 is 7"),
# HumanMessage(content="now multiply that by 4"),
# ])

Limiting the number of messages:

.. code-block:: python

from langchain_core.prompts import MessagesPlaceholder

prompt = MessagesPlaceholder("history", n_messages=1)

prompt.format_messages(
history=[
("system", "You are an AI assistant."),
("human", "Hello!"),
]
)
# -> [
# HumanMessage(content="Hello!"),
# ]
"""

variable_name: str
Expand All @@ -170,6 +188,10 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
list. If False then a named argument with name `variable_name` must be passed
in, even if the value is an empty list."""

n_messages: Optional[PositiveInt] = None
"""Maximum number of messages to include. If None, then will include all.
Defaults to None."""

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
Expand Down Expand Up @@ -197,7 +219,10 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
f"variable {self.variable_name} should be a list of base messages, "
f"got {value}"
)
return convert_to_messages(value)
value = convert_to_messages(value)
if self.n_messages:
value = value[-self.n_messages :]
return value

@property
def input_variables(self) -> List[str]:
Expand Down
15 changes: 15 additions & 0 deletions libs/core/tests/unit_tests/prompts/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,21 @@ def test_messages_placeholder() -> None:
]


def test_messages_placeholder_with_max() -> None:
history = [
AIMessage(content="1"),
AIMessage(content="2"),
AIMessage(content="3"),
]
prompt = MessagesPlaceholder("history")
assert prompt.format_messages(history=history) == history
prompt = MessagesPlaceholder("history", n_messages=2)
assert prompt.format_messages(history=history) == [
AIMessage(content="2"),
AIMessage(content="3"),
]


def test_chat_prompt_message_placeholder_partial() -> None:
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")])
prompt = prompt.partial(history=[("system", "foo")])
Expand Down
Loading