From 73f56e7865f9903cb063c4dd3b8687d42244465d Mon Sep 17 00:00:00 2001 From: jinno Date: Wed, 24 Jul 2024 20:37:05 +0900 Subject: [PATCH] fix: update for lint about chat_models.py --- .../langchain_google_genai/chat_models.py | 8 ++-- .../integration_tests/test_chat_models.py | 11 +++-- .../integration_tests/test_embeddings.py | 3 +- .../tests/integration_tests/test_llms.py | 8 +++- .../tests/unit_tests/test_chat_models.py | 47 ++++++++++++------- .../genai/tests/unit_tests/test_embeddings.py | 14 +++--- 6 files changed, 54 insertions(+), 37 deletions(-) diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index 7233f170..85032191 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -818,8 +818,8 @@ class Joke(BaseModel): client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private: google_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") - """Google AI API key. - + """Google AI API key. + If not specified will be read from env var ``GOOGLE_API_KEY``.""" default_metadata: Sequence[Tuple[str, str]] = Field( default_factory=list @@ -827,8 +827,8 @@ class Joke(BaseModel): convert_system_message_to_human: bool = False """Whether to merge any leading SystemMessage into the following HumanMessage. - - Gemini does not support system messages; any unsupported messages will + + Gemini does not support system messages; any unsupported messages will raise an error.""" class Config: diff --git a/libs/genai/tests/integration_tests/test_chat_models.py b/libs/genai/tests/integration_tests/test_chat_models.py index 3e3afded..60bee7b5 100644 --- a/libs/genai/tests/integration_tests/test_chat_models.py +++ b/libs/genai/tests/integration_tests/test_chat_models.py @@ -1,9 +1,14 @@ """Test ChatGoogleGenerativeAI chat model.""" + import asyncio import json from typing import Generator, List, Optional, Type import pytest +from google.generativeai.types import ( # type: ignore[import] + HarmBlockThreshold, + HarmCategory, +) from langchain_core.language_models import BaseChatModel from langchain_core.messages import ( AIMessage, @@ -18,11 +23,7 @@ from langchain_core.tools import tool from langchain_standard_tests.integration_tests import ChatModelIntegrationTests -from langchain_google_genai import ( - ChatGoogleGenerativeAI, - HarmBlockThreshold, - HarmCategory, -) +from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError _MODEL = "models/gemini-1.0-pro-001" # TODO: Use nano when it's available. diff --git a/libs/genai/tests/integration_tests/test_embeddings.py b/libs/genai/tests/integration_tests/test_embeddings.py index 63b9a344..1412ab18 100644 --- a/libs/genai/tests/integration_tests/test_embeddings.py +++ b/libs/genai/tests/integration_tests/test_embeddings.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from langchain_core.pydantic_v1 import SecretStr from langchain_google_genai._common import GoogleGenerativeAIError from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings @@ -75,7 +76,7 @@ def test_invalid_api_key_error_handling() -> None: """Test error handling with an invalid API key.""" with pytest.raises(GoogleGenerativeAIError): GoogleGenerativeAIEmbeddings( - model=_MODEL, google_api_key="invalid_key" + model=_MODEL, google_api_key=SecretStr("invalid_key") ).embed_query("Hello world") diff --git a/libs/genai/tests/integration_tests/test_llms.py b/libs/genai/tests/integration_tests/test_llms.py index dc9d3309..5e35c382 100644 --- a/libs/genai/tests/integration_tests/test_llms.py +++ b/libs/genai/tests/integration_tests/test_llms.py @@ -7,9 +7,13 @@ from typing import Generator import pytest +from google.generativeai.types import ( # type: ignore[import] + HarmBlockThreshold, + HarmCategory, +) from langchain_core.outputs import LLMResult -from langchain_google_genai import GoogleGenerativeAI, HarmBlockThreshold, HarmCategory +from langchain_google_genai import GoogleGenerativeAI model_names = ["models/text-bison-001", "gemini-pro"] @@ -23,7 +27,7 @@ def test_google_generativeai_call(model_name: str) -> None: if model_name: llm = GoogleGenerativeAI(max_output_tokens=10, model=model_name) else: - llm = GoogleGenerativeAI(max_output_tokens=10) + llm = GoogleGenerativeAI(max_output_tokens=10, model=model_names[0]) output = llm("Say foo:") assert isinstance(output, str) assert llm._llm_type == "google_palm" diff --git a/libs/genai/tests/unit_tests/test_chat_models.py b/libs/genai/tests/unit_tests/test_chat_models.py index c6e40261..242d4401 100644 --- a/libs/genai/tests/unit_tests/test_chat_models.py +++ b/libs/genai/tests/unit_tests/test_chat_models.py @@ -61,7 +61,7 @@ def test_integration_initialization() -> None: """Test chat model initialization.""" llm = ChatGoogleGenerativeAI( model="gemini-nano", - google_api_key="...", + api_key=SecretStr("..."), top_k=2, top_p=1, temperature=0.7, @@ -77,7 +77,7 @@ def test_integration_initialization() -> None: llm = ChatGoogleGenerativeAI( model="gemini-nano", - google_api_key="...", + api_key=SecretStr("..."), max_output_tokens=10, ) ls_params = llm._get_ls_params() @@ -91,11 +91,10 @@ def test_integration_initialization() -> None: ChatGoogleGenerativeAI( model="gemini-nano", - google_api_key="...", + api_key=SecretStr("..."), top_k=2, top_p=1, temperature=0.7, - candidate_count=2, ) @@ -104,19 +103,23 @@ def test_initialization_inside_threadpool() -> None: # thread pool executor easiest way to create one with ThreadPoolExecutor() as executor: executor.submit( - ChatGoogleGenerativeAI, model="gemini-nano", google_api_key="secret-api-key" + ChatGoogleGenerativeAI, + model="gemini-nano", + api_key=SecretStr("secret-api-key"), ).result() def test_initalization_without_async() -> None: - chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key") + chat = ChatGoogleGenerativeAI( + model="gemini-nano", api_key=SecretStr("secret-api-key") + ) assert chat.async_client is None def test_initialization_with_async() -> None: async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI: return ChatGoogleGenerativeAI( - model="gemini-nano", google_api_key="secret-api-key" + model="gemini-nano", api_key=SecretStr("secret-api-key") ) loop = asyncio.get_event_loop() @@ -125,12 +128,16 @@ async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI: def test_api_key_is_string() -> None: - chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key") + chat = ChatGoogleGenerativeAI( + model="gemini-nano", api_key=SecretStr("secret-api-key") + ) assert isinstance(chat.google_api_key, SecretStr) def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None: - chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key") + chat = ChatGoogleGenerativeAI( + model="gemini-nano", api_key=SecretStr("secret-api-key") + ) print(chat.google_api_key, end="") # noqa: T201 captured = capsys.readouterr() @@ -271,18 +278,22 @@ def test_additional_headers_support(headers: Optional[Dict[str, str]]) -> None: ) mock_client.return_value.generate_content = mock_generate_content api_endpoint = "http://127.0.0.1:8000/ai" - params = { - "google_api_key": "[secret]", - "client_options": {"api_endpoint": api_endpoint}, - "transport": "rest", - "additional_headers": headers, - } + param_api_key = "[secret]" + param_secret_api_key = SecretStr(param_api_key) + param_client_options = {"api_endpoint": api_endpoint} + param_transport = "rest" with patch( "langchain_google_genai._genai_extension.v1betaGenerativeServiceClient", mock_client, ): - chat = ChatGoogleGenerativeAI(model="gemini-pro", **params) + chat = ChatGoogleGenerativeAI( + model="gemini-pro", + api_key=param_secret_api_key, + client_options=param_client_options, + transport=param_transport, + additional_headers=headers, + ) expected_default_metadata: tuple = () if not headers: @@ -297,12 +308,12 @@ def test_additional_headers_support(headers: Optional[Dict[str, str]]) -> None: assert response.content == "test response" mock_client.assert_called_once_with( - transport=params["transport"], + transport=param_transport, client_options=ANY, client_info=ANY, ) call_client_options = mock_client.call_args_list[0].kwargs["client_options"] - assert call_client_options.api_key == params["google_api_key"] + assert call_client_options.api_key == param_api_key assert call_client_options.api_endpoint == api_endpoint call_client_info = mock_client.call_args_list[0].kwargs["client_info"] assert "langchain-google-genai" in call_client_info.user_agent diff --git a/libs/genai/tests/unit_tests/test_embeddings.py b/libs/genai/tests/unit_tests/test_embeddings.py index 285f40d5..ef901bd9 100644 --- a/libs/genai/tests/unit_tests/test_embeddings.py +++ b/libs/genai/tests/unit_tests/test_embeddings.py @@ -21,7 +21,7 @@ def test_integration_initialization() -> None: ) as mock_prediction_service: _ = GoogleGenerativeAIEmbeddings( model="models/embedding-001", - google_api_key="...", + google_api_key=SecretStr("..."), ) mock_prediction_service.assert_called_once() client_info = mock_prediction_service.call_args.kwargs["client_info"] @@ -34,7 +34,7 @@ def test_integration_initialization() -> None: ) as mock_prediction_service: _ = GoogleGenerativeAIEmbeddings( model="models/embedding-001", - google_api_key="...", + google_api_key=SecretStr("..."), task_type="retrieval_document", ) mock_prediction_service.assert_called_once() @@ -43,7 +43,7 @@ def test_integration_initialization() -> None: def test_api_key_is_string() -> None: embeddings = GoogleGenerativeAIEmbeddings( model="models/embedding-001", - google_api_key="secret-api-key", + google_api_key=SecretStr("secret-api-key"), ) assert isinstance(embeddings.google_api_key, SecretStr) @@ -51,7 +51,7 @@ def test_api_key_is_string() -> None: def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None: embeddings = GoogleGenerativeAIEmbeddings( model="models/embedding-001", - google_api_key="secret-api-key", + google_api_key=SecretStr("secret-api-key"), ) print(embeddings.google_api_key, end="") # noqa: T201 captured = capsys.readouterr() @@ -71,7 +71,7 @@ def test_embed_query() -> None: llm = GoogleGenerativeAIEmbeddings( model="models/embedding-test", - google_api_key="test-key", + google_api_key=SecretStr("test-key"), task_type="classification", ) @@ -102,7 +102,7 @@ def test_embed_documents() -> None: llm = GoogleGenerativeAIEmbeddings( model="models/embedding-test", - google_api_key="test-key", + google_api_key=SecretStr("test-key"), ) llm.embed_documents(["test text", "test text2"], titles=["title1", "title2"]) @@ -140,7 +140,7 @@ def test_embed_documents_with_numerous_texts() -> None: llm = GoogleGenerativeAIEmbeddings( model="models/embedding-test", - google_api_key="test-key", + google_api_key=SecretStr("test-key"), ) llm.embed_documents(