From bbc3e3b2cfc20381b212f8fcb463cab56946ab0c Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Fri, 10 Jan 2025 18:24:11 -0800 Subject: [PATCH] openai: disable streaming for o1 by default (#29147) Currently 400s https://community.openai.com/t/streaming-support-for-o1-o1-2024-12-17-resulting-in-400-unsupported-value/1085043 o1-mini and o1-preview stream fine --- .../openai/langchain_openai/chat_models/base.py | 9 +++++++++ .../integration_tests/chat_models/test_base.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 624e465ded9c5..3e77e6463391a 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -562,6 +562,15 @@ def validate_temperature(cls, values: Dict[str, Any]) -> Any: values["temperature"] = 1 return values + @model_validator(mode="before") + @classmethod + def validate_disable_streaming(cls, values: Dict[str, Any]) -> Any: + """Disable streaming if n > 1.""" + model = values.get("model_name") or values.get("model") or "" + if model == "o1" and values.get("disable_streaming") is None: + values["disable_streaming"] = True + return values + @model_validator(mode="after") def validate_environment(self) -> Self: """Validate that api key and python package exists in environment.""" diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index d116688ebef61..49d894603d726 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -1192,3 +1192,19 @@ def test_o1(use_max_completion_tokens: bool) -> None: assert isinstance(response, AIMessage) assert isinstance(response.content, str) assert response.content.upper() == response.content + + +@pytest.mark.scheduled +def test_o1_doesnt_stream() -> None: + """ + When this starts failing, remove the `disable_streaming` validator in + `BaseChatOpenAI` + """ + with pytest.raises(openai.BadRequestError): + list(ChatOpenAI(model="o1", disable_streaming=False).stream("how are you")) + + +@pytest.mark.scheduled +def test_o1_stream_default_works() -> None: + result = list(ChatOpenAI(model="o1").stream("say 'hi'")) + assert len(result) > 0