Skip to content

Commit

Permalink
standard-tests[patch]: fix oai usage metadata test (#27122)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Oct 4, 2024
1 parent 827bdf4 commit bd5b335
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 64 deletions.
26 changes: 16 additions & 10 deletions libs/partners/anthropic/tests/integration_tests/test_standard.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Standard LangChain interface tests"""

from pathlib import Path
from typing import List, Literal, Type, cast
from typing import Dict, List, Literal, Type, cast

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
Expand Down Expand Up @@ -36,16 +36,22 @@ def supports_anthropic_inputs(self) -> bool:
@property
def supported_usage_metadata_details(
self,
) -> List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
) -> Dict[
Literal["invoke", "stream"],
List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
],
]:
return ["cache_read_input", "cache_creation_input"]
return {
"invoke": ["cache_read_input", "cache_creation_input"],
"stream": ["cache_read_input", "cache_creation_input"],
}

def invoke_with_cache_creation_input(self, *, stream: bool = False) -> AIMessage:
llm = ChatAnthropic(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,30 +238,6 @@ class Person(BaseModel):
assert isinstance(generation, AIMessage)


def test_chat_openai_extra_kwargs() -> None:
"""Test extra kwargs to chat openai."""
# Check that foo is saved in extra_kwargs.
llm = ChatOpenAI(foo=3, max_tokens=10) # type: ignore[call-arg]
assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3}

# Test that if extra_kwargs are provided, they are added to it.
llm = ChatOpenAI(foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
assert llm.model_kwargs == {"foo": 3, "bar": 2}

# Test that if provided twice it errors
with pytest.raises(ValueError):
ChatOpenAI(foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]

# Test that if explicit param is specified in kwargs it errors
with pytest.raises(ValueError):
ChatOpenAI(model_kwargs={"temperature": 0.2})

# Test that "model" cannot be specified in kwargs
with pytest.raises(ValueError):
ChatOpenAI(model_kwargs={"model": "gpt-3.5-turbo-instruct"})


@pytest.mark.scheduled
def test_openai_streaming() -> None:
"""Test streaming tokens from OpenAI."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Standard LangChain interface tests"""

from pathlib import Path
from typing import List, Literal, Type, cast
from typing import Dict, List, Literal, Type, cast

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
Expand All @@ -28,16 +28,19 @@ def supports_image_inputs(self) -> bool:
@property
def supported_usage_metadata_details(
self,
) -> List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
) -> Dict[
Literal["invoke", "stream"],
List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
],
]:
return ["reasoning_output", "cache_read_input"]
return {"invoke": ["reasoning_output", "cache_read_input"], "stream": []}

def invoke_with_cache_read_input(self, *, stream: bool = False) -> AIMessage:
with open(REPO_ROOT_DIR / "README.md", "r") as f:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,25 +151,25 @@ def test_usage_metadata(self, model: BaseChatModel) -> None:
assert isinstance(result.usage_metadata["output_tokens"], int)
assert isinstance(result.usage_metadata["total_tokens"], int)

if "audio_input" in self.supported_usage_metadata_details:
if "audio_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_audio_input()
assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int) # type: ignore[index]
if "audio_output" in self.supported_usage_metadata_details:
if "audio_output" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_audio_output()
assert isinstance(msg.usage_metadata["output_token_details"]["audio"], int) # type: ignore[index]
if "reasoning_output" in self.supported_usage_metadata_details:
if "reasoning_output" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_reasoning_output()
assert isinstance(
msg.usage_metadata["output_token_details"]["reasoning"], # type: ignore[index]
int,
)
if "cache_read_input" in self.supported_usage_metadata_details:
if "cache_read_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_read_input()
assert isinstance(
msg.usage_metadata["input_token_details"]["cache_read"], # type: ignore[index]
int,
)
if "cache_creation_input" in self.supported_usage_metadata_details:
if "cache_creation_input" in self.supported_usage_metadata_details["invoke"]:
msg = self.invoke_with_cache_creation_input()
assert isinstance(
msg.usage_metadata["input_token_details"]["cache_creation"], # type: ignore[index]
Expand All @@ -189,25 +189,25 @@ def test_usage_metadata_streaming(self, model: BaseChatModel) -> None:
assert isinstance(full.usage_metadata["output_tokens"], int)
assert isinstance(full.usage_metadata["total_tokens"], int)

if "audio_input" in self.supported_usage_metadata_details:
if "audio_input" in self.supported_usage_metadata_details["stream"]:
msg = self.invoke_with_audio_input(stream=True)
assert isinstance(msg.usage_metadata["input_token_details"]["audio"], int) # type: ignore[index]
if "audio_output" in self.supported_usage_metadata_details:
if "audio_output" in self.supported_usage_metadata_details["stream"]:
msg = self.invoke_with_audio_output(stream=True)
assert isinstance(msg.usage_metadata["output_token_details"]["audio"], int) # type: ignore[index]
if "reasoning_output" in self.supported_usage_metadata_details:
if "reasoning_output" in self.supported_usage_metadata_details["stream"]:
msg = self.invoke_with_reasoning_output(stream=True)
assert isinstance(
msg.usage_metadata["output_token_details"]["reasoning"], # type: ignore[index]
int,
)
if "cache_read_input" in self.supported_usage_metadata_details:
if "cache_read_input" in self.supported_usage_metadata_details["stream"]:
msg = self.invoke_with_cache_read_input(stream=True)
assert isinstance(
msg.usage_metadata["input_token_details"]["cache_read"], # type: ignore[index]
int,
)
if "cache_creation_input" in self.supported_usage_metadata_details:
if "cache_creation_input" in self.supported_usage_metadata_details["stream"]:
msg = self.invoke_with_cache_creation_input(stream=True)
assert isinstance(
msg.usage_metadata["input_token_details"]["cache_creation"], # type: ignore[index]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
from abc import abstractmethod
from typing import Any, List, Literal, Optional, Tuple, Type
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from unittest import mock

import pytest
Expand Down Expand Up @@ -141,16 +141,19 @@ def supports_image_tool_message(self) -> bool:
@property
def supported_usage_metadata_details(
self,
) -> List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
) -> Dict[
Literal["invoke", "stream"],
List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
],
]:
return []
return {"invoke": [], "stream": []}


class ChatModelUnitTests(ChatModelTests):
Expand Down

0 comments on commit bd5b335

Please sign in to comment.