diff --git a/libs/core/langchain_core/prompt_values.py b/libs/core/langchain_core/prompt_values.py index d0d1a1047336e..4c599f9f6a037 100644 --- a/libs/core/langchain_core/prompt_values.py +++ b/libs/core/langchain_core/prompt_values.py @@ -3,6 +3,8 @@ from abc import ABC, abstractmethod from typing import List, Literal, Sequence +from typing_extensions import TypedDict + from langchain_core.load.serializable import Serializable from langchain_core.messages import ( AnyMessage, @@ -82,6 +84,30 @@ def get_lc_namespace(cls) -> List[str]: return ["langchain", "prompts", "chat"] +class ImageURL(TypedDict, total=False): + detail: Literal["auto", "low", "high"] + """Specifies the detail level of the image.""" + + url: str + """Either a URL of the image or the base64 encoded image data.""" + + +class ImagePromptValue(PromptValue): + """Image prompt value.""" + + image_url: ImageURL + """Prompt image.""" + type: Literal["ImagePromptValue"] = "ImagePromptValue" + + def to_string(self) -> str: + """Return prompt as string.""" + return self.image_url["url"] + + def to_messages(self) -> List[BaseMessage]: + """Return prompt as messages.""" + return [HumanMessage(content=[self.image_url])] + + class ChatPromptValueConcrete(ChatPromptValue): """Chat prompt value which explicitly lists out the message types it accepts. For use in external schemas.""" diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 2e1878ed27e3c..07ac72225547c 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -8,10 +8,12 @@ Any, Callable, Dict, + Generic, List, Mapping, Optional, Type, + TypeVar, Union, ) @@ -30,7 +32,12 @@ from langchain_core.documents import Document -class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): +FormatOutputType = TypeVar("FormatOutputType") + + +class BasePromptTemplate( + RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC +): """Base class for all prompt templates, returning a prompt.""" input_variables: List[str] @@ -142,7 +149,7 @@ def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: return {**partial_kwargs, **kwargs} @abstractmethod - def format(self, **kwargs: Any) -> str: + def format(self, **kwargs: Any) -> FormatOutputType: """Format the prompt with the inputs. Args: @@ -210,7 +217,7 @@ def save(self, file_path: Union[Path, str]) -> None: raise ValueError(f"{save_path} must be json or yaml") -def format_document(doc: Document, prompt: BasePromptTemplate) -> str: +def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str: """Format a document into a string based on a prompt template. First, this pulls information from the document from two sources: @@ -236,7 +243,7 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str: Example: .. code-block:: python - from langchain_core import Document + from langchain_core.documents import Document from langchain_core.prompts import PromptTemplate doc = Document(page_content="This is a joke", metadata={"page": "1"}) diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index b03e0be291325..6553614d102d8 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -13,8 +13,10 @@ Set, Tuple, Type, + TypedDict, TypeVar, Union, + cast, overload, ) @@ -30,10 +32,11 @@ convert_to_messages, ) from langchain_core.messages.base import get_msg_title_repr -from langchain_core.prompt_values import ChatPromptValue, PromptValue +from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.prompts.image import ImagePromptTemplate from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.prompts.string import StringPromptTemplate +from langchain_core.prompts.string import StringPromptTemplate, get_template_variables from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.utils import get_colored_text from langchain_core.utils.interactive_env import is_interactive_env @@ -288,34 +291,152 @@ def format(self, **kwargs: Any) -> BaseMessage: ) -class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate): +_StringImageMessagePromptTemplateT = TypeVar( + "_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate" +) + + +class _TextTemplateParam(TypedDict, total=False): + text: Union[str, Dict] + + +class _ImageTemplateParam(TypedDict, total=False): + image_url: Union[str, Dict] + + +class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): """Human message prompt template. This is a message sent from the user.""" + prompt: Union[ + StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]] + ] + """Prompt template.""" + additional_kwargs: dict = Field(default_factory=dict) + """Additional keyword arguments to pass to the prompt template.""" + + _msg_class: Type[BaseMessage] + @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "chat"] - def format(self, **kwargs: Any) -> BaseMessage: - """Format the prompt template. + @classmethod + def from_template( + cls: Type[_StringImageMessagePromptTemplateT], + template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]], + template_format: str = "f-string", + **kwargs: Any, + ) -> _StringImageMessagePromptTemplateT: + """Create a class from a string template. Args: - **kwargs: Keyword arguments to use for formatting. + template: a template. + template_format: format of the template. + **kwargs: keyword arguments to pass to the constructor. Returns: - Formatted message. + A new instance of this class. """ - text = self.prompt.format(**kwargs) - return HumanMessage(content=text, additional_kwargs=self.additional_kwargs) + if isinstance(template, str): + prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template( + template, template_format=template_format + ) + return cls(prompt=prompt, **kwargs) + elif isinstance(template, list): + prompt = [] + for tmpl in template: + if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl: + if isinstance(tmpl, str): + text: str = tmpl + else: + text = cast(_TextTemplateParam, tmpl)["text"] # type: ignore[assignment] # noqa: E501 + prompt.append( + PromptTemplate.from_template( + text, template_format=template_format + ) + ) + elif isinstance(tmpl, dict) and "image_url" in tmpl: + img_template = cast(_ImageTemplateParam, tmpl)["image_url"] + if isinstance(img_template, str): + vars = get_template_variables(img_template, "f-string") + if vars: + if len(vars) > 1: + raise ValueError( + "Only one format variable allowed per image" + f" template.\nGot: {vars}" + f"\nFrom: {tmpl}" + ) + input_variables = [vars[0]] + else: + input_variables = None + img_template = {"url": img_template} + img_template_obj = ImagePromptTemplate( + input_variables=input_variables, template=img_template + ) + elif isinstance(img_template, dict): + img_template = dict(img_template) + if "url" in img_template: + input_variables = get_template_variables( + img_template["url"], "f-string" + ) + else: + input_variables = None + img_template_obj = ImagePromptTemplate( + input_variables=input_variables, template=img_template + ) + else: + raise ValueError() + prompt.append(img_template_obj) + else: + raise ValueError() + return cls(prompt=prompt, **kwargs) + else: + raise ValueError() + @classmethod + def from_template_file( + cls: Type[_StringImageMessagePromptTemplateT], + template_file: Union[str, Path], + input_variables: List[str], + **kwargs: Any, + ) -> _StringImageMessagePromptTemplateT: + """Create a class from a template file. -class AIMessagePromptTemplate(BaseStringMessagePromptTemplate): - """AI message prompt template. This is a message sent from the AI.""" + Args: + template_file: path to a template file. String or Path. + input_variables: list of input variables. + **kwargs: keyword arguments to pass to the constructor. - @classmethod - def get_lc_namespace(cls) -> List[str]: - """Get the namespace of the langchain object.""" - return ["langchain", "prompts", "chat"] + Returns: + A new instance of this class. + """ + with open(str(template_file), "r") as f: + template = f.read() + return cls.from_template(template, input_variables=input_variables, **kwargs) + + def format_messages(self, **kwargs: Any) -> List[BaseMessage]: + """Format messages from kwargs. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + List of BaseMessages. + """ + return [self.format(**kwargs)] + + @property + def input_variables(self) -> List[str]: + """ + Input variables for this prompt template. + + Returns: + List of input variable names. + """ + prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt] + input_variables = [iv for prompt in prompts for iv in prompt.input_variables] + return input_variables def format(self, **kwargs: Any) -> BaseMessage: """Format the prompt template. @@ -326,31 +447,54 @@ def format(self, **kwargs: Any) -> BaseMessage: Returns: Formatted message. """ - text = self.prompt.format(**kwargs) - return AIMessage(content=text, additional_kwargs=self.additional_kwargs) + if isinstance(self.prompt, StringPromptTemplate): + text = self.prompt.format(**kwargs) + return self._msg_class( + content=text, additional_kwargs=self.additional_kwargs + ) + else: + content = [] + for prompt in self.prompt: + inputs = {var: kwargs[var] for var in prompt.input_variables} + if isinstance(prompt, StringPromptTemplate): + formatted: Union[str, ImageURL] = prompt.format(**inputs) + content.append({"type": "text", "text": formatted}) + elif isinstance(prompt, ImagePromptTemplate): + formatted = prompt.format(**inputs) + content.append({"type": "image_url", "image_url": formatted}) + return self._msg_class( + content=content, additional_kwargs=self.additional_kwargs + ) -class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate): - """System message prompt template. - This is a message that is not sent to the user. - """ +class HumanMessagePromptTemplate(_StringImageMessagePromptTemplate): + """Human message prompt template. This is a message sent from the user.""" + + _msg_class: Type[BaseMessage] = HumanMessage + + +class AIMessagePromptTemplate(_StringImageMessagePromptTemplate): + """AI message prompt template. This is a message sent from the AI.""" + + _msg_class: Type[BaseMessage] = AIMessage @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "prompts", "chat"] - def format(self, **kwargs: Any) -> BaseMessage: - """Format the prompt template. - Args: - **kwargs: Keyword arguments to use for formatting. +class SystemMessagePromptTemplate(_StringImageMessagePromptTemplate): + """System message prompt template. + This is a message that is not sent to the user. + """ - Returns: - Formatted message. - """ - text = self.prompt.format(**kwargs) - return SystemMessage(content=text, additional_kwargs=self.additional_kwargs) + _msg_class: Type[BaseMessage] = SystemMessage + + @classmethod + def get_lc_namespace(cls) -> List[str]: + """Get the namespace of the langchain object.""" + return ["langchain", "prompts", "chat"] class BaseChatPromptTemplate(BasePromptTemplate, ABC): @@ -405,8 +549,7 @@ def pretty_print(self) -> None: MessageLikeRepresentation = Union[ MessageLike, - Tuple[str, str], - Tuple[Type, str], + Tuple[Union[str, Type], Union[str, List[dict], List[object]]], str, ] @@ -738,7 +881,7 @@ def pretty_repr(self, html: bool = False) -> str: def _create_template_from_message_type( - message_type: str, template: str + message_type: str, template: Union[str, list] ) -> BaseMessagePromptTemplate: """Create a message prompt template from a message type and template string. @@ -754,9 +897,9 @@ def _create_template_from_message_type( template ) elif message_type in ("ai", "assistant"): - message = AIMessagePromptTemplate.from_template(template) + message = AIMessagePromptTemplate.from_template(cast(str, template)) elif message_type == "system": - message = SystemMessagePromptTemplate.from_template(template) + message = SystemMessagePromptTemplate.from_template(cast(str, template)) else: raise ValueError( f"Unexpected message type: {message_type}. Use one of 'human'," @@ -799,7 +942,9 @@ def _convert_to_message( if isinstance(message_type_str, str): _message = _create_template_from_message_type(message_type_str, template) else: - _message = message_type_str(prompt=PromptTemplate.from_template(template)) + _message = message_type_str( + prompt=PromptTemplate.from_template(cast(str, template)) + ) else: raise NotImplementedError(f"Unsupported message type: {type(message)}") diff --git a/libs/core/langchain_core/prompts/image.py b/libs/core/langchain_core/prompts/image.py new file mode 100644 index 0000000000000..d3d2d94da13dc --- /dev/null +++ b/libs/core/langchain_core/prompts/image.py @@ -0,0 +1,76 @@ +from typing import Any + +from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.pydantic_v1 import Field +from langchain_core.utils import image as image_utils + + +class ImagePromptTemplate(BasePromptTemplate[ImageURL]): + """An image prompt template for a multimodal model.""" + + template: dict = Field(default_factory=dict) + """Template for the prompt.""" + + def __init__(self, **kwargs: Any) -> None: + if "input_variables" not in kwargs: + kwargs["input_variables"] = [] + + overlap = set(kwargs["input_variables"]) & set(("url", "path", "detail")) + if overlap: + raise ValueError( + "input_variables for the image template cannot contain" + " any of 'url', 'path', or 'detail'." + f" Found: {overlap}" + ) + super().__init__(**kwargs) + + @property + def _prompt_type(self) -> str: + """Return the prompt type key.""" + return "image-prompt" + + def format_prompt(self, **kwargs: Any) -> PromptValue: + """Create Chat Messages.""" + return ImagePromptValue(image_url=self.format(**kwargs)) + + def format( + self, + **kwargs: Any, + ) -> ImageURL: + """Format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + formatted = {} + for k, v in self.template.items(): + if isinstance(v, str): + formatted[k] = v.format(**kwargs) + else: + formatted[k] = v + url = kwargs.get("url") or formatted.get("url") + path = kwargs.get("path") or formatted.get("path") + detail = kwargs.get("detail") or formatted.get("detail") + if not url and not path: + raise ValueError("Must provide either url or path.") + if not url: + if not isinstance(path, str): + raise ValueError("path must be a string.") + url = image_utils.image_to_data_url(path) + if not isinstance(url, str): + raise ValueError("url must be a string.") + output: ImageURL = {"url": url} + if detail: + # Don't check literal values here: let the API check them + output["detail"] = detail # type: ignore[typeddict-item] + return output diff --git a/libs/core/langchain_core/utils/__init__.py b/libs/core/langchain_core/utils/__init__.py index 6491a85f17fb7..92f919bac399b 100644 --- a/libs/core/langchain_core/utils/__init__.py +++ b/libs/core/langchain_core/utils/__init__.py @@ -4,6 +4,7 @@ These functions do not depend on any other LangChain module. """ +from langchain_core.utils import image from langchain_core.utils.env import get_from_dict_or_env, get_from_env from langchain_core.utils.formatting import StrictFormatter, formatter from langchain_core.utils.input import ( @@ -41,6 +42,7 @@ "xor_args", "try_load_from_hub", "build_extra_kwargs", + "image", "get_from_env", "get_from_dict_or_env", "stringify_dict", diff --git a/libs/core/langchain_core/utils/image.py b/libs/core/langchain_core/utils/image.py new file mode 100644 index 0000000000000..b59682bd37f1b --- /dev/null +++ b/libs/core/langchain_core/utils/image.py @@ -0,0 +1,14 @@ +import base64 +import mimetypes + + +def encode_image(image_path: str) -> str: + """Get base64 string from image URI.""" + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +def image_to_data_url(image_path: str) -> str: + encoding = encode_image(image_path) + mime_type = mimetypes.guess_type(image_path)[0] + return f"data:{mime_type};base64,{encoding}" diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 0f3198bf26243..029244afe8c25 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -3,6 +3,9 @@ import pytest +from langchain_core._api.deprecation import ( + LangChainPendingDeprecationWarning, +) from langchain_core.messages import ( AIMessage, BaseMessage, @@ -243,14 +246,15 @@ def test_chat_valid_infer_variables() -> None: def test_chat_from_role_strings() -> None: """Test instantiation of chat template from role strings.""" - template = ChatPromptTemplate.from_role_strings( - [ - ("system", "You are a bot."), - ("assistant", "hello!"), - ("human", "{question}"), - ("other", "{quack}"), - ] - ) + with pytest.warns(LangChainPendingDeprecationWarning): + template = ChatPromptTemplate.from_role_strings( + [ + ("system", "You are a bot."), + ("assistant", "hello!"), + ("human", "{question}"), + ("other", "{quack}"), + ] + ) messages = template.format_messages(question="How are you?", quack="duck") assert messages == [ @@ -363,6 +367,136 @@ def test_chat_message_partial() -> None: assert template2.format(input="hello") == get_buffer_string(expected) +def test_chat_tmpl_from_messages_multipart_text() -> None: + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are an AI assistant named {name}."), + ( + "human", + [ + {"type": "text", "text": "What's in this image?"}, + {"type": "text", "text": "Oh nvm"}, + ], + ), + ] + ) + messages = template.format_messages(name="R2D2") + expected = [ + SystemMessage(content="You are an AI assistant named R2D2."), + HumanMessage( + content=[ + {"type": "text", "text": "What's in this image?"}, + {"type": "text", "text": "Oh nvm"}, + ] + ), + ] + assert messages == expected + + +def test_chat_tmpl_from_messages_multipart_text_with_template() -> None: + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are an AI assistant named {name}."), + ( + "human", + [ + {"type": "text", "text": "What's in this {object_name}?"}, + {"type": "text", "text": "Oh nvm"}, + ], + ), + ] + ) + messages = template.format_messages(name="R2D2", object_name="image") + expected = [ + SystemMessage(content="You are an AI assistant named R2D2."), + HumanMessage( + content=[ + {"type": "text", "text": "What's in this image?"}, + {"type": "text", "text": "Oh nvm"}, + ] + ), + ] + assert messages == expected + + +def test_chat_tmpl_from_messages_multipart_image() -> None: + base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA" + other_base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA" + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are an AI assistant named {name}."), + ( + "human", + [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": "data:image/jpeg;base64,{my_image}", + }, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,{my_image}"}, + }, + {"type": "image_url", "image_url": "{my_other_image}"}, + { + "type": "image_url", + "image_url": {"url": "{my_other_image}", "detail": "medium"}, + }, + { + "type": "image_url", + "image_url": {"url": "https://www.langchain.com/image.png"}, + }, + { + "type": "image_url", + "image_url": {"url": ""}, + }, + ], + ), + ] + ) + messages = template.format_messages( + name="R2D2", my_image=base64_image, my_other_image=other_base64_image + ) + expected = [ + SystemMessage(content="You are an AI assistant named R2D2."), + HumanMessage( + content=[ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{other_base64_image}" + }, + }, + { + "type": "image_url", + "image_url": {"url": f"{other_base64_image}"}, + }, + { + "type": "image_url", + "image_url": { + "url": f"{other_base64_image}", + "detail": "medium", + }, + }, + { + "type": "image_url", + "image_url": {"url": "https://www.langchain.com/image.png"}, + }, + { + "type": "image_url", + "image_url": {"url": ""}, + }, + ] + ), + ] + assert messages == expected + + def test_messages_placeholder() -> None: prompt = MessagesPlaceholder("history") with pytest.raises(KeyError): diff --git a/libs/core/tests/unit_tests/utils/test_imports.py b/libs/core/tests/unit_tests/utils/test_imports.py index ce56c02026f30..64528cfd521b2 100644 --- a/libs/core/tests/unit_tests/utils/test_imports.py +++ b/libs/core/tests/unit_tests/utils/test_imports.py @@ -16,6 +16,7 @@ "xor_args", "try_load_from_hub", "build_extra_kwargs", + "image", "get_from_dict_or_env", "get_from_env", "stringify_dict",