-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Examples for online inference for LLaVA (#2671)
* Rough version of ITT notebook. * Working version of the notebook. * Minor improvements. * Rough version of the CLI demo. * Reformat with black. * Remove mode parameter in call to prepare_data.py. * Make code cells easier to read. * Minor readability improvement. * Format with black. * Minor fixes and reformatting. * Reformat with black.
- Loading branch information
1 parent
aee8bfe
commit 3ce6d0a
Showing
4 changed files
with
649 additions
and
0 deletions.
There are no files selected for viewing
11 changes: 11 additions & 0 deletions
11
cli/foundation-models/system/inference/image-text-to-text/deploy-online.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
$schema: https://azuremlschemas.azureedge.net/latest/managedOnlineDeployment.schema.json | ||
name: image-text-to-text-mlflow-deploy | ||
instance_type: Standard_NC6s_v3 | ||
instance_count: 1 | ||
liveness_probe: | ||
initial_delay: 180 | ||
period: 180 | ||
failure_threshold: 49 | ||
timeout: 299 | ||
request_settings: | ||
request_timeout_ms: 90000 |
80 changes: 80 additions & 0 deletions
80
...undation-models/system/inference/image-text-to-text/image-text-to-text-online-endpoint.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
set -x | ||
# The commands in this file map to steps in this notebook: <TODO> | ||
# The sample scoring file available in the same folder as the above notebook | ||
|
||
# script inputs | ||
registry_name="azureml" | ||
subscription_id="<SUBSCRIPTION_ID>" | ||
resource_group_name="<RESOURCE_GROUP>" | ||
workspace_name="<WORKSPACE_NAME>" | ||
|
||
# This is the model from system registry that needs to be deployed | ||
model_name="llava-7" | ||
model_label="latest" | ||
|
||
version=$(date +%s) | ||
endpoint_name="image-text-to-text-$version" | ||
|
||
# Todo: fetch deployment_sku from the min_inference_sku tag of the model | ||
deployment_sku="Standard_NC6s_v3" | ||
|
||
# Prepare data for deployment | ||
data_path="./data_online" | ||
python ./prepare_data.py --data_path $data_path | ||
# sample_request_data | ||
sample_request_data="$data_path/fridgeObjects/sample_request_data.json" | ||
# 1. Setup pre-requisites | ||
if [ "$subscription_id" = "<SUBSCRIPTION_ID>" ] || \ | ||
["$resource_group_name" = "<RESOURCE_GROUP>" ] || \ | ||
[ "$workspace_name" = "<WORKSPACE_NAME>" ]; then | ||
echo "Please update the script with the subscription_id, resource_group_name and workspace_name" | ||
exit 1 | ||
fi | ||
|
||
az account set -s $subscription_id | ||
workspace_info="--resource-group $resource_group_name --workspace-name $workspace_name" | ||
|
||
# 2. Check if the model exists in the registry | ||
# Need to confirm model show command works for registries outside the tenant (aka system registry) | ||
if ! az ml model show --name $model_name --label $model_label --registry-name $registry_name | ||
then | ||
echo "Model $model_name:$model_label does not exist in registry $registry_name" | ||
exit 1 | ||
fi | ||
|
||
# Get the latest model version | ||
model_version=$(az ml model show --name $model_name --label $model_label --registry-name $registry_name --query version --output tsv) | ||
|
||
# 3. Deploy the model to an endpoint | ||
# Create online endpoint | ||
az ml online-endpoint create --name $endpoint_name $workspace_info || { | ||
echo "endpoint create failed"; exit 1; | ||
} | ||
|
||
# Deploy model from registry to endpoint in workspace | ||
az ml online-deployment create --file deploy-online.yaml $workspace_info --all-traffic --set \ | ||
endpoint_name=$endpoint_name model=azureml://registries/$registry_name/models/$model_name/versions/$model_version \ | ||
instance_type=$deployment_sku || { | ||
echo "deployment create failed"; exit 1; | ||
} | ||
|
||
# 4. Try a sample scoring request | ||
|
||
# Check if scoring data file exists | ||
if [ -f $sample_request_data ]; then | ||
echo "Invoking endpoint $endpoint_name with $sample_request_data\n\n" | ||
else | ||
echo "Scoring file $sample_request_data does not exist" | ||
exit 1 | ||
fi | ||
|
||
az ml online-endpoint invoke --name $endpoint_name --request-file $sample_request_data $workspace_info || { | ||
echo "endpoint invoke failed"; exit 1; | ||
} | ||
|
||
# 6. Delete the endpoint and sample_request_data.json | ||
az ml online-endpoint delete --name $endpoint_name $workspace_info --yes || { | ||
echo "endpoint delete failed"; exit 1; | ||
} | ||
|
||
rm $sample_request_data |
108 changes: 108 additions & 0 deletions
108
cli/foundation-models/system/inference/image-text-to-text/prepare_data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import argparse | ||
import base64 | ||
import json | ||
import os | ||
import shutil | ||
import urllib.request | ||
from zipfile import ZipFile | ||
|
||
|
||
# Direct question to ask the model. | ||
DIRECT_QUESTION = "What's in this image?" | ||
|
||
|
||
def download_and_unzip(dataset_parent_dir: str) -> None: | ||
"""Download image dataset and unzip it. | ||
:param dataset_parent_dir: dataset parent directory to which dataset will be downloaded | ||
:type dataset_parent_dir: str | ||
""" | ||
# Create directory, if it does not exist | ||
os.makedirs(dataset_parent_dir, exist_ok=True) | ||
|
||
# download data | ||
download_url = "https://cvbp-secondary.z19.web.core.windows.net/datasets/image_classification/fridgeObjects.zip" | ||
print(f"Downloading data from {download_url}") | ||
|
||
# Extract current dataset name from dataset url | ||
dataset_name = os.path.basename(download_url).split(".")[0] | ||
# Get dataset path for later use | ||
dataset_dir = os.path.join(dataset_parent_dir, dataset_name) | ||
|
||
if os.path.exists(dataset_dir): | ||
shutil.rmtree(dataset_dir) | ||
|
||
# Get the name of zip file | ||
data_file = os.path.join(dataset_parent_dir, f"{dataset_name}.zip") | ||
|
||
# Download data from public url | ||
urllib.request.urlretrieve(download_url, filename=data_file) | ||
|
||
# extract files | ||
with ZipFile(data_file, "r") as zip: | ||
print("extracting files...") | ||
zip.extractall(path=dataset_parent_dir) | ||
print("done") | ||
# delete zip file | ||
os.remove(data_file) | ||
return dataset_dir | ||
|
||
|
||
def read_image(image_path: str) -> bytes: | ||
"""Read image from path. | ||
:param image_path: image path | ||
:type image_path: str | ||
:return: image in bytes format | ||
:rtype: bytes | ||
""" | ||
with open(image_path, "rb") as f: | ||
return f.read() | ||
|
||
|
||
def prepare_data_for_online_inference(dataset_dir: str) -> None: | ||
"""Prepare request json for online inference. | ||
:param dataset_dir: dataset directory | ||
:type dataset_dir: str | ||
""" | ||
sample_image = os.path.join(dataset_dir, "milk_bottle", "99.jpg") | ||
|
||
request_json = { | ||
"input_data": { | ||
"columns": ["image", "prompt", "direct_question"], | ||
"index": [0], | ||
"data": [ | ||
[ | ||
base64.encodebytes(read_image(sample_image)).decode("utf-8"), | ||
"", | ||
DIRECT_QUESTION, | ||
] | ||
], | ||
} | ||
} | ||
|
||
request_file_name = os.path.join(dataset_dir, "sample_request_data.json") | ||
|
||
with open(request_file_name, "w") as request_file: | ||
json.dump(request_json, request_file) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Prepare data for text generation based on images and text." | ||
) | ||
parser.add_argument( | ||
"--data_path", type=str, default="data", help="Dataset location" | ||
) | ||
|
||
args, unknown = parser.parse_known_args() | ||
args_dict = vars(args) | ||
|
||
dataset_dir = download_and_unzip( | ||
dataset_parent_dir=os.path.join( | ||
os.path.dirname(os.path.realpath(__file__)), args.data_path | ||
), | ||
) | ||
|
||
prepare_data_for_online_inference(dataset_dir=dataset_dir) |
Oops, something went wrong.