Skip to content

Commit

Permalink
rm output generic
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Feb 15, 2024
1 parent 49ca7d0 commit be3b084
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar
from typing import Any, Dict, Generic, TypeVar, Union

from langchain_core.language_models import LanguageModelInput
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable

_OutputSchema = TypeVar("_OutputSchema")
_FormattedOutput = TypeVar("_FormattedOutput")


class FormattedOutputMixin(Generic[_OutputSchema, _FormattedOutput], ABC):
class StructuredOutputMixin(Generic[_OutputSchema], ABC):
"""Mixin for language models that offer native output formatting."""

@abstractmethod
def with_output_format(
def with_structured_output(
self, schema: _OutputSchema, **kwargs: Any
) -> Runnable[LanguageModelInput, _FormattedOutput]:
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Implement this if there is a way of steering the model to generate responses that match a given schema.""" # noqa: E501
8 changes: 3 additions & 5 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.language_models.output_format import FormattedOutputMixin
from langchain_core.language_models.structured_output import StructuredOutputMixin
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand Down Expand Up @@ -214,9 +214,7 @@ class _AllReturnType(TypedDict):
parsing_error: Optional[BaseException]


class ChatOpenAI(
BaseChatModel, FormattedOutputMixin[_DictOrPydanticClass, _DictOrPydantic]
):
class ChatOpenAI(BaseChatModel, StructuredOutputMixin[_DictOrPydanticClass]):
"""`OpenAI` Chat large language models API.
To use, you should have the
Expand Down Expand Up @@ -773,7 +771,7 @@ def with_output_format(
) -> Runnable[LanguageModelInput, _DictOrPydantic]:
...

def with_output_format(
def with_structured_output(
self,
schema: _DictOrPydanticClass,
*,
Expand Down

0 comments on commit be3b084

Please sign in to comment.