From 88976563623f5a33fc2360272d2732aebead9cf0 Mon Sep 17 00:00:00 2001 From: Johannes Date: Wed, 13 Nov 2024 19:33:28 +0100 Subject: [PATCH] feat: add clients initialization in validator --- .../langchain_community/chat_models/writer.py | 27 ++++++++++++++++--- .../unit_tests/chat_models/test_writer.py | 10 ++----- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/libs/community/langchain_community/chat_models/writer.py b/libs/community/langchain_community/chat_models/writer.py index 6dc2becb54444..4d76a6a5f4dea 100644 --- a/libs/community/langchain_community/chat_models/writer.py +++ b/libs/community/langchain_community/chat_models/writer.py @@ -38,8 +38,9 @@ ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable +from langchain_core.utils import get_from_dict_or_env from langchain_core.utils.function_calling import convert_to_openai_tool -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator logger = logging.getLogger(__name__) @@ -66,8 +67,10 @@ class ChatWriter(BaseChatModel): ) """ - client: Any = Field(exclude=True) #: :meta private: - async_client: Any = Field(exclude=True) #: :meta private: + client: Any = Field(default=None, exclude=True) #: :meta private: + async_client: Any = Field(default=None, exclude=True) #: :meta private: + writer_api_key: Optional[SecretStr] = Field(default=None) + """Writer API key.""" model_name: str = Field(default="palmyra-x-004", alias="model") """Model name to use.""" temperature: float = 0.7 @@ -106,6 +109,24 @@ def _default_params(self) -> Dict[str, Any]: **self.model_kwargs, } + @model_validator(mode="before") + def validate_environment(self, values: Dict) -> Any: + """Validates that api key is passed and creates Writer clients.""" + try: + from writerai import AsyncClient, Client + except ImportError as e: + raise ImportError( + "Could not import writerai python package. " + "Please install it with `pip install writerai`." + ) from e + + if not (values["client"] and values["async_client"]): + api_key = get_from_dict_or_env(values, "api_key", "WRITER_API_KEY") + values["client"] = Client(api_key=api_key) + values["async_client"] = AsyncClient(api_key=api_key) + + return values + def _create_chat_result(self, response: Any) -> ChatResult: generations = [] for choice in response.choices: diff --git a/libs/community/tests/unit_tests/chat_models/test_writer.py b/libs/community/tests/unit_tests/chat_models/test_writer.py index 21a942f7a5f7a..8b20a10ef7adf 100644 --- a/libs/community/tests/unit_tests/chat_models/test_writer.py +++ b/libs/community/tests/unit_tests/chat_models/test_writer.py @@ -106,6 +106,7 @@ def __init__( self.choices = choices +@pytest.mark.requires("writer-sdk") class TestChatWriterCustom: """Test case for ChatWriter""" @@ -114,24 +115,16 @@ def test_writer_model_param(self) -> None: test_cases: List[dict] = [ { "model_name": "palmyra-x-004", - "client": MagicMock(), - "async_client": AsyncMock(), }, { "model": "palmyra-x-004", - "client": MagicMock(), - "async_client": AsyncMock(), }, { "model_name": "palmyra-x-004", - "client": MagicMock(), - "async_client": AsyncMock(), }, { "model": "palmyra-x-004", "temperature": 0.5, - "client": MagicMock(), - "async_client": AsyncMock(), }, ] @@ -423,6 +416,7 @@ class GetWeather(BaseModel): assert response.tool_calls[0]["args"]["location"] == "London" +@pytest.mark.requires("writer-sdk") class TestChatWriterStandart(ChatModelUnitTests): """Test case for ChatWriter that inherits from standard LangChain tests."""