Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ort V2 integration #51

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/surrealml_core_onnx_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ jobs:
deactivate

- name: Run Core Unit Tests
run: cd modules/core && cargo test --features onnx-tests
run: cd modules/core && cargo test --features onnx-tests --lib
2 changes: 1 addition & 1 deletion .github/workflows/surrealml_core_tensorflow_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ jobs:
deactivate

- name: Run Core Unit Tests
run: cd modules/core && cargo test --features tensorflow-tests
run: cd modules/core && cargo test --features tensorflow-tests --lib
2 changes: 1 addition & 1 deletion .github/workflows/surrealml_core_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
deactivate

- name: Run Core Unit Tests
run: cd modules/core && cargo test --features sklearn-tests
run: cd modules/core && cargo test --features sklearn-tests --lib

- name: Run HTTP Transfer Tests
run: cargo test
2 changes: 1 addition & 1 deletion .github/workflows/surrealml_core_torch_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ jobs:
deactivate

- name: Run Core Unit Tests
run: cd modules/core && cargo test --features torch-tests
run: cd modules/core && cargo test --features torch-tests --lib
1 change: 1 addition & 0 deletions modules/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ tensorflow-tests = []
[dependencies]
regex = "1.9.3"
ort = { version = "1.16.2", features = ["load-dynamic"], default-features = false }
ort-v2 = { version = "2.0.0-rc.1", package = "ort" }
ndarray = "0.15.6"
once_cell = "1.18.0"
bytes = "1.5.0"
Expand Down
11 changes: 11 additions & 0 deletions modules/core/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ let data: ArrayD<f32> = ndarray::arr1(&x).into_dyn();
let output = compute_unit.raw_compute(data, None).unwrap();
```

### V2 Ort
Version 2 of Ort is now supported as well but this is in beta as V2 of Ort in general is also
in beta. Everything is the same apart from you adding `v2` to the end of your execution calls
like so:

```rust
let output = compute_unit.buffered_compute_v2(&mut input_values).unwrap();

