Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Image prompt template #14263

Merged
merged 24 commits into from
Jan 28, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions libs/core/langchain_core/prompt_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from abc import ABC, abstractmethod
from typing import List, Literal, Sequence

from typing_extensions import TypedDict

from langchain_core.load.serializable import Serializable
from langchain_core.messages import (
AnyMessage,
Expand Down Expand Up @@ -82,6 +84,30 @@ def get_lc_namespace(cls) -> List[str]:
return ["langchain", "prompts", "chat"]


class ImageURL(TypedDict, total=False):
detail: Literal["auto", "low", "high"]
"""Specifies the detail level of the image."""

url: str
"""Either a URL of the image or the base64 encoded image data."""


class ImagePromptValue(PromptValue):
"""Image prompt value."""

image_url: ImageURL
"""Prompt image."""
type: Literal["ImagePromptValue"] = "ImagePromptValue"

def to_string(self) -> str:
"""Return prompt as string."""
return self.image_url["url"]

def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=[self.image_url])]


class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas."""
Expand Down
15 changes: 11 additions & 4 deletions libs/core/langchain_core/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
Any,
Callable,
Dict,
Generic,
List,
Mapping,
Optional,
Type,
TypeVar,
Union,
)

Expand All @@ -30,7 +32,12 @@
from langchain_core.documents import Document


class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
FormatOutputType = TypeVar("FormatOutputType")


class BasePromptTemplate(
RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC
):
"""Base class for all prompt templates, returning a prompt."""

input_variables: List[str]
Expand Down Expand Up @@ -143,7 +150,7 @@ def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
return {**partial_kwargs, **kwargs}

@abstractmethod
def format(self, **kwargs: Any) -> str:
def format(self, **kwargs: Any) -> FormatOutputType:
"""Format the prompt with the inputs.

Args:
Expand Down Expand Up @@ -211,7 +218,7 @@ def save(self, file_path: Union[Path, str]) -> None:
raise ValueError(f"{save_path} must be json or yaml")


def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
"""Format a document into a string based on a prompt template.

First, this pulls information from the document from two sources:
Expand All @@ -237,7 +244,7 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
Example:
.. code-block:: python

from langchain_core import Document
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate

doc = Document(page_content="This is a joke", metadata={"page": "1"})
Expand Down
217 changes: 181 additions & 36 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
Set,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
cast,
overload,
)

Expand All @@ -28,10 +30,11 @@
HumanMessage,
SystemMessage,
)
from langchain_core.prompt_values import ChatPromptValue, PromptValue
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.pydantic_v1 import Field, root_validator


Expand Down Expand Up @@ -256,34 +259,152 @@ def format(self, **kwargs: Any) -> BaseMessage:
)


class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
_StringImageMessagePromptTemplateT = TypeVar(
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
)


class _TextTemplateParam(TypedDict, total=False):
text: Union[str, Dict]


class _ImageTemplateParam(TypedDict, total=False):
image_url: Union[str, Dict]


class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""

prompt: Union[
StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]]
]
"""Prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template."""

_msg_class: Type[BaseMessage]

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]

def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
@classmethod
def from_template(
cls: Type[_StringImageMessagePromptTemplateT],
template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: str = "f-string",
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
"""Create a class from a string template.

Args:
**kwargs: Keyword arguments to use for formatting.
template: a template.
template_format: format of the template.
**kwargs: keyword arguments to pass to the constructor.

Returns:
Formatted message.
A new instance of this class.
"""
text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
if isinstance(template, str):
prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template(
template, template_format=template_format
)
return cls(prompt=prompt, **kwargs)
elif isinstance(template, list):
prompt = []
for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(_TextTemplateParam, tmpl)["text"] # type: ignore[assignment] # noqa: E501
prompt.append(
PromptTemplate.from_template(
text, template_format=template_format
)
)
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
if isinstance(img_template, str):
vars = get_template_variables(img_template, "f-string")
if vars:
if len(vars) > 1:
raise ValueError(
"Only one format variable allowed per image"
f" template.\nGot: {vars}"
f"\nFrom: {tmpl}"
)
input_variables = [vars[0]]
else:
input_variables = None
img_template = {"url": img_template}
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
)
elif isinstance(img_template, dict):
img_template = dict(img_template)
if "url" in img_template:
input_variables = get_template_variables(
img_template["url"], "f-string"
)
else:
input_variables = None
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
)
else:
raise ValueError()
prompt.append(img_template_obj)
else:
raise ValueError()
return cls(prompt=prompt, **kwargs)
else:
raise ValueError()

@classmethod
def from_template_file(
cls: Type[_StringImageMessagePromptTemplateT],
template_file: Union[str, Path],
input_variables: List[str],
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
"""Create a class from a template file.

class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI."""
Args:
template_file: path to a template file. String or Path.
input_variables: list of input variables.
**kwargs: keyword arguments to pass to the constructor.

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
Returns:
A new instance of this class.
"""
with open(str(template_file), "r") as f:
template = f.read()
return cls.from_template(template, input_variables=input_variables, **kwargs)

def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.

Args:
**kwargs: Keyword arguments to use for formatting.

Returns:
List of BaseMessages.
"""
return [self.format(**kwargs)]

@property
def input_variables(self) -> List[str]:
"""
Input variables for this prompt template.

Returns:
List of input variable names.
"""
prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt]
input_variables = [iv for prompt in prompts for iv in prompt.input_variables]
return input_variables

def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Expand All @@ -294,31 +415,54 @@ def format(self, **kwargs: Any) -> BaseMessage:
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
if isinstance(self.prompt, StringPromptTemplate):
text = self.prompt.format(**kwargs)
return self._msg_class(
content=text, additional_kwargs=self.additional_kwargs
)
else:
content = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = prompt.format(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = prompt.format(**inputs)
content.append({"type": "image_url", "image_url": formatted})
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
)


class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""System message prompt template.
This is a message that is not sent to the user.
"""
class HumanMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""

_msg_class: Type[BaseMessage] = HumanMessage


class AIMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI."""

_msg_class: Type[BaseMessage] = AIMessage

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]

def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.

Args:
**kwargs: Keyword arguments to use for formatting.
class SystemMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""System message prompt template.
This is a message that is not sent to the user.
"""

Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
_msg_class: Type[BaseMessage] = SystemMessage

@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]


class BaseChatPromptTemplate(BasePromptTemplate, ABC):
Expand Down Expand Up @@ -366,8 +510,7 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]:

MessageLikeRepresentation = Union[
MessageLike,
Tuple[str, str],
Tuple[Type, str],
Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
str,
]

Expand Down Expand Up @@ -700,7 +843,7 @@ def save(self, file_path: Union[Path, str]) -> None:


def _create_template_from_message_type(
message_type: str, template: str
message_type: str, template: Union[str, list]
) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string.

Expand All @@ -716,9 +859,9 @@ def _create_template_from_message_type(
template
)
elif message_type in ("ai", "assistant"):
message = AIMessagePromptTemplate.from_template(template)
message = AIMessagePromptTemplate.from_template(cast(str, template))
elif message_type == "system":
message = SystemMessagePromptTemplate.from_template(template)
message = SystemMessagePromptTemplate.from_template(cast(str, template))
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
Expand Down Expand Up @@ -761,7 +904,9 @@ def _convert_to_message(
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
_message = message_type_str(prompt=PromptTemplate.from_template(template))
_message = message_type_str(
prompt=PromptTemplate.from_template(cast(str, template))
)
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")

Expand Down
Loading
Loading