Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable unit test for wasi-nn WinML backend. #8442

Merged
merged 1 commit into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,9 @@ jobs:
matrix:
feature: ["openvino", "onnx"]
os: ["ubuntu-latest", "windows-latest"]
include:
- os: windows-latest
feature: winml
name: Test wasi-nn (${{ matrix.feature }}, ${{ matrix.os }})
runs-on: ${{ matrix.os }}
needs: determine
Expand All @@ -696,6 +699,15 @@ jobs:
- uses: abrown/install-openvino-action@v8
if: runner.arch == 'X64'

# Install WinML for testing wasi-nn WinML backend. WinML is only available
# on Windows clients and Windows Server with desktop experience enabled.
# GitHub Actions Window Server image doesn't have desktop experience
# enabled, so we download the standalone library from ONNX Runtime project.
- uses: nuget/setup-nuget@v2
if: (matrix.os == 'windows-latest') && (matrix.feature == 'winml')
- run: nuget install Microsoft.AI.MachineLearning
if: (matrix.os == 'windows-latest') && (matrix.feature == 'winml')

# Install Rust targets.
- run: rustup target add wasm32-wasi

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ pub fn main() -> Result<()> {
.expect("the model file to be mapped to the fixture directory");
let graph =
GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU).build_from_bytes([&model])?;
let tensor = fs::read("fixture/tensor.bgr")
let tensor = fs::read("fixture/000000062808.rgb")
.expect("the tensor file to be mapped to the fixture directory");
let results = classify(graph, tensor)?;
let top_five = &sort_results(&results)[..5];
// 963 is meat loaf, meatloaf.
// https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963
assert_eq!(top_five[0].class_id(), 963);
println!("found results, sorted top 5: {:?}", top_five);
Ok(())
}
4 changes: 0 additions & 4 deletions crates/test-programs/src/nn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,8 @@ pub fn classify(graph: Graph, tensor: Vec<u8>) -> Result<Vec<f32>> {
/// placing the match probability for each class at the index for that class
/// (the probability of class `N` is stored at `probabilities[N]`).
pub fn sort_results(probabilities: &[f32]) -> Vec<InferenceResult> {
// It is unclear why the MobileNet output indices are "off by one" but the
// `.skip(1)` below seems necessary to get results that make sense (e.g. 763
// = "revolver" vs 762 = "restaurant").
let mut results: Vec<InferenceResult> = probabilities
.iter()
.skip(1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this line because it's likely to be a workaround for this specific openvino model only. If mobilenet-v1-0.25-128 is the model for openvino test, it may have an additional class 0 for background. The shape (1, 1001) also shows it has one more value than ONNX model (1, 1000).

.enumerate()
.map(|(c, p)| InferenceResult(c, *p))
.collect();
Expand Down
2 changes: 1 addition & 1 deletion crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ wasi-common = { workspace = true, features = ["sync"] }
wasmtime = { workspace = true, features = ["cranelift"] }

[features]
default = ["openvino"]
default = ["openvino", "winml"]
# openvino is available on all platforms, it requires openvino installed.
openvino = ["dep:openvino"]
# onnx is available on all platforms.
Expand Down
55 changes: 17 additions & 38 deletions crates/wasi-nn/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@

#[allow(unused_imports)]
use anyhow::{anyhow, Context, Result};
use std::{env, fs, path::Path, path::PathBuf, process::Command, sync::Mutex};
use std::{
env, fs,
path::{Path, PathBuf},
process::Command,
sync::Mutex,
};

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
use windows::AI::MachineLearning::{LearningModelDevice, LearningModelDeviceKind};
Expand Down Expand Up @@ -50,13 +55,12 @@ pub fn check() -> Result<()> {
#[cfg(feature = "openvino")]
check_openvino_artifacts_are_available()?;

#[cfg(feature = "onnx")]
#[cfg(any(feature = "onnx", all(feature = "winml", target_os = "windows")))]
check_onnx_artifacts_are_available()?;

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
#[cfg(all(feature = "winml", target_os = "windows"))]
{
check_winml_is_available()?;
check_winml_artifacts_are_available()?;
}
Ok(())
}
Expand Down Expand Up @@ -108,7 +112,7 @@ fn check_openvino_artifacts_are_available() -> Result<()> {
Ok(())
}

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
#[cfg(all(feature = "winml", target_os = "windows"))]
fn check_winml_is_available() -> Result<()> {
match std::panic::catch_unwind(|| {
println!(
Expand All @@ -121,59 +125,34 @@ fn check_winml_is_available() -> Result<()> {
}
}

#[cfg(feature = "onnx")]
#[cfg(any(feature = "onnx", all(feature = "winml", target_os = "windows")))]
fn check_onnx_artifacts_are_available() -> Result<()> {
let _exclusively_retrieve_artifacts = ARTIFACTS.lock().unwrap();

const OPENVINO_BASE_URL: &str =
"https://github.com/intel/openvino-rs/raw/main/crates/openvino/tests/fixtures/mobilenet";
const ONNX_BASE_URL: &str =
"https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/mobilenet/model/mobilenetv2-7.onnx?download=";
"https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/mobilenet/model/mobilenetv2-10.onnx?download=";

let artifacts_dir = artifacts_dir();
if !artifacts_dir.is_dir() {
fs::create_dir(&artifacts_dir)?;
}

for (from, to) in [
(
[OPENVINO_BASE_URL, "tensor-1x224x224x3-f32.bgr"].join("/"),
"tensor.bgr",
),
(ONNX_BASE_URL.to_string(), "model.onnx"),
] {
for (from, to) in [(ONNX_BASE_URL.to_string(), "model.onnx")] {
let local_path = artifacts_dir.join(to);
if !local_path.is_file() {
download(&from, &local_path).with_context(|| "unable to retrieve test artifact")?;
} else {
println!("> using cached artifact: {}", local_path.display())
}
}
Ok(())
}

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
fn check_winml_artifacts_are_available() -> Result<()> {
let _exclusively_retrieve_artifacts = ARTIFACTS.lock().unwrap();
let artifacts_dir = artifacts_dir();
if !artifacts_dir.is_dir() {
fs::create_dir(&artifacts_dir)?;
}
const MODEL_URL: &str="https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/mobilenet/model/mobilenetv2-12.onnx";
for (from, to) in [(MODEL_URL, "model.onnx")] {
let local_path = artifacts_dir.join(to);
if !local_path.is_file() {
download(&from, &local_path).with_context(|| "unable to retrieve test artifact")?;
} else {
println!("> using cached artifact: {}", local_path.display())
}
}
// kitten.rgb is converted from https://github.com/microsoft/Windows-Machine-Learning/blob/master/SharedContent/media/kitten_224.png?raw=true.
let tensor_path = env::current_dir()?
// Copy image from source tree to artifact directory.
let image_path = env::current_dir()?
.join("tests")
.join("fixtures")
.join("kitten.rgb");
fs::copy(tensor_path, artifacts_dir.join("kitten.rgb"))?;
.join("000000062808.rgb");
let dest_path = artifacts_dir.join("000000062808.rgb");
fs::copy(&image_path, &dest_path)?;
Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions crates/wasi-nn/tests/all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ fn nn_image_classification_named() {
#[cfg_attr(not(all(feature = "winml", target_os = "windows")), ignore)]
#[test]
fn nn_image_classification_winml() {
#[cfg(feature = "winml")]
#[cfg(all(feature = "winml", target_os = "windows"))]
{
let backend = Backend::from(backend::winml::WinMLBackend::default());
run(NN_IMAGE_CLASSIFICATION_WINML, backend, true).unwrap()
run(NN_IMAGE_CLASSIFICATION_ONNX, backend, true).unwrap()
}
}

Expand Down
Binary file not shown.
Binary file removed crates/wasi-nn/tests/fixtures/kitten.rgb
Binary file not shown.
9 changes: 9 additions & 0 deletions crates/wasi-nn/tests/fixtures/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
The original image of 000000062808.rgb is 000000062808.jpg from ImageNet
database. It processed by following Python code with
https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/imagenet_preprocess.py

```
image = mxnet.image.imread('000000062808.jpg')
image = preprocess_mxnet(image)
image.asnumpy().tofile('000000062808.rgb')
```
Loading