Skip to content

Commit

Permalink
core[minor]: Image prompt template (#14263)
Browse files Browse the repository at this point in the history
Builds on Bagatur's (#13227). See unit test for example usage (below)

```python
def test_chat_tmpl_from_messages_multipart_image() -> None:
    base64_image = "abcd123"
    other_base64_image = "abcd123"
    template = ChatPromptTemplate.from_messages(
        [
            ("system", "You are an AI assistant named {name}."),
            (
                "human",
                [
                    {"type": "text", "text": "What's in this image?"},
                    # OAI supports all these structures today
                    {
                        "type": "image_url",
                        "image_url": "data:image/jpeg;base64,{my_image}",
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": "data:image/jpeg;base64,{my_image}"},
                    },
                    {"type": "image_url", "image_url": "{my_other_image}"},
                    {
                        "type": "image_url",
                        "image_url": {"url": "{my_other_image}", "detail": "medium"},
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": "https://www.langchain.com/image.png"},
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": ""},
                    },
                ],
            ),
        ]
    )
    messages = template.format_messages(
        name="R2D2", my_image=base64_image, my_other_image=other_base64_image
    )
    expected = [
        SystemMessage(content="You are an AI assistant named R2D2."),
        HumanMessage(
            content=[
                {"type": "text", "text": "What's in this image?"},
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{other_base64_image}"
                    },
                },
                {
                    "type": "image_url",
                    "image_url": {"url": f"{other_base64_image}"},
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"{other_base64_image}",
                        "detail": "medium",
                    },
                },
                {
                    "type": "image_url",
                    "image_url": {"url": "https://www.langchain.com/image.png"},
                },
                {
                    "type": "image_url",
                    "image_url": {"url": ""},
                },
            ]
        ),
    ]
    assert messages == expected
```

---------

Co-authored-by: Bagatur <[email protected]>
Co-authored-by: Brace Sproul <[email protected]>
  • Loading branch information
3 people authored Jan 28, 2024
1 parent 3c387bc commit 38425c9
Show file tree
Hide file tree
Showing 8 changed files with 453 additions and 48 deletions.
26 changes: 26 additions & 0 deletions libs/core/langchain_core/prompt_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from abc import ABC, abstractmethod
from typing import List, Literal, Sequence

from typing_extensions import TypedDict

from langchain_core.load.serializable import Serializable
from langchain_core.messages import (
AnyMessage,
Expand Down Expand Up @@ -82,6 +84,30 @@ def get_lc_namespace(cls) -> List[str]:
return ["langchain", "prompts", "chat"]


class ImageURL(TypedDict, total=False):
detail: Literal["auto", "low", "high"]
"""Specifies the detail level of the image."""

url: str
"""Either a URL of the image or the base64 encoded image data."""


class ImagePromptValue(PromptValue):
"""Image prompt value."""

image_url: ImageURL
"""Prompt image."""
type: Literal["ImagePromptValue"] = "ImagePromptValue"

def to_string(self) -> str:
"""Return prompt as string."""
return self.image_url["url"]

def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=[self.image_url])]


class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas."""
Expand Down
15 changes: 11 additions & 4 deletions libs/core/langchain_core/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
Any,
Callable,
Dict,
Generic,
List,
Mapping,
Optional,
Type,
TypeVar,
Union,
)

Expand All @@ -30,7 +32,12 @@
from langchain_core.documents import Document


class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
FormatOutputType = TypeVar("FormatOutputType")


class BasePromptTemplate(
RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC
):
"""Base class for all prompt templates, returning a prompt."""

input_variables: List[str]
Expand Down Expand Up @@ -142,7 +149,7 @@ def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
return {**partial_kwargs, **kwargs}

@abstractmethod
def format(self, **kwargs: Any) -> str:
def format(self, **kwargs: Any) -> FormatOutputType:
"""Format the prompt with the inputs.
Args:
Expand Down Expand Up @@ -210,7 +217,7 @@ def save(self, file_path: Union[Path, str]) -> None:
raise ValueError(f"{save_path} must be json or yaml")