let output = compute_unit.raw_compute_v2(data, None).unwrap();
```

## ONNX runtime assets

We can find the ONNX assets with the following link:
Expand Down
170 changes: 170 additions & 0 deletions modules/core/src/execution/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::storage::surml_file::SurMlFile;
use std::collections::HashMap;
use ndarray::{ArrayD, CowArray};
use ort::{SessionBuilder, Value, session::Input};
use ort_v2::Session as SessionV2;

use super::onnx_environment::ENVIRONMENT;
use crate::safe_eject;
Expand Down Expand Up @@ -50,6 +51,33 @@ impl <'a>ModelComputation<'a> {
buffer
}

/// Creates a vector of dimensions for the input tensor from the loaded model.
///
/// # Arguments
/// * `input_dims` - The input dimensions from the loaded model.
///
/// # Returns
/// A vector of dimensions for the input tensor to be reshaped into from the loaded model.
fn process_input_dims_v2(input_dims: Option<&Vec<i64>>) -> Result<Vec<usize>, SurrealError> {
let input_dims = match input_dims {
Some(dims) => dims,
None => return Err(SurrealError::new(
String::from("src/execution/compute.rs 66: No input dimensions found for the model (V2)").to_string(),
SurrealErrorStatus::Unknown
))
};
let mut buffer = Vec::new();
for dim in input_dims {
if dim == &-1 {
buffer.push(1);
}
else {
buffer.push(*dim as usize);
}
}
Ok(buffer)
}

/// Creates a Vector that can be used manipulated with other operations such as normalisation from a hashmap of keys and values.
///
/// # Arguments
Expand Down Expand Up @@ -106,6 +134,60 @@ impl <'a>ModelComputation<'a> {
return Ok(buffer)
}

pub fn raw_compute_v2(&self, tensor: ArrayD<f32>, _dims: Option<(i32, i32)>) -> Result<Vec<f32>, SurrealError> {
let session = match SessionV2::builder() {
Ok(builder) => builder,
Err(_) => return Err(SurrealError::new(
String::from("Failed to create a session builder for the model (V2)").to_string(),
SurrealErrorStatus::Unknown
))
}.commit_from_memory(&self.surml_file.model).map_err(|e| {
SurrealError::new(
format!("Failed to commit the model to the session (V2): {}", e).to_string(),
SurrealErrorStatus::Unknown
)
})?;
let unwrapped_dims = ModelComputation::process_input_dims_v2(
session.inputs[0].input_type.tensor_dimensions()
)?;

let tensor = match tensor.into_shape(unwrapped_dims) {
Ok(tensor) => tensor,
Err(e) => return Err(SurrealError::new(
format!("Failed to reshape the input tensor for the model (V2): {}", e).to_string(),
SurrealErrorStatus::Unknown
))
};
let input_values = match ort_v2::inputs![tensor] {
Ok(inputs) => inputs,
Err(e) => return Err(SurrealError::new(
format!("Failed to create input values for the model (V2): {}", e).to_string(),
SurrealErrorStatus::Unknown
))
};
let outputs = safe_eject!(
session.run(input_values),
SurrealErrorStatus::Unknown
);

let mut buffer: Vec<f32> = Vec::new();

match outputs[0].try_extract_tensor::<f32>() {
Ok(y) => {
for i in y.view().clone().into_iter() {
buffer.push(*i);
}
},
Err(_) => {
for i in safe_eject!(outputs[0].try_extract_tensor::<i64>(), SurrealErrorStatus::Unknown).view().clone().into_iter() {
buffer.push(*i as f32);
}
}
};
return Ok(buffer)

}

/// Checks the header applying normalisers if present and then performs a raw computation on the loaded model. Will
/// also apply inverse normalisers if present on the outputs.
///
Expand All @@ -129,6 +211,8 @@ impl <'a>ModelComputation<'a> {
}
}
let tensor = self.input_tensor_from_key_bindings(input_values.clone())?;
// dims are None because dims is depcrecated as we are now getting the dims from the ONNX file but
// we will keep it here for now to keep the function signature the same and to see if we need it later
let output = self.raw_compute(tensor, None)?;

// if no normaliser is present, return the output
Expand All @@ -152,6 +236,54 @@ impl <'a>ModelComputation<'a> {
return Ok(buffer)
}

/// Checks the header applying normalisers if present and then performs a raw computation on the loaded model. Will
/// also apply inverse normalisers if present on the outputs.
///
/// # Notes
/// This function is fairly coupled and will consider breaking out the functions later on if needed.
///
/// # Arguments
/// * `input_values` - A hashmap of keys and values that will be used to create the input tensor.
///
/// # Returns
/// The computed output tensor from the loaded model.
pub fn buffered_compute_v2(&self, input_values: &mut HashMap<String, f32>) -> Result<Vec<f32>, SurrealError> {
// applying normalisers if present
for (key, value) in &mut *input_values {
let value_ref = value.clone();
match self.surml_file.header.get_normaliser(&key.to_string())? {
Some(normaliser) => {
*value = normaliser.normalise(value_ref);
},
None => {}
}
}
let tensor = self.input_tensor_from_key_bindings(input_values.clone())?;
// dims are None because dims is depcrecated as we are now getting the dims from the ONNX file but
// we will keep it here for now to keep the function signature the same and to see if we need it later
let output = self.raw_compute_v2(tensor, None)?;

// if no normaliser is present, return the output
if self.surml_file.header.output.normaliser == None {
return Ok(output)
}

// apply the normaliser to the output
let output_normaliser = match self.surml_file.header.output.normaliser.as_ref() {
Some(normaliser) => normaliser,
None => return Err(SurrealError::new(
String::from("No normaliser present for output which shouldn't happen as passed initial check for").to_string(),
SurrealErrorStatus::Unknown
))
};
let mut buffer = Vec::with_capacity(output.len());

for value in output {
buffer.push(output_normaliser.inverse_normalise(value));
}
return Ok(buffer)
}

}


Expand All @@ -173,10 +305,16 @@ mod tests {
input_values.insert(String::from("num_floors"), 2.0);

let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
let raw_input_two = raw_input.clone();

let output = model_computation.raw_compute(raw_input, Some((1, 2))).unwrap();
assert_eq!(output.len(), 1);
assert_eq!(output[0], 985.57745);

// testing the v2
let output = model_computation.raw_compute_v2(raw_input_two, Some((1, 2))).unwrap();
assert_eq!(output.len(), 1);
assert_eq!(output[0], 985.57745);
}

#[cfg(feature = "sklearn-tests")]
Expand All @@ -193,6 +331,10 @@ mod tests {

let output = model_computation.buffered_compute(&mut input_values).unwrap();
assert_eq!(output.len(), 1);

// testing the v2
let output = model_computation.buffered_compute_v2(&mut input_values).unwrap();
assert_eq!(output.len(), 1);
}

#[cfg(feature = "onnx-tests")]
Expand All @@ -208,10 +350,16 @@ mod tests {
input_values.insert(String::from("num_floors"), 2.0);

let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
let raw_input_two = raw_input.clone();

let output = model_computation.raw_compute(raw_input, Some((1, 2))).unwrap();
assert_eq!(output.len(), 1);
assert_eq!(output[0], 985.57745);

// testing the v2
let output = model_computation.raw_compute_v2(raw_input_two, Some((1, 2))).unwrap();
assert_eq!(output.len(), 1);
assert_eq!(output[0], 985.57745);
}

#[cfg(feature = "onnx-tests")]
Expand All @@ -228,6 +376,10 @@ mod tests {

let output = model_computation.buffered_compute(&mut input_values).unwrap();
assert_eq!(output.len(), 1);

// testing the v2
let output = model_computation.buffered_compute_v2(&mut input_values).unwrap();
assert_eq!(output.len(), 1);
}

#[cfg(feature = "torch-tests")]
Expand All @@ -243,9 +395,14 @@ mod tests {
input_values.insert(String::from("num_floors"), 2.0);

let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
let raw_input_two = raw_input.clone();

let output = model_computation.raw_compute(raw_input, None).unwrap();
assert_eq!(output.len(), 1);

// testing V2
let output = model_computation.raw_compute_v2(raw_input_two, None).unwrap();
assert_eq!(output.len(), 1);
}

#[cfg(feature = "torch-tests")]
Expand All @@ -262,6 +419,10 @@ mod tests {

let output = model_computation.buffered_compute(&mut input_values).unwrap();
assert_eq!(output.len(), 1);

// testing V2
let output = model_computation.buffered_compute_v2(&mut input_values).unwrap();
assert_eq!(output.len(), 1);
}

#[cfg(feature = "tensorflow-tests")]
Expand All @@ -277,9 +438,14 @@ mod tests {
input_values.insert(String::from("num_floors"), 2.0);

let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap();
let raw_input_two = raw_input.clone();

let output = model_computation.raw_compute(raw_input, None).unwrap();
assert_eq!(output.len(), 1);

// testing v2
let output = model_computation.raw_compute_v2(raw_input_two, None).unwrap();
assert_eq!(output.len(), 1);
}

#[cfg(feature = "tensorflow-tests")]
Expand All @@ -296,5 +462,9 @@ mod tests {

let output = model_computation.buffered_compute(&mut input_values).unwrap();
assert_eq!(output.len(), 1);

// testing v2
let output = model_computation.buffered_compute_v2(&mut input_values).unwrap();
assert_eq!(output.len(), 1);
}
}
32 changes: 32 additions & 0 deletions modules/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,38 @@
//! // None input can be a tuple of dimensions of the input data
//! let output = compute_unit.raw_compute(data, None).unwrap();
//! ```
//!
//! ### V2 Ort support
//! Version 2 of Ort is now supported as well but this is in beta as V2 of Ort in general is also
//! in beta. Everything is the same apart from you adding `v2` to the end of your execution calls
//! like so:
//! ```rust
//! use surrealml_core::storage::surml_file::SurMlFile;
//! use surrealml_core::execution::compute::ModelComputation;
//! use ndarray::ArrayD;
//! use std::collections::HashMap;
//!
//!
//! let mut file = SurMlFile::from_file("./stash/test.surml").unwrap();
//!
//! let compute_unit = ModelComputation {
//! surml_file: &mut file,
//! };
//!
//! // automatically map inputs and apply normalisers to the compute if this data was put in the header
//! let mut input_values = HashMap::new();
//! input_values.insert(String::from("squarefoot"), 1000.0);
//! input_values.insert(String::from("num_floors"), 2.0);
//!
//! let output = compute_unit.buffered_compute(&mut input_values).unwrap();
//!
//! // feed a raw ndarray into the model if no header was provided or if you want to bypass the header
//! let x = vec![1000.0, 2.0];
//! let data: ArrayD<f32> = ndarray::arr1(&x).into_dyn();
//!
//! let output = compute_unit.buffered_compute_v2(&mut input_values).unwrap();
//! let output = compute_unit.raw_compute_v2(data, None).unwrap();
//! ```
pub mod storage;
pub mod execution;
pub mod errors;
Expand Down
Binary file modified modules/core/stash/test.surml
Binary file not shown.
Loading