From ecd8ae5c53e05c09cae237215acaf60839e37edd Mon Sep 17 00:00:00 2001 From: Miseon Date: Fri, 8 Sep 2023 15:35:23 -0700 Subject: [PATCH] notebook to support llama model versions with ds mii container (#2633) * notebook to support llama model versions with ds mii container * update formatting --------- Co-authored-by: Miseon Park --- .../llama-safe-online-deployment.ipynb | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/sdk/python/foundation-models/system/inference/text-generation/llama-safe-online-deployment.ipynb b/sdk/python/foundation-models/system/inference/text-generation/llama-safe-online-deployment.ipynb index 5ab2370012d..065fc7e901e 100644 --- a/sdk/python/foundation-models/system/inference/text-generation/llama-safe-online-deployment.ipynb +++ b/sdk/python/foundation-models/system/inference/text-generation/llama-safe-online-deployment.ipynb @@ -53,8 +53,6 @@ "deployment_name = \"llama\" # Replace with your deployment name, lower case only!!!\n", "sku_name = \"Standard_NC24s_v3\" # Name of the sku(instance type) Check the model-list(can be found in the parent folder(inference)) to get the most optimal sku for your model (Default: Standard_DS2_v2)\n", "\n", - "environment_name = f\"{endpoint_name}-env\" # Replace with your environment name\n", - "\n", "# The severity level that will trigger response be blocked\n", "# Please reference Azure AI content documentation for more details\n", "# https://learn.microsoft.com/en-us/azure/cognitive-services/content-safety/concepts/harm-categories\n", @@ -298,16 +296,20 @@ " reg_client.models.list(model_name)\n", ") # list available versions of the model\n", "llama_model = None\n", - "hf_tgi = False # If text-generation-inference (hf container) is supported for model\n", + "\n", + "# If specific inference environments are tagged for the model\n", + "inference_envs_exist = False\n", "\n", "if len(version_list) == 0:\n", " raise Exception(f\"No model named {model_name} found in registry\")\n", "else:\n", " model_version = version_list[0].version\n", " llama_model = reg_client.models.get(model_name, model_version)\n", - " if \"inference_supported_envs\" in llama_model.tags:\n", - " if \"hf_tgi\" in llama_model.tags[\"inference_supported_envs\"]:\n", - " hf_tgi = True\n", + " if (\n", + " \"inference_supported_envs\" in llama_model.tags\n", + " and len(llama_model.tags[\"inference_supported_envs\"]) >= 1\n", + " ):\n", + " inference_envs_exist = True\n", " print(\n", " f\"Using model name: {llama_model.name}, version: {llama_model.version}, id: {llama_model.id} for inferencing\"\n", " )" @@ -457,10 +459,10 @@ " ProbeSettings,\n", ")\n", "\n", - "# For HF TGI inferencing, the scoring script is baked into the container\n", + "# For inference environments HF TGI and DS MII, the scoring script is baked into the container\n", "code_configuration = (\n", " CodeConfiguration(code=\"./llama-files/score/default/\", scoring_script=\"score.py\")\n", - " if not hf_tgi\n", + " if not inference_envs_exist\n", " else None\n", ")\n", "\n",