From 815bfa1913d0f42eec2a6c39b18c7abe07295c6c Mon Sep 17 00:00:00 2001 From: ccurme Date: Thu, 9 Jan 2025 10:32:30 -0500 Subject: [PATCH] openai[patch]: support streaming with json_schema response format (#29044) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Stream JSON string content. Final chunk includes parsed representation (following OpenAI [docs](https://platform.openai.com/docs/guides/structured-outputs#streaming)). - Mildly (?) breaking change: if you were using streaming with `response_format` before, usage metadata will disappear unless you set `stream_usage=True`. ## Response format Before: ![Screenshot 2025-01-06 at 11 59 01 AM](https://github.com/user-attachments/assets/e54753f7-47d5-421d-b8f3-172f32b3364d) After: ![Screenshot 2025-01-06 at 11 58 13 AM](https://github.com/user-attachments/assets/34882c6c-2284-45b4-92f7-5b5b69896903) ## with_structured_output For pydantic output, behavior of `with_structured_output` is unchanged (except for warning disappearing), because we pluck the parsed representation straight from OpenAI, and OpenAI doesn't return it until the stream is completed. Open to alternatives (e.g., parsing from content or intermediate dict chunks generated by OpenAI). Before: ![Screenshot 2025-01-06 at 12 38 11 PM](https://github.com/user-attachments/assets/913d320d-f49e-4cbb-a800-b394ae817fd1) After: ![Screenshot 2025-01-06 at 12 38 58 PM](https://github.com/user-attachments/assets/f7a45dd6-d886-48a6-8d76-d0e21ca767c6) --- .../langchain_openai/chat_models/base.py | 117 +++++++++++++----- .../chat_models/test_azure.py | 35 ++++++ .../chat_models/test_base.py | 29 ++++- 3 files changed, 146 insertions(+), 35 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 142e7eca1a84b..f4e26253484e5 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -318,8 +318,14 @@ def _convert_delta_to_message_chunk( def _convert_chunk_to_generation_chunk( chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict] ) -> Optional[ChatGenerationChunk]: + if chunk.get("type") == "content.delta": # from beta.chat.completions.stream + return None token_usage = chunk.get("usage") - choices = chunk.get("choices", []) + choices = ( + chunk.get("choices", []) + # from beta.chat.completions.stream + or chunk.get("chunk", {}).get("choices", []) + ) usage_metadata: Optional[UsageMetadata] = ( _create_usage_metadata(token_usage) if token_usage else None @@ -660,13 +666,24 @@ def _stream( default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} - if self.include_response_headers: - raw_response = self.client.with_raw_response.create(**payload) - response = raw_response.parse() - base_generation_info = {"headers": dict(raw_response.headers)} + if "response_format" in payload: + if self.include_response_headers: + warnings.warn( + "Cannot currently include response headers when response_format is " + "specified." + ) + payload.pop("stream") + response_stream = self.root_client.beta.chat.completions.stream(**payload) + context_manager = response_stream else: - response = self.client.create(**payload) - with response: + if self.include_response_headers: + raw_response = self.client.with_raw_response.create(**payload) + response = raw_response.parse() + base_generation_info = {"headers": dict(raw_response.headers)} + else: + response = self.client.create(**payload) + context_manager = response + with context_manager as response: is_first_chunk = True for chunk in response: if not isinstance(chunk, dict): @@ -686,6 +703,16 @@ def _stream( ) is_first_chunk = False yield generation_chunk + if hasattr(response, "get_final_completion") and "response_format" in payload: + final_completion = response.get_final_completion() + generation_chunk = self._get_generation_chunk_from_completion( + final_completion + ) + if run_manager: + run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk def _generate( self, @@ -794,13 +821,29 @@ async def _astream( payload = self._get_request_payload(messages, stop=stop, **kwargs) default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk base_generation_info = {} - if self.include_response_headers: - raw_response = await self.async_client.with_raw_response.create(**payload) - response = raw_response.parse() - base_generation_info = {"headers": dict(raw_response.headers)} + + if "response_format" in payload: + if self.include_response_headers: + warnings.warn( + "Cannot currently include response headers when response_format is " + "specified." + ) + payload.pop("stream") + response_stream = self.root_async_client.beta.chat.completions.stream( + **payload + ) + context_manager = response_stream else: - response = await self.async_client.create(**payload) - async with response: + if self.include_response_headers: + raw_response = await self.async_client.with_raw_response.create( + **payload + ) + response = raw_response.parse() + base_generation_info = {"headers": dict(raw_response.headers)} + else: + response = await self.async_client.create(**payload) + context_manager = response + async with context_manager as response: is_first_chunk = True async for chunk in response: if not isinstance(chunk, dict): @@ -820,6 +863,16 @@ async def _astream( ) is_first_chunk = False yield generation_chunk + if hasattr(response, "get_final_completion") and "response_format" in payload: + final_completion = await response.get_final_completion() + generation_chunk = self._get_generation_chunk_from_completion( + final_completion + ) + if run_manager: + await run_manager.on_llm_new_token( + generation_chunk.text, chunk=generation_chunk + ) + yield generation_chunk async def _agenerate( self, @@ -1010,25 +1063,6 @@ def get_num_tokens_from_messages( num_tokens += 3 return num_tokens - def _should_stream( - self, - *, - async_api: bool, - run_manager: Optional[ - Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun] - ] = None, - response_format: Optional[Union[dict, type]] = None, - **kwargs: Any, - ) -> bool: - if isinstance(response_format, type) and is_basemodel_subclass(response_format): - # TODO: Add support for streaming with Pydantic response_format. - warnings.warn("Streaming with Pydantic response_format not yet supported.") - return False - - return super()._should_stream( - async_api=async_api, run_manager=run_manager, **kwargs - ) - @deprecated( since="0.2.1", alternative="langchain_openai.chat_models.base.ChatOpenAI.bind_tools", @@ -1531,6 +1565,25 @@ def _filter_disabled_params(self, **kwargs: Any) -> Dict[str, Any]: filtered[k] = v return filtered + def _get_generation_chunk_from_completion( + self, completion: openai.BaseModel + ) -> ChatGenerationChunk: + """Get chunk from completion (e.g., from final completion of a stream).""" + chat_result = self._create_chat_result(completion) + chat_message = chat_result.generations[0].message + if isinstance(chat_message, AIMessage): + usage_metadata = chat_message.usage_metadata + else: + usage_metadata = None + message = AIMessageChunk( + content="", + additional_kwargs=chat_message.additional_kwargs, + usage_metadata=usage_metadata, + ) + return ChatGenerationChunk( + message=message, generation_info=chat_result.llm_output + ) + class ChatOpenAI(BaseChatOpenAI): # type: ignore[override] """OpenAI chat model integration. diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py index 4ed531ad119ff..e9228df0730c5 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py @@ -13,6 +13,7 @@ HumanMessage, ) from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult +from pydantic import BaseModel from langchain_openai import AzureChatOpenAI from tests.unit_tests.fake.callbacks import FakeCallbackHandler @@ -262,3 +263,37 @@ async def test_json_mode_async(llm: AzureChatOpenAI) -> None: assert isinstance(full, AIMessageChunk) assert isinstance(full.content, str) assert json.loads(full.content) == {"a": 1} + + +class Foo(BaseModel): + response: str + + +def test_stream_response_format(llm: AzureChatOpenAI) -> None: + full: Optional[BaseMessageChunk] = None + chunks = [] + for chunk in llm.stream("how are ya", response_format=Foo): + chunks.append(chunk) + full = chunk if full is None else full + chunk + assert len(chunks) > 1 + assert isinstance(full, AIMessageChunk) + parsed = full.additional_kwargs["parsed"] + assert isinstance(parsed, Foo) + assert isinstance(full.content, str) + parsed_content = json.loads(full.content) + assert parsed.response == parsed_content["response"] + + +async def test_astream_response_format(llm: AzureChatOpenAI) -> None: + full: Optional[BaseMessageChunk] = None + chunks = [] + async for chunk in llm.astream("how are ya", response_format=Foo): + chunks.append(chunk) + full = chunk if full is None else full + chunk + assert len(chunks) > 1 + assert isinstance(full, AIMessageChunk) + parsed = full.additional_kwargs["parsed"] + assert isinstance(parsed, Foo) + assert isinstance(full.content, str) + parsed_content = json.loads(full.content) + assert parsed.response == parsed_content["response"] diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 93c08ce214178..506799aef4b59 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1092,14 +1092,37 @@ class Foo(BaseModel): def test_stream_response_format() -> None: - list(ChatOpenAI(model="gpt-4o-mini").stream("how are ya", response_format=Foo)) + full: Optional[BaseMessageChunk] = None + chunks = [] + for chunk in ChatOpenAI(model="gpt-4o-mini").stream( + "how are ya", response_format=Foo + ): + chunks.append(chunk) + full = chunk if full is None else full + chunk + assert len(chunks) > 1 + assert isinstance(full, AIMessageChunk) + parsed = full.additional_kwargs["parsed"] + assert isinstance(parsed, Foo) + assert isinstance(full.content, str) + parsed_content = json.loads(full.content) + assert parsed.response == parsed_content["response"] async def test_astream_response_format() -> None: - async for _ in ChatOpenAI(model="gpt-4o-mini").astream( + full: Optional[BaseMessageChunk] = None + chunks = [] + async for chunk in ChatOpenAI(model="gpt-4o-mini").astream( "how are ya", response_format=Foo ): - pass + chunks.append(chunk) + full = chunk if full is None else full + chunk + assert len(chunks) > 1 + assert isinstance(full, AIMessageChunk) + parsed = full.additional_kwargs["parsed"] + assert isinstance(parsed, Foo) + assert isinstance(full.content, str) + parsed_content = json.loads(full.content) + assert parsed.response == parsed_content["response"] @pytest.mark.parametrize("use_max_completion_tokens", [True, False])