Skip to content

Commit

Permalink
fix: update for lint about chat_models.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nobu007 committed Aug 18, 2024
1 parent 4c49a20 commit 73f56e7
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 37 deletions.
8 changes: 4 additions & 4 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,17 +818,17 @@ class Joke(BaseModel):
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
google_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Google AI API key.
"""Google AI API key.
If not specified will be read from env var ``GOOGLE_API_KEY``."""
default_metadata: Sequence[Tuple[str, str]] = Field(
default_factory=list
) #: :meta private:

convert_system_message_to_human: bool = False
"""Whether to merge any leading SystemMessage into the following HumanMessage.
Gemini does not support system messages; any unsupported messages will
Gemini does not support system messages; any unsupported messages will
raise an error."""

class Config:
Expand Down
11 changes: 6 additions & 5 deletions libs/genai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""Test ChatGoogleGenerativeAI chat model."""

import asyncio
import json
from typing import Generator, List, Optional, Type

import pytest
from google.generativeai.types import ( # type: ignore[import]
HarmBlockThreshold,
HarmCategory,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
Expand All @@ -18,11 +23,7 @@
from langchain_core.tools import tool
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests

from langchain_google_genai import (
ChatGoogleGenerativeAI,
HarmBlockThreshold,
HarmCategory,
)
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError

_MODEL = "models/gemini-1.0-pro-001" # TODO: Use nano when it's available.
Expand Down
3 changes: 2 additions & 1 deletion libs/genai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from langchain_core.pydantic_v1 import SecretStr

from langchain_google_genai._common import GoogleGenerativeAIError
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
Expand Down Expand Up @@ -75,7 +76,7 @@ def test_invalid_api_key_error_handling() -> None:
"""Test error handling with an invalid API key."""
with pytest.raises(GoogleGenerativeAIError):
GoogleGenerativeAIEmbeddings(
model=_MODEL, google_api_key="invalid_key"
model=_MODEL, google_api_key=SecretStr("invalid_key")
).embed_query("Hello world")


Expand Down
8 changes: 6 additions & 2 deletions libs/genai/tests/integration_tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from typing import Generator

import pytest
from google.generativeai.types import ( # type: ignore[import]
HarmBlockThreshold,
HarmCategory,
)
from langchain_core.outputs import LLMResult

from langchain_google_genai import GoogleGenerativeAI, HarmBlockThreshold, HarmCategory
from langchain_google_genai import GoogleGenerativeAI

model_names = ["models/text-bison-001", "gemini-pro"]

Expand All @@ -23,7 +27,7 @@ def test_google_generativeai_call(model_name: str) -> None:
if model_name:
llm = GoogleGenerativeAI(max_output_tokens=10, model=model_name)
else:
llm = GoogleGenerativeAI(max_output_tokens=10)
llm = GoogleGenerativeAI(max_output_tokens=10, model=model_names[0])
output = llm("Say foo:")
assert isinstance(output, str)
assert llm._llm_type == "google_palm"
Expand Down
47 changes: 29 additions & 18 deletions libs/genai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_integration_initialization() -> None:
"""Test chat model initialization."""
llm = ChatGoogleGenerativeAI(
model="gemini-nano",
google_api_key="...",
api_key=SecretStr("..."),
top_k=2,
top_p=1,
temperature=0.7,
Expand All @@ -77,7 +77,7 @@ def test_integration_initialization() -> None:

llm = ChatGoogleGenerativeAI(
model="gemini-nano",
google_api_key="...",
api_key=SecretStr("..."),
max_output_tokens=10,
)
ls_params = llm._get_ls_params()
Expand All @@ -91,11 +91,10 @@ def test_integration_initialization() -> None:

ChatGoogleGenerativeAI(
model="gemini-nano",
google_api_key="...",
api_key=SecretStr("..."),
top_k=2,
top_p=1,
temperature=0.7,
candidate_count=2,
)


Expand All @@ -104,19 +103,23 @@ def test_initialization_inside_threadpool() -> None:
# thread pool executor easiest way to create one
with ThreadPoolExecutor() as executor:
executor.submit(
ChatGoogleGenerativeAI, model="gemini-nano", google_api_key="secret-api-key"
ChatGoogleGenerativeAI,
model="gemini-nano",
api_key=SecretStr("secret-api-key"),
).result()


def test_initalization_without_async() -> None:
chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key")
chat = ChatGoogleGenerativeAI(
model="gemini-nano", api_key=SecretStr("secret-api-key")
)
assert chat.async_client is None


def test_initialization_with_async() -> None:
async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI:
return ChatGoogleGenerativeAI(
model="gemini-nano", google_api_key="secret-api-key"
model="gemini-nano", api_key=SecretStr("secret-api-key")
)

loop = asyncio.get_event_loop()
Expand All @@ -125,12 +128,16 @@ async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI:


def test_api_key_is_string() -> None:
chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key")
chat = ChatGoogleGenerativeAI(
model="gemini-nano", api_key=SecretStr("secret-api-key")
)
assert isinstance(chat.google_api_key, SecretStr)


def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key")
chat = ChatGoogleGenerativeAI(
model="gemini-nano", api_key=SecretStr("secret-api-key")
)
print(chat.google_api_key, end="") # noqa: T201
captured = capsys.readouterr()

Expand Down Expand Up @@ -271,18 +278,22 @@ def test_additional_headers_support(headers: Optional[Dict[str, str]]) -> None:
)
mock_client.return_value.generate_content = mock_generate_content
api_endpoint = "http://127.0.0.1:8000/ai"
params = {
"google_api_key": "[secret]",
"client_options": {"api_endpoint": api_endpoint},
"transport": "rest",
"additional_headers": headers,
}
param_api_key = "[secret]"
param_secret_api_key = SecretStr(param_api_key)
param_client_options = {"api_endpoint": api_endpoint}
param_transport = "rest"

with patch(
"langchain_google_genai._genai_extension.v1betaGenerativeServiceClient",
mock_client,
):
chat = ChatGoogleGenerativeAI(model="gemini-pro", **params)
chat = ChatGoogleGenerativeAI(
model="gemini-pro",
api_key=param_secret_api_key,
client_options=param_client_options,
transport=param_transport,
additional_headers=headers,
)

expected_default_metadata: tuple = ()
if not headers:
Expand All @@ -297,12 +308,12 @@ def test_additional_headers_support(headers: Optional[Dict[str, str]]) -> None:
assert response.content == "test response"

mock_client.assert_called_once_with(
transport=params["transport"],
transport=param_transport,
client_options=ANY,
client_info=ANY,
)
call_client_options = mock_client.call_args_list[0].kwargs["client_options"]
assert call_client_options.api_key == params["google_api_key"]
assert call_client_options.api_key == param_api_key
assert call_client_options.api_endpoint == api_endpoint
call_client_info = mock_client.call_args_list[0].kwargs["client_info"]
assert "langchain-google-genai" in call_client_info.user_agent
Expand Down
14 changes: 7 additions & 7 deletions libs/genai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_integration_initialization() -> None:
) as mock_prediction_service:
_ = GoogleGenerativeAIEmbeddings(
model="models/embedding-001",
google_api_key="...",
google_api_key=SecretStr("..."),
)
mock_prediction_service.assert_called_once()
client_info = mock_prediction_service.call_args.kwargs["client_info"]
Expand All @@ -34,7 +34,7 @@ def test_integration_initialization() -> None:
) as mock_prediction_service:
_ = GoogleGenerativeAIEmbeddings(
model="models/embedding-001",
google_api_key="...",
google_api_key=SecretStr("..."),
task_type="retrieval_document",
)
mock_prediction_service.assert_called_once()
Expand All @@ -43,15 +43,15 @@ def test_integration_initialization() -> None:
def test_api_key_is_string() -> None:
embeddings = GoogleGenerativeAIEmbeddings(
model="models/embedding-001",
google_api_key="secret-api-key",
google_api_key=SecretStr("secret-api-key"),
)
assert isinstance(embeddings.google_api_key, SecretStr)


def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
embeddings = GoogleGenerativeAIEmbeddings(
model="models/embedding-001",
google_api_key="secret-api-key",
google_api_key=SecretStr("secret-api-key"),
)
print(embeddings.google_api_key, end="") # noqa: T201
captured = capsys.readouterr()
Expand All @@ -71,7 +71,7 @@ def test_embed_query() -> None:

llm = GoogleGenerativeAIEmbeddings(
model="models/embedding-test",
google_api_key="test-key",
google_api_key=SecretStr("test-key"),
task_type="classification",
)

Expand Down Expand Up @@ -102,7 +102,7 @@ def test_embed_documents() -> None:

llm = GoogleGenerativeAIEmbeddings(
model="models/embedding-test",
google_api_key="test-key",
google_api_key=SecretStr("test-key"),
)

llm.embed_documents(["test text", "test text2"], titles=["title1", "title2"])
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_embed_documents_with_numerous_texts() -> None:

llm = GoogleGenerativeAIEmbeddings(
model="models/embedding-test",
google_api_key="test-key",
google_api_key=SecretStr("test-key"),
)

llm.embed_documents(
Expand Down

0 comments on commit 73f56e7

Please sign in to comment.