Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw committed Dec 12, 2023
1 parent cb5077e commit ce3339b
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 61 deletions.
75 changes: 31 additions & 44 deletions libs/partners/google/langchain_google/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,31 +51,33 @@

if TYPE_CHECKING:
import google.generativeai as genai

IMAGE_TYPES = ()
try:
import PIL.Image
from PIL import Image

IMAGE_TYPES = IMAGE_TYPES + (PIL.Image.Image,)
except ImportError:
PIL = None
Image = None

try:
import IPython.display

IMAGE_TYPES = IMAGE_TYPES + (IPython.display.Image,)
except ImportError:
IPython = None
pass


class ChatGoogleGeminiError(Exception):
class ChatGoogleGenerativeAIError(Exception):
"""
Custom exception class for errors associated with the `Google Gemini` API.
Custom exception class for errors associated with the `Google GenAI` API.
This exception is raised when there are specific issues related to the
Google Gemini API usage in the ChatGoogleGemini class, such as unsupported
Google genai API usage in the ChatGoogleGenerativeAI class, such as unsupported
message types or roles.
"""

pass


def _create_retry_decorator() -> Callable[[Any], Any]:
"""
Expand Down Expand Up @@ -125,42 +127,27 @@ def chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
Any: The result from the chat generation method.
"""
retry_decorator = _create_retry_decorator()
from google.api_core.exceptions import InvalidArgument # type: ignore

@retry_decorator
def _chat_with_retry(**kwargs: Any) -> Any:
return generation_method(**kwargs)
try:
return generation_method(**kwargs)
except InvalidArgument as e:
# Do not retry for these errors.
raise ChatGoogleGenerativeAIError(
f"Invalid argument provided to Gemini: {e}"
) from e
except Exception as e:
raise e

return _chat_with_retry(**kwargs)


async def achat_with_retry(*, generation_method: Awaitable, **kwargs: Any) -> Any:
"""
Asynchronously executes a chat generation method with retry logic.
Similar to `chat_with_retry`, this function applies a retry decorator for
asynchronous chat generation methods. It handles retries for tasks like
generating responses from a language model.
Args:
generation_method (Awaitable): The async chat generation method to be executed.
**kwargs (Any): Additional keyword arguments to pass to the generation method.
Returns:
Any: The result from the async chat generation method.
"""
retry_decorator = _create_retry_decorator()

@retry_decorator
async def _achat_with_retry(**kwargs: Any) -> Any:
return await generation_method(**kwargs)

return await _achat_with_retry(**kwargs)


def _get_role(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
if message.role not in ("user", "model"):
raise ChatGoogleGeminiError(
raise ChatGoogleGenerativeAIError(
"Gemini only supports user and model roles when"
" providing it with Chat messages."
)
Expand All @@ -171,7 +158,7 @@ def _get_role(message: BaseMessage) -> str:
return "model"
else:
# TODO: Gemini doesn't seem to have a concept of system messages yet.
raise ChatGoogleGeminiError(
raise ChatGoogleGenerativeAIError(
f"Message of '{message.type}' type not supported by Gemini."
" Please only provide it with Human or AI (user/assistant) messages."
)
Expand All @@ -181,6 +168,10 @@ def _is_openai_parts_format(part: dict) -> bool:
return "type" in part


def _is_vision_model(model: str):
return "vision" in model


def _is_url(s: str) -> bool:
try:
result = urlparse(s)
Expand Down Expand Up @@ -219,7 +210,7 @@ def _url_to_pil(image_source: str) -> Image:
"with `pip install pillow`"
)
try:
if isinstance(image_source, (Image.Image, IPython.display.Image)):
if isinstance(image_source, IMAGE_TYPES):
return image_source
elif _is_url(image_source):
if image_source.startswith("gs://"):
Expand Down Expand Up @@ -264,7 +255,6 @@ def _convert_to_parts(
f"Unrecognized message image format: {img_url}"
)
img_url = img_url["url"]

parts.append({"inline_data": _url_to_pil(img_url)})
else:
raise ValueError(f"Unrecognized message part type: {part['type']}")
Expand All @@ -277,7 +267,7 @@ def _convert_to_parts(
else:
# TODO: Maybe some of Google's native stuff
# would hit this branch.
raise ChatGoogleGeminiError(
raise ChatGoogleGenerativeAIError(
"Gemini only supports text and inline_data parts."
)
return parts
Expand All @@ -289,7 +279,6 @@ def _messages_to_genai_contents(
"""Converts a list of messages into a Gemini API google content dicts."""

messages: List[genai.types.MessageDict] = []

for i, message in enumerate(input_messages):
role = _get_role(message)
if isinstance(message.content, str):
Expand All @@ -300,7 +289,7 @@ def _messages_to_genai_contents(
if i > 0:
# Cannot have multiple messages from the same role in a row.
if role == messages[-2]["role"]:
raise ChatGoogleGeminiError(
raise ChatGoogleGenerativeAIError(
"Cannot have multiple messages from the same role in a row."
" Consider merging them into a single message with multiple"
f" parts.\nReceived: {messages}"
Expand All @@ -327,7 +316,7 @@ def _parts_to_content(parts: List[genai.types.PartType]) -> Union[List[dict], st
)
else:
# TODO: Handle inline_data if that's a thing?
raise ChatGoogleGeminiError(f"Unexpected part type. {part}")
raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
return messages


Expand Down Expand Up @@ -450,14 +439,12 @@ def validate_environment(cls, values: Dict) -> Dict:

genai.configure(api_key=google_api_key)
except ImportError:
raise ChatGoogleGeminiError(
raise ChatGoogleGenerativeAIError(
"Could not import google.generativeai python package. "
"Please install it with `pip install google-generativeai`"
)

values["client"] = genai
genai.count_text_tokens()

if (
values.get("temperature") is not None
and not 0 <= values["temperature"] <= 1
Expand All @@ -470,7 +457,7 @@ def validate_environment(cls, values: Dict) -> Dict:
if values.get("top_k") is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
model = values["model"]
values["_generative_model"] = genai.GenerativeModel(model=model)
values["_generative_model"] = genai.GenerativeModel(model_name=model)
return values

@property
Expand Down
69 changes: 68 additions & 1 deletion libs/partners/google/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/partners/google/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ optional = true

[tool.poetry.group.dev.dependencies]
langchain-core = {path = "../../core", develop = true}
pillow = "^10.1.0"

[tool.ruff]
select = [
Expand Down
Loading

0 comments on commit ce3339b

Please sign in to comment.