Skip to content

Commit

Permalink
openai[patch]: support streaming with json_schema response format (#2…
Browse files Browse the repository at this point in the history
…9044)

- Stream JSON string content. Final chunk includes parsed representation
(following OpenAI
[docs](https://platform.openai.com/docs/guides/structured-outputs#streaming)).
- Mildly (?) breaking change: if you were using streaming with
`response_format` before, usage metadata will disappear unless you set
`stream_usage=True`.

## Response format

Before:

![Screenshot 2025-01-06 at 11 59
01 AM](https://github.com/user-attachments/assets/e54753f7-47d5-421d-b8f3-172f32b3364d)


After:

![Screenshot 2025-01-06 at 11 58
13 AM](https://github.com/user-attachments/assets/34882c6c-2284-45b4-92f7-5b5b69896903)


## with_structured_output

For pydantic output, behavior of `with_structured_output` is unchanged
(except for warning disappearing), because we pluck the parsed
representation straight from OpenAI, and OpenAI doesn't return it until
the stream is completed. Open to alternatives (e.g., parsing from
content or intermediate dict chunks generated by OpenAI).

Before:

![Screenshot 2025-01-06 at 12 38
11 PM](https://github.com/user-attachments/assets/913d320d-f49e-4cbb-a800-b394ae817fd1)

After:

![Screenshot 2025-01-06 at 12 38
58 PM](https://github.com/user-attachments/assets/f7a45dd6-d886-48a6-8d76-d0e21ca767c6)
  • Loading branch information
ccurme authored Jan 9, 2025
1 parent 858f655 commit 815bfa1
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 35 deletions.
117 changes: 85 additions & 32 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,14 @@ def _convert_delta_to_message_chunk(
def _convert_chunk_to_generation_chunk(
chunk: dict, default_chunk_class: Type, base_generation_info: Optional[Dict]
) -> Optional[ChatGenerationChunk]:
if chunk.get("type") == "content.delta": # from beta.chat.completions.stream
return None
token_usage = chunk.get("usage")
choices = chunk.get("choices", [])
choices = (
chunk.get("choices", [])
# from beta.chat.completions.stream
or chunk.get("chunk", {}).get("choices", [])
)

usage_metadata: Optional[UsageMetadata] = (
_create_usage_metadata(token_usage) if token_usage else None
Expand Down Expand Up @@ -660,13 +666,24 @@ def _stream(
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}

if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response_stream = self.root_client.beta.chat.completions.stream(**payload)
context_manager = response_stream
else:
response = self.client.create(**payload)
with response:
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
context_manager = response
with context_manager as response:
is_first_chunk = True
for chunk in response:
if not isinstance(chunk, dict):
Expand All @@ -686,6 +703,16 @@ def _stream(
)
is_first_chunk = False
yield generation_chunk
if hasattr(response, "get_final_completion") and "response_format" in payload:
final_completion = response.get_final_completion()
generation_chunk = self._get_generation_chunk_from_completion(
final_completion
)
if run_manager:
run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk

def _generate(
self,
Expand Down Expand Up @@ -794,13 +821,29 @@ async def _astream(
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}

if "response_format" in payload:
if self.include_response_headers:
warnings.warn(
"Cannot currently include response headers when response_format is "
"specified."
)
payload.pop("stream")
response_stream = self.root_async_client.beta.chat.completions.stream(
**payload
)
context_manager = response_stream
else:
response = await self.async_client.create(**payload)
async with response:
if self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(
**payload
)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = await self.async_client.create(**payload)
context_manager = response
async with context_manager as response:
is_first_chunk = True
async for chunk in response:
if not isinstance(chunk, dict):
Expand All @@ -820,6 +863,16 @@ async def _astream(
)
is_first_chunk = False
yield generation_chunk
if hasattr(response, "get_final_completion") and "response_format" in payload:
final_completion = await response.get_final_completion()
generation_chunk = self._get_generation_chunk_from_completion(
final_completion
)
if run_manager:
await run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk
)
yield generation_chunk

async def _agenerate(
self,
Expand Down Expand Up @@ -1010,25 +1063,6 @@ def get_num_tokens_from_messages(
num_tokens += 3
return num_tokens

def _should_stream(
self,
*,
async_api: bool,
run_manager: Optional[
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
] = None,
response_format: Optional[Union[dict, type]] = None,
**kwargs: Any,
) -> bool:
if isinstance(response_format, type) and is_basemodel_subclass(response_format):
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
return False

return super()._should_stream(
async_api=async_api, run_manager=run_manager, **kwargs
)

@deprecated(
since="0.2.1",
alternative="langchain_openai.chat_models.base.ChatOpenAI.bind_tools",
Expand Down Expand Up @@ -1531,6 +1565,25 @@ def _filter_disabled_params(self, **kwargs: Any) -> Dict[str, Any]:
filtered[k] = v
return filtered

def _get_generation_chunk_from_completion(
self, completion: openai.BaseModel
) -> ChatGenerationChunk:
"""Get chunk from completion (e.g., from final completion of a stream)."""
chat_result = self._create_chat_result(completion)
chat_message = chat_result.generations[0].message
if isinstance(chat_message, AIMessage):
usage_metadata = chat_message.usage_metadata
else:
usage_metadata = None
message = AIMessageChunk(
content="",
additional_kwargs=chat_message.additional_kwargs,
usage_metadata=usage_metadata,
)
return ChatGenerationChunk(
message=message, generation_info=chat_result.llm_output
)


class ChatOpenAI(BaseChatOpenAI): # type: ignore[override]
"""OpenAI chat model integration.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
HumanMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from pydantic import BaseModel

from langchain_openai import AzureChatOpenAI
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
Expand Down Expand Up @@ -262,3 +263,37 @@ async def test_json_mode_async(llm: AzureChatOpenAI) -> None:
assert isinstance(full, AIMessageChunk)
assert isinstance(full.content, str)
assert json.loads(full.content) == {"a": 1}


class Foo(BaseModel):
response: str


def test_stream_response_format(llm: AzureChatOpenAI) -> None:
full: Optional[BaseMessageChunk] = None
chunks = []
for chunk in llm.stream("how are ya", response_format=Foo):
chunks.append(chunk)
full = chunk if full is None else full + chunk
assert len(chunks) > 1
assert isinstance(full, AIMessageChunk)
parsed = full.additional_kwargs["parsed"]
assert isinstance(parsed, Foo)
assert isinstance(full.content, str)
parsed_content = json.loads(full.content)
assert parsed.response == parsed_content["response"]


async def test_astream_response_format(llm: AzureChatOpenAI) -> None:
full: Optional[BaseMessageChunk] = None
chunks = []
async for chunk in llm.astream("how are ya", response_format=Foo):
chunks.append(chunk)
full = chunk if full is None else full + chunk
assert len(chunks) > 1
assert isinstance(full, AIMessageChunk)
parsed = full.additional_kwargs["parsed"]
assert isinstance(parsed, Foo)
assert isinstance(full.content, str)
parsed_content = json.loads(full.content)
assert parsed.response == parsed_content["response"]
Original file line number Diff line number Diff line change
Expand Up @@ -1092,14 +1092,37 @@ class Foo(BaseModel):


def test_stream_response_format() -> None:
list(ChatOpenAI(model="gpt-4o-mini").stream("how are ya", response_format=Foo))
full: Optional[BaseMessageChunk] = None
chunks = []
for chunk in ChatOpenAI(model="gpt-4o-mini").stream(
"how are ya", response_format=Foo
):
chunks.append(chunk)
full = chunk if full is None else full + chunk
assert len(chunks) > 1
assert isinstance(full, AIMessageChunk)
parsed = full.additional_kwargs["parsed"]
assert isinstance(parsed, Foo)
assert isinstance(full.content, str)
parsed_content = json.loads(full.content)
assert parsed.response == parsed_content["response"]


async def test_astream_response_format() -> None:
async for _ in ChatOpenAI(model="gpt-4o-mini").astream(
full: Optional[BaseMessageChunk] = None
chunks = []
async for chunk in ChatOpenAI(model="gpt-4o-mini").astream(
"how are ya", response_format=Foo
):
pass
chunks.append(chunk)
full = chunk if full is None else full + chunk
assert len(chunks) > 1
assert isinstance(full, AIMessageChunk)
parsed = full.additional_kwargs["parsed"]
assert isinstance(parsed, Foo)
assert isinstance(full.content, str)
parsed_content = json.loads(full.content)
assert parsed.response == parsed_content["response"]


@pytest.mark.parametrize("use_max_completion_tokens", [True, False])
Expand Down

0 comments on commit 815bfa1

Please sign in to comment.