From edbe7d5f5e0dcc771c1f53a49bb784a3960ce448 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Sat, 28 Dec 2024 15:46:51 -0500 Subject: [PATCH] core,anthropic[patch]: fix with_structured_output typing (#28950) --- libs/core/langchain_core/language_models/base.py | 2 +- .../core/langchain_core/language_models/chat_models.py | 10 +++++----- libs/core/langchain_core/prompts/structured.py | 4 ++-- .../anthropic/langchain_anthropic/chat_models.py | 5 ++--- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index a6db99a495111..051550dfe7f85 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -233,7 +233,7 @@ async def agenerate_prompt( """ def with_structured_output( - self, schema: Union[dict, type[BaseModel]], **kwargs: Any + self, schema: Union[dict, type], **kwargs: Any ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: """Not implemented on this class.""" # Implement this on child class if there is a way of steering the model to diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 39fd11c247f9f..516485654ce37 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -1128,7 +1128,7 @@ def with_structured_output( The output schema. Can be passed in as: - an OpenAI function/tool schema, - a JSON Schema, - - a TypedDict class (support added in 0.2.26), + - a TypedDict class, - or a Pydantic class. If ``schema`` is a Pydantic class then the model output will be a Pydantic instance of that class, and the model-generated fields will be @@ -1137,10 +1137,6 @@ def with_structured_output( for more on how to properly specify types and descriptions of schema fields when specifying a Pydantic or TypedDict class. - .. versionchanged:: 0.2.26 - - Added support for TypedDict class. - include_raw: If False then only the parsed structured output is returned. If an error occurs during model output parsing it will be raised. If True @@ -1222,6 +1218,10 @@ class AnswerWithJustification(BaseModel): # 'answer': 'They weigh the same', # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' # } + + .. versionchanged:: 0.2.26 + + Added support for TypedDict class. """ # noqa: E501 if kwargs: msg = f"Received unsupported arguments {kwargs}" diff --git a/libs/core/langchain_core/prompts/structured.py b/libs/core/langchain_core/prompts/structured.py index 543dbb57e5d90..a3a01cd68f3e1 100644 --- a/libs/core/langchain_core/prompts/structured.py +++ b/libs/core/langchain_core/prompts/structured.py @@ -28,7 +28,7 @@ class StructuredPrompt(ChatPromptTemplate): """Structured prompt template for a language model.""" - schema_: Union[dict, type[BaseModel]] + schema_: Union[dict, type] """Schema for the structured prompt.""" structured_output_kwargs: dict[str, Any] = Field(default_factory=dict) @@ -66,7 +66,7 @@ def get_lc_namespace(cls) -> list[str]: def from_messages_and_schema( cls, messages: Sequence[MessageLikeRepresentation], - schema: Union[dict, type[BaseModel]], + schema: Union[dict, type], **kwargs: Any, ) -> ChatPromptTemplate: """Create a chat prompt template from a variety of message formats. diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 6eb9dc4bca61b..fd64b824a8df7 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -16,7 +16,6 @@ Sequence, Tuple, Type, - TypedDict, Union, cast, ) @@ -72,7 +71,7 @@ SecretStr, model_validator, ) -from typing_extensions import NotRequired +from typing_extensions import NotRequired, TypedDict from langchain_anthropic.output_parsers import extract_tool_calls @@ -973,7 +972,7 @@ class GetPrice(BaseModel): def with_structured_output( self, - schema: Union[Dict, Type[BaseModel]], + schema: Union[Dict, type], *, include_raw: bool = False, **kwargs: Any,