From fcf4cd4c1683e724848c6d8a4c315bafa38f088f Mon Sep 17 00:00:00 2001 From: gaspardBT Date: Fri, 8 Nov 2024 14:45:53 +0100 Subject: [PATCH] update the gcp name logic --- .../mistralai_gcp/src/mistralai_gcp/sdk.py | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/packages/mistralai_gcp/src/mistralai_gcp/sdk.py b/packages/mistralai_gcp/src/mistralai_gcp/sdk.py index bb4c1de..7e7adbd 100644 --- a/packages/mistralai_gcp/src/mistralai_gcp/sdk.py +++ b/packages/mistralai_gcp/src/mistralai_gcp/sdk.py @@ -1,7 +1,7 @@ """Code generated by Speakeasy (https://speakeasyapi.dev). DO NOT EDIT.""" import json -from typing import Optional, Union +from typing import Optional, Tuple, Union import google.auth import google.auth.credentials @@ -20,6 +20,19 @@ from .utils.logger import Logger, NoOpLogger from .utils.retries import RetryConfig +LEGACY_MODEL_ID_FORMAT = { + "codestral-2405": "codestral@2405", + "mistral-large-2407": "mistral-large@2407", + "mistral-nemo-2407": "mistral-nemo@2407", +} + +def get_model_info(model: str) -> Tuple[str,str]: + # if the model requiers the legacy fomat, use it, else do nothing. + model_id = LEGACY_MODEL_ID_FORMAT.get(model, model) + model = "-".join(model.split("-")[:-1]) + return model, model_id + + class MistralGoogleCloud(BaseSDK): r"""Mistral AI API: Our Chat Completion and Embeddings APIs specification. Create your account on [La Plateforme](https://console.mistral.ai) to get access and read the [docs](https://docs.mistral.ai) to learn how to use it.""" @@ -140,28 +153,24 @@ def __init__(self, region: str, project_id: str): def before_request( self, hook_ctx, request: httpx.Request ) -> Union[httpx.Request, Exception]: - # The goal of this function is to template in the region, project, model, and model_version into the URL path + # The goal of this function is to template in the region, project and model into the URL path # We do this here so that the API remains more user-friendly - model = None - model_version = None + model_id = None new_content = None if request.content: parsed = json.loads(request.content.decode("utf-8")) model_raw = parsed.get("model") - model = "-".join(model_raw.split("-")[:-1]) - model_version = model_raw.split("-")[-1] - parsed["model"] = model + model_name, model_id = get_model_info(model_raw) + parsed["model"] = model_name new_content = json.dumps(parsed).encode("utf-8") - if model == "": + if model_id == "": raise models.SDKError("model must be provided") - if model_version is None: - raise models.SDKError("model_version must be provided") stream = "streamRawPredict" in request.url.path specifier = "streamRawPredict" if stream else "rawPredict" - url = f"/v1/projects/{self.project_id}/locations/{self.region}/publishers/mistralai/models/{model}@{model_version}:{specifier}" + url = f"/v1/projects/{self.project_id}/locations/{self.region}/publishers/mistralai/models/{model_id}:{specifier}" headers = dict(request.headers) # Delete content-length header as it will need to be recalculated