Skip to content

Commit

Permalink
update the gcp name logic
Browse files Browse the repository at this point in the history
  • Loading branch information
GaspardBT committed Nov 8, 2024
1 parent dee4be6 commit fcf4cd4
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions packages/mistralai_gcp/src/mistralai_gcp/sdk.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Code generated by Speakeasy (https://speakeasyapi.dev). DO NOT EDIT."""

import json
from typing import Optional, Union
from typing import Optional, Tuple, Union

import google.auth
import google.auth.credentials
Expand All @@ -20,6 +20,19 @@
from .utils.logger import Logger, NoOpLogger
from .utils.retries import RetryConfig

LEGACY_MODEL_ID_FORMAT = {
"codestral-2405": "codestral@2405",
"mistral-large-2407": "mistral-large@2407",
"mistral-nemo-2407": "mistral-nemo@2407",
}

def get_model_info(model: str) -> Tuple[str,str]:
# if the model requiers the legacy fomat, use it, else do nothing.
model_id = LEGACY_MODEL_ID_FORMAT.get(model, model)
model = "-".join(model.split("-")[:-1])
return model, model_id



class MistralGoogleCloud(BaseSDK):
r"""Mistral AI API: Our Chat Completion and Embeddings APIs specification. Create your account on [La Plateforme](https://console.mistral.ai) to get access and read the [docs](https://docs.mistral.ai) to learn how to use it."""
Expand Down Expand Up @@ -140,28 +153,24 @@ def __init__(self, region: str, project_id: str):
def before_request(
self, hook_ctx, request: httpx.Request
) -> Union[httpx.Request, Exception]:
# The goal of this function is to template in the region, project, model, and model_version into the URL path
# The goal of this function is to template in the region, project and model into the URL path
# We do this here so that the API remains more user-friendly
model = None
model_version = None
model_id = None
new_content = None
if request.content:
parsed = json.loads(request.content.decode("utf-8"))
model_raw = parsed.get("model")
model = "-".join(model_raw.split("-")[:-1])
model_version = model_raw.split("-")[-1]
parsed["model"] = model
model_name, model_id = get_model_info(model_raw)
parsed["model"] = model_name
new_content = json.dumps(parsed).encode("utf-8")

if model == "":
if model_id == "":
raise models.SDKError("model must be provided")

if model_version is None:
raise models.SDKError("model_version must be provided")

stream = "streamRawPredict" in request.url.path
specifier = "streamRawPredict" if stream else "rawPredict"
url = f"/v1/projects/{self.project_id}/locations/{self.region}/publishers/mistralai/models/{model}@{model_version}:{specifier}"
url = f"/v1/projects/{self.project_id}/locations/{self.region}/publishers/mistralai/models/{model_id}:{specifier}"

headers = dict(request.headers)
# Delete content-length header as it will need to be recalculated
Expand Down

0 comments on commit fcf4cd4

Please sign in to comment.