Skip to content

Commit

Permalink
mistral[minor]: Added Retrying Mechanism in case of Request Rate Limi…
Browse files Browse the repository at this point in the history
…t Error for `MistralAIEmbeddings` (#27818)

- **Description:**: In the event of a Rate Limit Error from the
MistralAI server, the response JSON raises a KeyError. To address this,
a simple retry mechanism has been implemented to handle cases where the
request limit is exceeded.
  - **Issue:** #27790

---------

Co-authored-by: Eugene Yurtsev <[email protected]>
  • Loading branch information
keenborder786 and eyurtsev authored Dec 11, 2024
1 parent df5008f commit a37afbe
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions libs/partners/mistralai/langchain_mistralai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Iterable, List

import httpx
from httpx import Response
from langchain_core.embeddings import Embeddings
from langchain_core.utils import (
secret_from_env,
Expand All @@ -15,6 +16,7 @@
SecretStr,
model_validator,
)
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
from tokenizers import Tokenizer # type: ignore
from typing_extensions import Self

Expand Down Expand Up @@ -58,6 +60,8 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
The number of times to retry a request if it fails.
timeout: int
The number of seconds to wait for a response before timing out.
wait_time: int
The number of seconds to wait before retrying a request in case of 429 error.
max_concurrent_requests: int
The maximum number of concurrent requests to make to the Mistral API.
Expand Down Expand Up @@ -128,6 +132,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
endpoint: str = "https://api.mistral.ai/v1/"
max_retries: int = 5
timeout: int = 120
wait_time: int = 30
max_concurrent_requests: int = 64
tokenizer: Tokenizer = Field(default=None)

Expand Down Expand Up @@ -215,16 +220,26 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
List of embeddings, one for each text.
"""
try:
batch_responses = (
self.client.post(
batch_responses = []

@retry(
retry=retry_if_exception_type(httpx.TimeoutException),
wait=wait_fixed(self.wait_time),
stop=stop_after_attempt(self.max_retries),
)
def _embed_batch(batch: List[str]) -> Response:
response = self.client.post(
url="/embeddings",
json=dict(
model=self.model,
input=batch,
),
)
for batch in self._get_batches(texts)
)
response.raise_for_status()
return response

for batch in self._get_batches(texts):
batch_responses.append(_embed_batch(batch))
return [
list(map(float, embedding_obj["embedding"]))
for response in batch_responses
Expand Down

0 comments on commit a37afbe

Please sign in to comment.