def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
"""Format a document into a string based on a prompt template.
First, this pulls information from the document from two sources:
Expand All @@ -236,7 +243,7 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
Example:
.. code-block:: python
from langchain_core import Document
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
doc = Document(page_content="This is a joke", metadata={"page": "1"})
Expand Down
217 changes: 181 additions & 36 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
Set,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
cast,
overload,
)

Expand All @@ -30,10 +32,11 @@
convert_to_messages,
)
from langchain_core.messages.base import get_msg_title_repr
from langchain_core.prompt_values import ChatPromptValue, PromptValue
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env
Expand Down Expand Up @@ -288,34 +291,152 @@ def format(self, **kwargs: Any) -> BaseMessage:
)


class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
_StringImageMessagePromptTemplateT = TypeVar(
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
)


class _TextTemplateParam(TypedDict, total=False):
text: Union[str, Dict]


class _ImageTemplateParam(TypedDict, total=False):
image_url: Union[str, Dict]


class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""

prompt: Union[
StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]]
]
"""Prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template."""

_msg_class: Type[BaseMessage]

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]

def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
@classmethod
def from_template(
cls: Type[_StringImageMessagePromptTemplateT],
template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: str = "f-string",
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
"""Create a class from a string template.
Args:
**kwargs: Keyword arguments to use for formatting.
template: a template.
template_format: format of the template.
**kwargs: keyword arguments to pass to the constructor.
Returns:
Formatted message.
A new instance of this class.
"""
text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
if isinstance(template, str):
prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template(
template, template_format=template_format
)
return cls(prompt=prompt, **kwargs)
elif isinstance(template, list):
prompt = []
for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(_TextTemplateParam, tmpl)["text"] # type: ignore[assignment] # noqa: E501
prompt.append(
PromptTemplate.from_template(
text, template_format=template_format
)
)
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
if isinstance(img_template, str):
vars = get_template_variables(img_template, "f-string")
if vars:
if len(vars) > 1:
raise ValueError(
"Only one format variable allowed per image"
f" template.\nGot: {vars}"
f"\nFrom: {tmpl}"
)
input_variables = [vars[0]]
else:
input_variables = None
img_template = {"url": img_template}
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
)
elif isinstance(img_template, dict):
img_template = dict(img_template)
if "url" in img_template:
input_variables = get_template_variables(
img_template["url"], "f-string"
)
else:
input_variables = None
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
)
else:
raise ValueError()
prompt.append(img_template_obj)
else:
raise ValueError()
return cls(prompt=prompt, **kwargs)
else:
raise ValueError()

@classmethod
def from_template_file(
cls: Type[_StringImageMessagePromptTemplateT],
template_file: Union[str, Path],
input_variables: List[str],
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
"""Create a class from a template file.
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI."""
Args:
template_file: path to a template file. String or Path.
input_variables: list of input variables.
**kwargs: keyword arguments to pass to the constructor.
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
Returns:
A new instance of this class.
"""
with open(str(template_file), "r") as f:
template = f.read()
return cls.from_template(template, input_variables=input_variables, **kwargs)

def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
return [self.format(**kwargs)]

@property
def input_variables(self) -> List[str]:
"""
Input variables for this prompt template.
Returns:
List of input variable names.
"""
prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt]
input_variables = [iv for prompt in prompts for iv in prompt.input_variables]
return input_variables

def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Expand All @@ -326,31 +447,54 @@ def format(self, **kwargs: Any) -> BaseMessage:
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
if isinstance(self.prompt, StringPromptTemplate):
text = self.prompt.format(**kwargs)
return self._msg_class(
content=text, additional_kwargs=self.additional_kwargs
)
else:
content = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = prompt.format(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = prompt.format(**inputs)
content.append({"type": "image_url", "image_url": formatted})
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
)


class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""System message prompt template.
This is a message that is not sent to the user.
"""
class HumanMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""

_msg_class: Type[BaseMessage] = HumanMessage


class AIMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI."""

_msg_class: Type[BaseMessage] = AIMessage

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]

def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.

Args:
**kwargs: Keyword arguments to use for formatting.
class SystemMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""System message prompt template.
This is a message that is not sent to the user.
"""

Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
_msg_class: Type[BaseMessage] = SystemMessage

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]


class BaseChatPromptTemplate(BasePromptTemplate, ABC):
Expand Down Expand Up @@ -405,8 +549,7 @@ def pretty_print(self) -> None:

MessageLikeRepresentation = Union[
MessageLike,
Tuple[str, str],
Tuple[Type, str],
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
str,
]

Expand Down Expand Up @@ -738,7 +881,7 @@ def pretty_repr(self, html: bool = False) -> str:


def _create_template_from_message_type(
message_type: str, template: str
message_type: str, template: Union[str, list]
) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string.
Expand All @@ -754,9 +897,9 @@ def _create_template_from_message_type(
template
)
elif message_type in ("ai", "assistant"):
message = AIMessagePromptTemplate.from_template(template)
message = AIMessagePromptTemplate.from_template(cast(str, template))
elif message_type == "system":
message = SystemMessagePromptTemplate.from_template(template)
message = SystemMessagePromptTemplate.from_template(cast(str, template))
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
Expand Down Expand Up @@ -799,7 +942,9 @@ def _convert_to_message(
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
_message = message_type_str(prompt=PromptTemplate.from_template(template))
_message = message_type_str(
prompt=PromptTemplate.from_template(cast(str, template))
)
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")

Expand Down
Loading

0 comments on commit 38425c9

Please sign in to comment.