Skip to content

Commit

Permalink
Enable unit test for wasi-nn WinML backend.
Browse files Browse the repository at this point in the history
This test was disabled because GitHub Actions Windows Server image
doesn't have desktop experience included. But it looks like we can have
a standalone WinML binary downloaded from ONNX Runtime project.

Wasi-nn WinML backend and ONNX Runtime backend now share the same test
code as they accept the same input, and they are expected to produce the
same result.

This change also make wasi-nn WinML backend as a default feature.
  • Loading branch information
jianjunz committed Apr 29, 2024
1 parent 1cf3a9d commit cd8785a
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 106 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,15 @@ jobs:
- uses: abrown/install-openvino-action@v8
if: runner.arch == 'X64'

# Install WinML for testing wasi-nn WinML backend. WinML is only available
# on Windows clients and Windows Server with desktop experience enabled.
# GitHub Actions Window Server image doesn't have desktop experience
# enabled, so we download the standalone library from ONNX Runtime project.
- uses: nuget/setup-nuget@v2
if: matrix.os == 'windows-latest'
- run: nuget install Microsoft.AI.MachineLearning
if: matrix.os == 'windows-latest'

# Fix an ICE for now in gcc when compiling zstd with debuginfo (??)
- run: echo CFLAGS=-g0 >> $GITHUB_ENV
if: matrix.target == 'x86_64-pc-windows-gnu'
Expand Down
2 changes: 1 addition & 1 deletion crates/test-programs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ getrandom = "0.2.9"
futures = { workspace = true, default-features = false, features = ['alloc'] }
url = { workspace = true }
sha2 = "0.10.2"
base64 = "0.21.0"
base64 = "0.21.0"
19 changes: 13 additions & 6 deletions crates/test-programs/src/bin/nn_image_classification_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,30 @@ pub fn main() -> Result<()> {
context
);

// Prepare WASI-NN tensor - Tensor data is always a bytes vector
// Prepare WASI-NN tensor - Tensor data is a bytes vector tensorized from
// 000000062808.jpg of ImageNet database, pre-processed as described in
// https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet#preprocessing.
// Load a tensor that precisely matches the graph input tensor
let data = fs::read("fixture/tensor.bgr").unwrap();
let data = fs::read("fixture/000000062808.rgb").unwrap();
println!("[ONNX] Read input tensor, size in bytes: {}", data.len());

context.set_input(0, wasi_nn::TensorType::F32, &[1, 3, 224, 224], &data)?;

// Execute the inferencing
context.compute()?;
println!("[ONNX] Executed graph inference");

// Retrieve the output.
// To simplify the test and avoid unnecessary dependencies, probablility is
// not processed with softmax function.
let mut output_buffer = vec![0f32; 1000];
context.get_output(0, &mut output_buffer[..])?;
println!(
"[ONNX] Found results, sorted top 5: {:?}",
&sort_results(&output_buffer)[..5]
);
let sorted = sort_results(&output_buffer);
println!("[ONNX] Found results, sorted top 5: {:?}", &sorted[..5]);

// Index 963 is meat loaf, meatloaf.
// https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963
assert_eq!(sorted[0].0, 963);

Ok(())
}
Expand Down
58 changes: 0 additions & 58 deletions crates/test-programs/src/bin/nn_image_classification_winml.rs

This file was deleted.

2 changes: 1 addition & 1 deletion crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ wasi-common = { workspace = true, features = ["sync"] }
wasmtime = { workspace = true, features = ["cranelift"] }

[features]
default = ["openvino"]
default = ["openvino", "winml"]
# openvino is available on all platforms, it requires openvino installed.
openvino = ["dep:openvino"]
# onnx is available on all platforms.
Expand Down
55 changes: 17 additions & 38 deletions crates/wasi-nn/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@

