diff --git a/libs/partners/mistralai/langchain_mistralai/embeddings.py b/libs/partners/mistralai/langchain_mistralai/embeddings.py index bb792139fc8d8..7c14536a31503 100644 --- a/libs/partners/mistralai/langchain_mistralai/embeddings.py +++ b/libs/partners/mistralai/langchain_mistralai/embeddings.py @@ -4,6 +4,7 @@ from typing import Iterable, List import httpx +from httpx import Response from langchain_core.embeddings import Embeddings from langchain_core.utils import ( secret_from_env, @@ -15,6 +16,7 @@ SecretStr, model_validator, ) +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed from tokenizers import Tokenizer # type: ignore from typing_extensions import Self @@ -58,6 +60,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings): The number of times to retry a request if it fails. timeout: int The number of seconds to wait for a response before timing out. + wait_time: int + The number of seconds to wait before retrying a request in case of 429 error. max_concurrent_requests: int The maximum number of concurrent requests to make to the Mistral API. @@ -128,6 +132,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings): endpoint: str = "https://api.mistral.ai/v1/" max_retries: int = 5 timeout: int = 120 + wait_time: int = 30 max_concurrent_requests: int = 64 tokenizer: Tokenizer = Field(default=None) @@ -215,16 +220,26 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: List of embeddings, one for each text. """ try: - batch_responses = ( - self.client.post( + batch_responses = [] + + @retry( + retry=retry_if_exception_type(httpx.TimeoutException), + wait=wait_fixed(self.wait_time), + stop=stop_after_attempt(self.max_retries), + ) + def _embed_batch(batch: List[str]) -> Response: + response = self.client.post( url="/embeddings", json=dict( model=self.model, input=batch, ), ) - for batch in self._get_batches(texts) - ) + response.raise_for_status() + return response + + for batch in self._get_batches(texts): + batch_responses.append(_embed_batch(batch)) return [ list(map(float, embedding_obj["embedding"])) for response in batch_responses