#[allow(unused_imports)]
use anyhow::{anyhow, Context, Result};
use std::{env, fs, path::Path, path::PathBuf, process::Command, sync::Mutex};
use std::{
env, fs,
path::{Path, PathBuf},
process::Command,
sync::Mutex,
};

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
use windows::AI::MachineLearning::{LearningModelDevice, LearningModelDeviceKind};
Expand Down Expand Up @@ -50,13 +55,12 @@ pub fn check() -> Result<()> {
#[cfg(feature = "openvino")]
check_openvino_artifacts_are_available()?;

#[cfg(feature = "onnx")]
#[cfg(any(feature = "onnx", all(feature = "winml", target_os = "windows")))]
check_onnx_artifacts_are_available()?;

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
#[cfg(all(feature = "winml", target_os = "windows"))]
{
check_winml_is_available()?;
check_winml_artifacts_are_available()?;
}
Ok(())
}
Expand Down Expand Up @@ -108,7 +112,7 @@ fn check_openvino_artifacts_are_available() -> Result<()> {
Ok(())
}

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
#[cfg(all(feature = "winml", target_os = "windows"))]
fn check_winml_is_available() -> Result<()> {
match std::panic::catch_unwind(|| {
println!(
Expand All @@ -121,59 +125,34 @@ fn check_winml_is_available() -> Result<()> {
}
}

#[cfg(feature = "onnx")]
#[cfg(any(feature = "onnx", all(feature = "winml", target_os = "windows")))]
fn check_onnx_artifacts_are_available() -> Result<()> {
let _exclusively_retrieve_artifacts = ARTIFACTS.lock().unwrap();

const OPENVINO_BASE_URL: &str =
"https://github.com/intel/openvino-rs/raw/main/crates/openvino/tests/fixtures/mobilenet";
const ONNX_BASE_URL: &str =
"https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/mobilenet/model/mobilenetv2-7.onnx?download=";
"https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/mobilenet/model/mobilenetv2-10.onnx?download=";

let artifacts_dir = artifacts_dir();
if !artifacts_dir.is_dir() {
fs::create_dir(&artifacts_dir)?;
}

for (from, to) in [
(
[OPENVINO_BASE_URL, "tensor-1x224x224x3-f32.bgr"].join("/"),
"tensor.bgr",
),
(ONNX_BASE_URL.to_string(), "model.onnx"),
] {
for (from, to) in [(ONNX_BASE_URL.to_string(), "model.onnx")] {
let local_path = artifacts_dir.join(to);
if !local_path.is_file() {
download(&from, &local_path).with_context(|| "unable to retrieve test artifact")?;
} else {
println!("> using cached artifact: {}", local_path.display())
}
}
Ok(())
}

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
fn check_winml_artifacts_are_available() -> Result<()> {
let _exclusively_retrieve_artifacts = ARTIFACTS.lock().unwrap();
let artifacts_dir = artifacts_dir();
if !artifacts_dir.is_dir() {
fs::create_dir(&artifacts_dir)?;
}
const MODEL_URL: &str="https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/mobilenet/model/mobilenetv2-12.onnx";
for (from, to) in [(MODEL_URL, "model.onnx")] {
let local_path = artifacts_dir.join(to);
if !local_path.is_file() {
download(&from, &local_path).with_context(|| "unable to retrieve test artifact")?;
} else {
println!("> using cached artifact: {}", local_path.display())
}
}
// kitten.rgb is converted from https://github.com/microsoft/Windows-Machine-Learning/blob/master/SharedContent/media/kitten_224.png?raw=true.
let tensor_path = env::current_dir()?
// Copy image from source tree to artifact directory.
let image_path = env::current_dir()?
.join("tests")
.join("fixtures")
.join("kitten.rgb");
fs::copy(tensor_path, artifacts_dir.join("kitten.rgb"))?;
.join("000000062808.rgb");
let dest_path = artifacts_dir.join("000000062808.rgb");
fs::copy(&image_path, &dest_path)?;
Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions crates/wasi-nn/tests/all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ fn nn_image_classification_named() {
#[cfg_attr(not(all(feature = "winml", target_os = "windows")), ignore)]
#[test]
fn nn_image_classification_winml() {
#[cfg(feature = "winml")]
#[cfg(all(feature = "winml", target_os = "windows"))]
{
let backend = Backend::from(backend::winml::WinMLBackend::default());
run(NN_IMAGE_CLASSIFICATION_WINML, backend, true).unwrap()
run(NN_IMAGE_CLASSIFICATION_ONNX, backend, true).unwrap()
}
}

Expand Down
Binary file added crates/wasi-nn/tests/fixtures/000000062808.rgb
Binary file not shown.
Binary file removed crates/wasi-nn/tests/fixtures/kitten.rgb
Binary file not shown.
9 changes: 9 additions & 0 deletions crates/wasi-nn/tests/fixtures/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
The original image of 000000062808.rgb is 000000062808.jpg from ImageNet
database. It processed by following Python code with
https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/imagenet_preprocess.py

```
image = mxnet.image.imread('000000062808.jpg')
image = preprocess_mxnet(image)
image.asnumpy().tofile('000000062808.rgb')
```

0 comments on commit cd8785a

Please sign in to comment.