From e1d77b499eb35656194880d6b27168b3e7a7c3e0 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 30 Jun 2024 22:19:40 -0500 Subject: [PATCH 1/5] refactor: simplify `run_async` --- src/session/async.rs | 64 ++++++++------------------------------------ 1 file changed, 11 insertions(+), 53 deletions(-) diff --git a/src/session/async.rs b/src/session/async.rs index 4bc338f..a63ea48 100644 --- a/src/session/async.rs +++ b/src/session/async.rs @@ -2,14 +2,10 @@ use std::{ cell::UnsafeCell, ffi::{c_char, CString}, future::Future, - mem::MaybeUninit, ops::Deref, pin::Pin, ptr::NonNull, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, Mutex - }, + sync::{Arc, Mutex}, task::{Context, Poll, Waker} }; @@ -17,47 +13,26 @@ use ort_sys::{c_void, OrtStatus}; use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SessionInputValue, SessionOutputs, SharedSessionInner, Value}; -pub(crate) enum InnerValue { - Present(T), - Pending, - Closed -} - -const VALUE_PRESENT: usize = 1 << 0; -const CHANNEL_CLOSED: usize = 1 << 1; - #[derive(Debug)] pub(crate) struct InferenceFutInner<'r, 's> { - presence: AtomicUsize, - value: UnsafeCell>>>, + value: UnsafeCell>>>, waker: Mutex> } impl<'r, 's> InferenceFutInner<'r, 's> { pub(crate) fn new() -> Self { InferenceFutInner { - presence: AtomicUsize::new(0), waker: Mutex::new(None), - value: UnsafeCell::new(MaybeUninit::uninit()) + value: UnsafeCell::new(None) } } - pub(crate) fn try_take(&self) -> InnerValue>> { - let state_snapshot = self.presence.fetch_and(!VALUE_PRESENT, Ordering::Acquire); - if state_snapshot & VALUE_PRESENT == 0 { - if self.presence.load(Ordering::Acquire) & CHANNEL_CLOSED != 0 { - InnerValue::Closed - } else { - InnerValue::Pending - } - } else { - InnerValue::Present(unsafe { (*self.value.get()).assume_init_read() }) - } + pub(crate) fn try_take(&self) -> Option>> { + unsafe { &mut *self.value.get() }.take() } pub(crate) fn emplace_value(&self, value: Result>) { - unsafe { (*self.value.get()).write(value) }; - self.presence.fetch_or(VALUE_PRESENT, Ordering::Release); + unsafe { &mut *self.value.get() }.replace(value); } pub(crate) fn set_waker(&self, waker: Option<&Waker>) { @@ -69,18 +44,6 @@ impl<'r, 's> InferenceFutInner<'r, 's> { waker.wake(); } } - - pub(crate) fn close(&self) -> bool { - self.presence.fetch_or(CHANNEL_CLOSED, Ordering::Acquire) & CHANNEL_CLOSED == 0 - } -} - -impl<'r, 's> Drop for InferenceFutInner<'r, 's> { - fn drop(&mut self) { - if self.presence.load(Ordering::Acquire) & VALUE_PRESENT != 0 { - unsafe { (*self.value.get()).assume_init_drop() }; - } - } } unsafe impl<'r, 's> Send for InferenceFutInner<'r, 's> {} @@ -136,24 +99,19 @@ impl<'s, 'r> Future for InferenceFut<'s, 'r> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = Pin::into_inner(self); - match this.inner.try_take() { - InnerValue::Present(v) => { - this.did_receive = true; - return Poll::Ready(v); - } - InnerValue::Pending => {} - InnerValue::Closed => panic!() - }; + if let Some(v) = this.inner.try_take() { + this.did_receive = true; + return Poll::Ready(v); + } this.inner.set_waker(Some(cx.waker())); - Poll::Pending } } impl<'s, 'r> Drop for InferenceFut<'s, 'r> { fn drop(&mut self) { - if !self.did_receive && self.inner.close() { + if !self.did_receive { let _ = self.run_options.terminate(); self.inner.set_waker(None); } From 920cee9427c1e3eaf00a517b43df19104b0dbedb Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 30 Jun 2024 23:10:57 -0500 Subject: [PATCH 2/5] refactor: simplify logging function --- src/environment.rs | 36 ++++++------------------------------ 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/src/environment.rs b/src/environment.rs index 1a7860e..6bc6038 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -268,45 +268,21 @@ pub fn init_from(path: impl ToString) -> EnvironmentBuilder { EnvironmentBuilder::new() } -/// ONNX's logger sends the code location where the log occurred, which will be parsed into this struct. -#[derive(Debug)] -struct CodeLocation<'a> { - file: &'a str, - line: &'a str, - function: &'a str -} - -impl<'a> From<&'a str> for CodeLocation<'a> { - fn from(code_location: &'a str) -> Self { - let mut splitter = code_location.split(' '); - let file_and_line = splitter.next().unwrap_or(":"); - let function = splitter.next().unwrap_or(""); - let mut file_and_line_splitter = file_and_line.split(':'); - let file = file_and_line_splitter.next().unwrap_or(""); - let line = file_and_line_splitter.next().unwrap_or(""); - - CodeLocation { file, line, function } - } -} - extern_system_fn! { /// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate. - pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: ort_sys::OrtLoggingLevel, category: *const c_char, _: *const c_char, code_location: *const c_char, message: *const c_char) { - assert_ne!(category, ptr::null()); - let category = unsafe { CStr::from_ptr(category) }.to_str().unwrap_or(""); + pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: ort_sys::OrtLoggingLevel, _: *const c_char, id: *const c_char, code_location: *const c_char, message: *const c_char) { assert_ne!(code_location, ptr::null()); - let code_location_str = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or(""); + let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or(""); assert_ne!(message, ptr::null()); let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or(""); + assert_ne!(id, ptr::null()); + let id = unsafe { CStr::from_ptr(id) }.to_str().unwrap_or(""); - let code_location = CodeLocation::from(code_location_str); let span = tracing::span!( Level::TRACE, "ort", - category = category, - file = code_location.file, - line = code_location.line, - function = code_location.function + id = id, + location = code_location ); match severity { From 3b93e73b278b96ec765ae131520eca76d68eadf9 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 30 Jun 2024 23:11:58 -0500 Subject: [PATCH 3/5] chore: remove unused imports --- src/execution_providers/migraphx.rs | 12 ++++++++---- src/session/builder.rs | 6 ++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/execution_providers/migraphx.rs b/src/execution_providers/migraphx.rs index ebcb50d..2eb02ff 100644 --- a/src/execution_providers/migraphx.rs +++ b/src/execution_providers/migraphx.rs @@ -1,7 +1,7 @@ -use std::{ffi::CString, ptr}; +use std::ffi::CString; use super::ExecutionProvider; -use crate::{ortsys, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; #[derive(Debug, Default, Clone)] pub struct MIGraphXExecutionProvider { @@ -68,9 +68,13 @@ impl ExecutionProvider for MIGraphXExecutionProvider { migraphx_fp16_enable: self.enable_fp16.into(), migraphx_int8_enable: self.enable_int8.into(), migraphx_use_native_calibration_table: self.use_native_calibration_table.into(), - migraphx_int8_calibration_table_name: self.int8_calibration_table_name.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null) + migraphx_int8_calibration_table_name: self + .int8_calibration_table_name + .as_ref() + .map(|c| c.as_ptr()) + .unwrap_or_else(std::ptr::null) }; - ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.session_options_ptr.as_ptr(), &options) -> Error::ExecutionProvider]; + crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.session_options_ptr.as_ptr(), &options) -> Error::ExecutionProvider]; return Ok(()); } diff --git a/src/session/builder.rs b/src/session/builder.rs index 7d654c2..e105fd7 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -1,7 +1,11 @@ +#[cfg(any(feature = "operator-libraries", not(windows)))] +use std::ffi::CString; #[cfg(unix)] use std::os::unix::ffi::OsStrExt; #[cfg(target_family = "windows")] use std::os::windows::ffi::OsStrExt; +#[cfg(not(target_arch = "wasm32"))] +use std::path::Path; #[cfg(feature = "fetch-models")] use std::path::PathBuf; use std::{ @@ -11,8 +15,6 @@ use std::{ rc::Rc, sync::{atomic::Ordering, Arc} }; -#[cfg(not(target_arch = "wasm32"))] -use std::{ffi::CString, path::Path}; use super::{dangerous, InMemorySession, Input, Output, Session, SharedSessionInner}; #[cfg(feature = "fetch-models")] From a127d0f372262357650d82065ff4af1ea4bc80b3 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Sun, 30 Jun 2024 23:50:20 -0500 Subject: [PATCH 4/5] refactor: typestate for `RunOptions` that have selected outputs --- src/io_binding.rs | 6 +- src/lib.rs | 8 +- src/session/async.rs | 32 ++++---- src/session/mod.rs | 45 ++++++------ src/session/run_options.rs | 146 ++++++++++++++++++++++++++++++++++--- src/value/mod.rs | 2 +- 6 files changed, 182 insertions(+), 57 deletions(-) diff --git a/src/io_binding.rs b/src/io_binding.rs index 0467cbb..f0a704a 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -12,7 +12,7 @@ use crate::{ ortsys, session::{output::SessionOutputs, RunOptions}, value::{Value, ValueInner}, - DynValue, Error, Result, Session, ValueTypeMarker + DynValue, Error, NoSelectedOutputs, Result, Session, ValueTypeMarker }; /// Enables binding of session inputs and/or outputs to pre-allocated memory. @@ -177,11 +177,11 @@ impl<'s> IoBinding<'s> { } /// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`]. - pub fn run_with_options(&mut self, run_options: &RunOptions) -> Result> { + pub fn run_with_options(&mut self, run_options: &RunOptions) -> Result> { self.run_inner(Some(run_options)) } - fn run_inner(&mut self, run_options: Option<&RunOptions>) -> Result> { + fn run_inner(&mut self, run_options: Option<&RunOptions>) -> Result> { let run_options_ptr = if let Some(run_options) = run_options { run_options.run_options_ptr.as_ptr() } else { diff --git a/src/lib.rs b/src/lib.rs index b8c9176..1dd8204 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,8 +59,8 @@ pub use self::operator::{ InferShapeFn, Operator, OperatorDomain }; pub use self::session::{ - GraphOptimizationLevel, InMemorySession, Input, Output, OutputSelector, RunOptions, Session, SessionBuilder, SessionInputValue, SessionInputs, - SessionOutputs, SharedSessionInner + GraphOptimizationLevel, HasSelectedOutputs, InMemorySession, InferenceFut, Input, NoSelectedOutputs, Output, OutputSelector, RunOptions, + SelectedOutputMarker, Session, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, SharedSessionInner }; #[cfg(feature = "ndarray")] #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] @@ -69,8 +69,8 @@ pub use self::tensor::{IntoTensorElementType, PrimitiveTensorElementType, Tensor pub use self::value::{ DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence, - SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker, Value, ValueRef, - ValueRefMut, ValueType, ValueTypeMarker + SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker, Value, + ValueRef, ValueRefMut, ValueType, ValueTypeMarker }; #[cfg(not(all(target_arch = "x86", target_os = "windows")))] diff --git a/src/session/async.rs b/src/session/async.rs index a63ea48..db13a84 100644 --- a/src/session/async.rs +++ b/src/session/async.rs @@ -11,7 +11,7 @@ use std::{ use ort_sys::{c_void, OrtStatus}; -use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SessionInputValue, SessionOutputs, SharedSessionInner, Value}; +use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner, Value}; #[derive(Debug)] pub(crate) struct InferenceFutInner<'r, 's> { @@ -49,25 +49,25 @@ impl<'r, 's> InferenceFutInner<'r, 's> { unsafe impl<'r, 's> Send for InferenceFutInner<'r, 's> {} unsafe impl<'r, 's> Sync for InferenceFutInner<'r, 's> {} -pub enum RunOptionsRef<'r> { - Arc(Arc), - Ref(&'r RunOptions) +pub enum RunOptionsRef<'r, O: SelectedOutputMarker> { + Arc(Arc>), + Ref(&'r RunOptions) } -impl<'r> From<&Arc> for RunOptionsRef<'r> { - fn from(value: &Arc) -> Self { +impl<'r, O: SelectedOutputMarker> From<&Arc>> for RunOptionsRef<'r, O> { + fn from(value: &Arc>) -> Self { Self::Arc(Arc::clone(value)) } } -impl<'r> From<&'r RunOptions> for RunOptionsRef<'r> { - fn from(value: &'r RunOptions) -> Self { +impl<'r, O: SelectedOutputMarker> From<&'r RunOptions> for RunOptionsRef<'r, O> { + fn from(value: &'r RunOptions) -> Self { Self::Ref(value) } } -impl<'r> Deref for RunOptionsRef<'r> { - type Target = RunOptions; +impl<'r, O: SelectedOutputMarker> Deref for RunOptionsRef<'r, O> { + type Target = RunOptions; fn deref(&self) -> &Self::Target { match self { @@ -77,14 +77,14 @@ impl<'r> Deref for RunOptionsRef<'r> { } } -pub struct InferenceFut<'s, 'r> { +pub struct InferenceFut<'s, 'r, O: SelectedOutputMarker> { inner: Arc>, - run_options: RunOptionsRef<'r>, + run_options: RunOptionsRef<'r, O>, did_receive: bool } -impl<'s, 'r> InferenceFut<'s, 'r> { - pub(crate) fn new(inner: Arc>, run_options: RunOptionsRef<'r>) -> Self { +impl<'s, 'r, O: SelectedOutputMarker> InferenceFut<'s, 'r, O> { + pub(crate) fn new(inner: Arc>, run_options: RunOptionsRef<'r, O>) -> Self { Self { inner, run_options, @@ -93,7 +93,7 @@ impl<'s, 'r> InferenceFut<'s, 'r> { } } -impl<'s, 'r> Future for InferenceFut<'s, 'r> { +impl<'s, 'r, O: SelectedOutputMarker> Future for InferenceFut<'s, 'r, O> { type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -109,7 +109,7 @@ impl<'s, 'r> Future for InferenceFut<'s, 'r> { } } -impl<'s, 'r> Drop for InferenceFut<'s, 'r> { +impl<'s, 'r, O: SelectedOutputMarker> Drop for InferenceFut<'s, 'r, O> { fn drop(&mut self) { if !self.did_receive { let _ = self.run_options.terminate(); diff --git a/src/session/mod.rs b/src/session/mod.rs index 865e09b..fd78a11 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -2,8 +2,6 @@ use std::{any::Any, ffi::CString, marker::PhantomData, ops::Deref, os::raw::c_char, ptr::NonNull, sync::Arc}; -use r#async::RunOptionsRef; - use super::{ char_p_to_string, environment::Environment, @@ -21,13 +19,13 @@ pub(crate) mod builder; pub(crate) mod input; pub(crate) mod output; mod run_options; -use self::r#async::{AsyncInferenceContext, InferenceFutInner}; +use self::r#async::{AsyncInferenceContext, InferenceFutInner, RunOptionsRef}; pub use self::{ r#async::InferenceFut, builder::{GraphOptimizationLevel, SessionBuilder}, input::{SessionInputValue, SessionInputs}, output::SessionOutputs, - run_options::{OutputSelector, RunOptions} + run_options::{HasSelectedOutputs, NoSelectedOutputs, OutputSelector, RunOptions, SelectedOutputMarker} }; /// Holds onto an [`ort_sys::OrtSession`] pointer and its associated allocator. @@ -164,14 +162,16 @@ impl Session { pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s self, input_values: impl Into>) -> Result> { match input_values.into() { SessionInputs::ValueSlice(input_values) => { - self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) + self.run_inner::(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) } SessionInputs::ValueArray(input_values) => { - self.run_inner(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) - } - SessionInputs::ValueMap(input_values) => { - self.run_inner(&input_values.iter().map(|(k, _)| k.as_ref()).collect::>(), input_values.iter().map(|(_, v)| v), None) + self.run_inner::(&self.inputs.iter().map(|input| input.name.as_str()).collect::>(), input_values.iter(), None) } + SessionInputs::ValueMap(input_values) => self.run_inner::( + &input_values.iter().map(|(k, _)| k.as_ref()).collect::>(), + input_values.iter().map(|(_, v)| v), + None + ) } } @@ -201,10 +201,10 @@ impl Session { /// # Ok(()) /// # } /// ``` - pub fn run_with_options<'r, 's: 'r, 'i, 'v: 'i, const N: usize>( + pub fn run_with_options<'r, 's: 'r, 'i, 'v: 'i, O: SelectedOutputMarker, const N: usize>( &'s self, input_values: impl Into>, - run_options: &'r RunOptions + run_options: &'r RunOptions ) -> Result> { match input_values.into() { SessionInputs::ValueSlice(input_values) => { @@ -219,11 +219,11 @@ impl Session { } } - fn run_inner<'i, 'r, 's: 'r, 'v: 'i>( + fn run_inner<'i, 'r, 's: 'r, 'v: 'i, O: SelectedOutputMarker>( &'s self, input_names: &[&str], input_values: impl Iterator>, - run_options: Option<&'r RunOptions> + run_options: Option<&'r RunOptions> ) -> Result> { let input_names_ptr: Vec<*const c_char> = input_names .iter() @@ -321,7 +321,7 @@ impl Session { pub fn run_async<'s, 'i, 'v: 'i + 's, const N: usize>( &'s self, input_values: impl Into> + 'static - ) -> Result> { + ) -> Result> { match input_values.into() { SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"), SessionInputs::ValueArray(input_values) => { @@ -335,11 +335,11 @@ impl Session { /// Asynchronously run input data through the ONNX graph, performing inference, with the given [`RunOptions`]. /// See [`Session::run_with_options`] and [`Session::run_async`] for more details. - pub fn run_async_with_options<'s, 'i, 'v: 'i + 's, 'r, const N: usize>( + pub fn run_async_with_options<'s, 'i, 'v: 'i + 's, 'r, O: SelectedOutputMarker, const N: usize>( &'s self, input_values: impl Into> + 'static, - run_options: &'r RunOptions - ) -> Result> { + run_options: &'r RunOptions + ) -> Result> { match input_values.into() { SessionInputs::ValueSlice(_) => unimplemented!("slices cannot be used in `run_async`"), SessionInputs::ValueArray(input_values) => { @@ -353,17 +353,20 @@ impl Session { } } - fn run_inner_async<'s, 'v: 's, 'r>( + fn run_inner_async<'s, 'v: 's, 'r, O: SelectedOutputMarker>( &'s self, input_names: &[String], input_values: impl Iterator>, - run_options: Option<&'r RunOptions> - ) -> Result> { + run_options: Option<&'r RunOptions> + ) -> Result> { let run_options = match run_options { Some(r) => RunOptionsRef::Ref(r), // create a `RunOptions` to pass to the future so that when it drops, it terminates inference - crucial // (performance-wise) for routines involving `tokio::select!` or timeouts - None => RunOptionsRef::Arc(Arc::new(RunOptions::new()?)) + None => RunOptionsRef::Arc(Arc::new(unsafe { + // SAFETY: transmuting from `RunOptions` to `RunOptions`; safe because its just a marker + std::mem::transmute(RunOptions::new()?) + })) }; let input_name_ptrs: Vec<*const c_char> = input_names diff --git a/src/session/run_options.rs b/src/session/run_options.rs index fa5ef21..ae222cb 100644 --- a/src/session/run_options.rs +++ b/src/session/run_options.rs @@ -2,6 +2,30 @@ use std::{collections::HashMap, ffi::CString, marker::PhantomData, ptr::NonNull, use crate::{ortsys, DynValue, Error, Output, Result, Value, ValueTypeMarker}; +/// Allows selecting/deselecting/preallocating the outputs of a [`crate::Session`] inference call. +/// +/// ``` +/// # use std::sync::Arc; +/// # use ort::{Session, Allocator, RunOptions, OutputSelector, Tensor}; +/// # fn main() -> ort::Result<()> { +/// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; +/// let input = Tensor::::new(&Allocator::default(), [1, 64, 64, 3])?; +/// +/// let output0 = session.outputs[0].name.as_str(); +/// let options = RunOptions::new()?.with_outputs( +/// // Disable all outputs... +/// OutputSelector::no_default() +/// // except for the first one... +/// .with(output0) +/// // and since this is a 2x upsampler model, pre-allocate the output to be twice as large. +/// .preallocate(output0, Tensor::::new(&Allocator::default(), [1, 128, 128, 3])?) +/// ); +/// +/// // `outputs[0]` will be the tensor we just pre-allocated. +/// let outputs = session.run_with_options(ort::inputs![input]?, &options)?; +/// # Ok(()) +/// # } +/// ``` #[derive(Debug)] pub struct OutputSelector { use_defaults: bool, @@ -11,6 +35,8 @@ pub struct OutputSelector { } impl Default for OutputSelector { + /// Creates an [`OutputSelector`] that enables all outputs by default. Use [`OutputSelector::without`] to disable a + /// specific output. fn default() -> Self { Self { use_defaults: true, @@ -22,6 +48,8 @@ impl Default for OutputSelector { } impl OutputSelector { + /// Creates an [`OutputSelector`] that does not enable any outputs. Use [`OutputSelector::with`] to enable a + /// specific output. pub fn no_default() -> Self { Self { use_defaults: false, @@ -29,16 +57,46 @@ impl OutputSelector { } } + /// Mark the output specified by the `name` for inclusion. pub fn with(mut self, name: impl Into) -> Self { self.allowlist.push(name.into()); self } + /// Mark the output specified by `name` to be **excluded**. ONNX Runtime may prune some of the output node's + /// ancestor nodes. pub fn without(mut self, name: impl Into) -> Self { self.default_blocklist.push(name.into()); self } + /// Pre-allocates an output. Assuming the type & shape of the value matches what is expected by the model, the + /// output value corresponding to `name` returned by the inference call will be the exact same value as the + /// pre-allocated value. + /// + /// **The same value will be reused as long as this [`OutputSelector`] and its parent [`RunOptions`] is used**, so + /// if you use the same `RunOptions` across multiple runs with a preallocated value, the preallocated value will be + /// overwritten upon each run. + /// + /// This can improve performance if the size and type of the output is known, and does not change between runs, i.e. + /// for an ODE or embeddings model. + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{Session, Allocator, RunOptions, OutputSelector, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let input = Tensor::::new(&Allocator::default(), [1, 64, 64, 3])?; + /// + /// let output0 = session.outputs[0].name.as_str(); + /// let options = RunOptions::new()?.with_outputs( + /// OutputSelector::default().preallocate(output0, Tensor::::new(&Allocator::default(), [1, 128, 128, 3])?) + /// ); + /// + /// let outputs = session.run_with_options(ort::inputs![input]?, &options)?; + /// # Ok(()) + /// # } + /// ``` pub fn preallocate(mut self, name: impl Into, value: Value) -> Self { self.preallocated_outputs.insert(name.into(), value.into_dyn()); self @@ -62,32 +120,96 @@ impl OutputSelector { } } -/// A structure which can be passed to [`crate::Session::run_with_options`] to allow terminating/unterminating a session -/// inference run from a different thread. +/// Types that specify whether a [`RunOptions`] was configured with an [`OutputSelector`]. +pub trait SelectedOutputMarker {} +/// Marks that a [`RunOptions`] was not configured with an [`OutputSelector`]. +pub struct NoSelectedOutputs; +impl SelectedOutputMarker for NoSelectedOutputs {} +/// Marks that a [`RunOptions`] was configured with an [`OutputSelector`]. +pub struct HasSelectedOutputs; +impl SelectedOutputMarker for HasSelectedOutputs {} + +/// Allows for finer control over session inference. +/// +/// [`RunOptions`] provides three main features: +/// - **Run tagging**: Each individual session run can have a uniquely identifiable tag attached with +/// [`RunOptions::set_tag`], which will show up in logs. This can be especially useful for debugging +/// performance/errors in inference servers. +/// - **Termination**: Allows for terminating an inference call from another thread; when [`RunOptions::terminate`] is +/// called, any sessions currently running under that [`RunOptions`] instance will halt graph execution as soon as the +/// termination signal is received. This allows for [`crate::Session::run_async`]'s cancel-safety. +/// - **Output specification**: Certain session outputs can be [disabled](`OutputSelector::without`) or +/// [pre-allocated](`OutputSelector::preallocate`). Disabling an output might mean ONNX Runtime will not execute parts +/// of the graph that are only used by that output. Pre-allocation can reduce expensive re-allocations by allowing you +/// to use the same memory across runs. +/// +/// [`RunOptions`] can be passed to most places where a session can be inferred, e.g. +/// [`crate::Session::run_with_options`], [`crate::Session::run_async_with_options`], +/// [`crate::IoBinding::run_with_options`]. Some of these patterns (notably `IoBinding`) do not accept +/// [`OutputSelector`], hence [`RunOptions`] contains an additional type parameter that marks whether or not outputs +/// have been selected. #[derive(Debug)] -pub struct RunOptions { +pub struct RunOptions { pub(crate) run_options_ptr: NonNull, - pub(crate) outputs: OutputSelector + pub(crate) outputs: OutputSelector, + _marker: PhantomData } // https://onnxruntime.ai/docs/api/c/struct_ort_api.html#ac2a08cac0a657604bd5899e0d1a13675 -unsafe impl Send for RunOptions {} -unsafe impl Sync for RunOptions {} +unsafe impl Send for RunOptions {} +// Only allow `Sync` if we don't have (potentially pre-allocated) outputs selected. +// Allowing `Sync` here would mean a single pre-allocated `Value` could be mutated simultaneously in different threads - +// a brazen crime against crabkind. +unsafe impl Sync for RunOptions {} impl RunOptions { /// Creates a new [`RunOptions`] struct. - pub fn new() -> Result { + pub fn new() -> Result> { let mut run_options_ptr: *mut ort_sys::OrtRunOptions = std::ptr::null_mut(); ortsys![unsafe CreateRunOptions(&mut run_options_ptr) -> Error::CreateRunOptions; nonNull(run_options_ptr)]; - Ok(Self { + Ok(RunOptions { run_options_ptr: unsafe { NonNull::new_unchecked(run_options_ptr) }, - outputs: OutputSelector::default() + outputs: OutputSelector::default(), + _marker: PhantomData }) } +} - pub fn with_outputs(mut self, outputs: OutputSelector) -> Self { +impl RunOptions { + /// Select/deselect/preallocate outputs for this run. + /// + /// See [`OutputSelector`] for more details. + /// + /// ``` + /// # use std::sync::Arc; + /// # use ort::{Session, Allocator, RunOptions, OutputSelector, Tensor}; + /// # fn main() -> ort::Result<()> { + /// let session = Session::builder()?.commit_from_file("tests/data/upsample.onnx")?; + /// let input = Tensor::::new(&Allocator::default(), [1, 64, 64, 3])?; + /// + /// let output0 = session.outputs[0].name.as_str(); + /// let options = RunOptions::new()?.with_outputs( + /// // Disable all outputs... + /// OutputSelector::no_default() + /// // except for the first one... + /// .with(output0) + /// // and since this is a 2x upsampler model, pre-allocate the output to be twice as large. + /// .preallocate(output0, Tensor::::new(&Allocator::default(), [1, 128, 128, 3])?) + /// ); + /// + /// // `outputs[0]` will be the tensor we just pre-allocated. + /// let outputs = session.run_with_options(ort::inputs![input]?, &options)?; + /// # Ok(()) + /// # } + /// ``` + pub fn with_outputs(mut self, outputs: OutputSelector) -> RunOptions { self.outputs = outputs; - self + unsafe { std::mem::transmute(self) } + } + + /// Sets a tag to identify this run in logs. + pub fn with_tag(mut self, tag: impl AsRef) -> Result { + self.set_tag(tag).map(|_| self) } /// Sets a tag to identify this run in logs. @@ -158,7 +280,7 @@ impl RunOptions { } } -impl Drop for RunOptions { +impl Drop for RunOptions { fn drop(&mut self) { ortsys![unsafe ReleaseRunOptions(self.run_options_ptr.as_ptr())]; } diff --git a/src/value/mod.rs b/src/value/mod.rs index cf50600..b68b301 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -16,7 +16,7 @@ pub use self::{ impl_sequence::{ DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, Sequence, SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker }, - impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker} + impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker} }; use crate::{error::status_to_result, memory::MemoryInfo, ortsys, session::SharedSessionInner, tensor::TensorElementType, Error, Result}; From dc79ade39d3221238b0251319f8f9b8988a5c121 Mon Sep 17 00:00:00 2001 From: "Carson M." Date: Mon, 1 Jul 2024 00:05:35 -0500 Subject: [PATCH 5/5] chore: more readable imports when auto-importing, rust-analyzer seems to only do absolute imports on odd-numbered days and Tuesdays, falling back to crate:: imports otherwise. this commit adds some consistency. --- src/environment.rs | 3 ++- src/error.rs | 2 +- src/execution_providers/acl.rs | 7 +++++-- src/execution_providers/armnn.rs | 7 +++++-- src/execution_providers/cann.rs | 7 +++++-- src/execution_providers/coreml.rs | 7 +++++-- src/execution_providers/cpu.rs | 8 ++++++-- src/execution_providers/cuda.rs | 7 +++++-- src/execution_providers/directml.rs | 7 +++++-- src/execution_providers/migraphx.rs | 7 +++++-- src/execution_providers/mod.rs | 7 ++++++- src/execution_providers/nnapi.rs | 7 +++++-- src/execution_providers/onednn.rs | 7 +++++-- src/execution_providers/openvino.rs | 7 +++++-- src/execution_providers/qnn.rs | 7 +++++-- src/execution_providers/rocm.rs | 7 +++++-- src/execution_providers/tensorrt.rs | 7 +++++-- src/execution_providers/tvm.rs | 7 +++++-- src/execution_providers/xnnpack.rs | 7 +++++-- src/io_binding.rs | 6 +++--- src/memory.rs | 9 +++++---- src/metadata.rs | 8 ++++++-- src/operator/io.rs | 2 +- src/operator/kernel.rs | 7 ++++++- src/operator/mod.rs | 7 +++++-- src/session/async.rs | 6 +++++- src/session/builder.rs | 5 +++-- src/session/input.rs | 8 +++----- src/session/mod.rs | 2 +- src/session/output.rs | 2 +- src/session/run_options.rs | 7 ++++++- src/tensor/types.rs | 5 ++++- src/value/impl_map.rs | 9 ++++++--- src/value/impl_sequence.rs | 8 ++++++-- src/value/impl_tensor/create.rs | 11 +++++------ src/value/impl_tensor/extract.rs | 7 +++++-- src/value/impl_tensor/mod.rs | 9 +++++++-- src/value/mod.rs | 8 +++++++- 38 files changed, 175 insertions(+), 76 deletions(-) diff --git a/src/environment.rs b/src/environment.rs index 6bc6038..9b94532 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -15,7 +15,8 @@ use tracing::{debug, Level}; use crate::G_ORT_DYLIB_PATH; use crate::{ error::{Error, Result}, - extern_system_fn, ortsys, ExecutionProviderDispatch + execution_providers::ExecutionProviderDispatch, + extern_system_fn, ortsys }; struct EnvironmentSingleton { diff --git a/src/error.rs b/src/error.rs index fb25f20..c7ef7d2 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,7 +4,7 @@ use std::{convert::Infallible, ffi::CString, io, path::PathBuf, ptr, string}; use thiserror::Error; -use super::{char_p_to_string, ortsys, tensor::TensorElementType, ValueType}; +use crate::{char_p_to_string, ortsys, tensor::TensorElementType, value::ValueType}; /// Type alias for the Result type returned by ORT functions. pub type Result = std::result::Result; diff --git a/src/execution_providers/acl.rs b/src/execution_providers/acl.rs index 1f15ac7..ddce629 100644 --- a/src/execution_providers/acl.rs +++ b/src/execution_providers/acl.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "acl"))] extern "C" { diff --git a/src/execution_providers/armnn.rs b/src/execution_providers/armnn.rs index 86332f0..c428feb 100644 --- a/src/execution_providers/armnn.rs +++ b/src/execution_providers/armnn.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "armnn"))] extern "C" { diff --git a/src/execution_providers/cann.rs b/src/execution_providers/cann.rs index f37a2f1..9189568 100644 --- a/src/execution_providers/cann.rs +++ b/src/execution_providers/cann.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{ArenaExtendStrategy, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Clone, PartialEq, Eq)] #[non_exhaustive] diff --git a/src/execution_providers/coreml.rs b/src/execution_providers/coreml.rs index 256de1e..2fa4aa7 100644 --- a/src/execution_providers/coreml.rs +++ b/src/execution_providers/coreml.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "coreml"))] extern "C" { diff --git a/src/execution_providers/cpu.rs b/src/execution_providers/cpu.rs index eb4be91..06e031b 100644 --- a/src/execution_providers/cpu.rs +++ b/src/execution_providers/cpu.rs @@ -1,5 +1,9 @@ -use super::ExecutionProvider; -use crate::{error::status_to_result, ortsys, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{status_to_result, Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + ortsys, + session::SessionBuilder +}; #[derive(Debug, Default, Clone)] pub struct CPUExecutionProvider { diff --git a/src/execution_providers/cuda.rs b/src/execution_providers/cuda.rs index 17fbe82..67cad84 100644 --- a/src/execution_providers/cuda.rs +++ b/src/execution_providers/cuda.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{ArenaExtendStrategy, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; /// The type of search done for cuDNN convolution algorithms. #[derive(Debug, Clone)] diff --git a/src/execution_providers/directml.rs b/src/execution_providers/directml.rs index 38556f1..085e68f 100644 --- a/src/execution_providers/directml.rs +++ b/src/execution_providers/directml.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "directml"))] extern "C" { diff --git a/src/execution_providers/migraphx.rs b/src/execution_providers/migraphx.rs index 2eb02ff..d3cc62a 100644 --- a/src/execution_providers/migraphx.rs +++ b/src/execution_providers/migraphx.rs @@ -1,7 +1,10 @@ use std::ffi::CString; -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Default, Clone)] pub struct MIGraphXExecutionProvider { diff --git a/src/execution_providers/mod.rs b/src/execution_providers/mod.rs index c7f4937..47fb855 100644 --- a/src/execution_providers/mod.rs +++ b/src/execution_providers/mod.rs @@ -1,6 +1,11 @@ use std::{fmt::Debug, os::raw::c_char, sync::Arc}; -use crate::{char_p_to_string, ortsys, Error, Result, SessionBuilder}; +use crate::{ + char_p_to_string, + error::{Error, Result}, + ortsys, + session::SessionBuilder +}; mod cpu; pub use self::cpu::CPUExecutionProvider; diff --git a/src/execution_providers/nnapi.rs b/src/execution_providers/nnapi.rs index 9f1951e..68d275a 100644 --- a/src/execution_providers/nnapi.rs +++ b/src/execution_providers/nnapi.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "nnapi"))] extern "C" { diff --git a/src/execution_providers/onednn.rs b/src/execution_providers/onednn.rs index 795d0e6..45dec27 100644 --- a/src/execution_providers/onednn.rs +++ b/src/execution_providers/onednn.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "onednn"))] extern "C" { diff --git a/src/execution_providers/openvino.rs b/src/execution_providers/openvino.rs index 95dc8e2..61924c5 100644 --- a/src/execution_providers/openvino.rs +++ b/src/execution_providers/openvino.rs @@ -1,7 +1,10 @@ use std::os::raw::c_void; -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Clone)] pub struct OpenVINOExecutionProvider { diff --git a/src/execution_providers/qnn.rs b/src/execution_providers/qnn.rs index eb7075d..54ee71b 100644 --- a/src/execution_providers/qnn.rs +++ b/src/execution_providers/qnn.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Clone)] pub enum QNNExecutionProviderPerformanceMode { diff --git a/src/execution_providers/rocm.rs b/src/execution_providers/rocm.rs index be4cfde..3c3553b 100644 --- a/src/execution_providers/rocm.rs +++ b/src/execution_providers/rocm.rs @@ -1,7 +1,10 @@ use std::os::raw::c_void; -use super::ExecutionProvider; -use crate::{ArenaExtendStrategy, Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Clone)] pub struct ROCmExecutionProvider { diff --git a/src/execution_providers/tensorrt.rs b/src/execution_providers/tensorrt.rs index e60e16f..1ea8dd8 100644 --- a/src/execution_providers/tensorrt.rs +++ b/src/execution_providers/tensorrt.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Default, Clone)] pub struct TensorRTExecutionProvider { diff --git a/src/execution_providers/tvm.rs b/src/execution_providers/tvm.rs index 19c8ea7..6a04d94 100644 --- a/src/execution_providers/tvm.rs +++ b/src/execution_providers/tvm.rs @@ -1,5 +1,8 @@ -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[cfg(all(not(feature = "load-dynamic"), feature = "tvm"))] extern "C" { diff --git a/src/execution_providers/xnnpack.rs b/src/execution_providers/xnnpack.rs index 8793326..bd3763e 100644 --- a/src/execution_providers/xnnpack.rs +++ b/src/execution_providers/xnnpack.rs @@ -1,7 +1,10 @@ use std::num::NonZeroUsize; -use super::ExecutionProvider; -use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder}; +use crate::{ + error::{Error, Result}, + execution_providers::{ExecutionProvider, ExecutionProviderDispatch}, + session::SessionBuilder +}; #[derive(Debug, Default, Clone)] pub struct XNNPACKExecutionProvider { diff --git a/src/io_binding.rs b/src/io_binding.rs index f0a704a..e3b9b76 100644 --- a/src/io_binding.rs +++ b/src/io_binding.rs @@ -8,11 +8,11 @@ use std::{ }; use crate::{ + error::{Error, Result}, memory::MemoryInfo, ortsys, - session::{output::SessionOutputs, RunOptions}, - value::{Value, ValueInner}, - DynValue, Error, NoSelectedOutputs, Result, Session, ValueTypeMarker + session::{output::SessionOutputs, NoSelectedOutputs, RunOptions, Session}, + value::{DynValue, Value, ValueInner, ValueTypeMarker} }; /// Enables binding of session inputs and/or outputs to pre-allocated memory. diff --git a/src/memory.rs b/src/memory.rs index 5ffc85b..74eb770 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -4,11 +4,12 @@ use std::{ sync::Arc }; -use super::{ - error::{Error, Result}, - ortsys +use crate::{ + char_p_to_string, + error::{status_to_result, Error, Result}, + ortsys, + session::{Session, SharedSessionInner} }; -use crate::{char_p_to_string, error::status_to_result, Session, SharedSessionInner}; /// A device allocator used to manage the allocation of [`crate::Value`]s. /// diff --git a/src/metadata.rs b/src/metadata.rs index 5464e5f..84fc69e 100644 --- a/src/metadata.rs +++ b/src/metadata.rs @@ -1,7 +1,11 @@ use std::{ffi::CString, os::raw::c_char, ptr::NonNull}; -use super::{char_p_to_string, error::Result, ortsys, Error}; -use crate::Allocator; +use crate::{ + char_p_to_string, + error::{Error, Result}, + memory::Allocator, + ortsys +}; /// Container for model metadata, including name & producer information. pub struct ModelMetadata<'s> { diff --git a/src/operator/io.rs b/src/operator/io.rs index 5a7507a..16d0e93 100644 --- a/src/operator/io.rs +++ b/src/operator/io.rs @@ -1,4 +1,4 @@ -use crate::{MemoryType, TensorElementType}; +use crate::{memory::MemoryType, tensor::TensorElementType}; #[repr(i32)] #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] diff --git a/src/operator/kernel.rs b/src/operator/kernel.rs index db09aa2..8c9280c 100644 --- a/src/operator/kernel.rs +++ b/src/operator/kernel.rs @@ -3,7 +3,12 @@ use std::{ ptr::{self, NonNull} }; -use crate::{error::status_to_result, ortsys, value::ValueRefMut, Allocator, DowncastableTarget, DynValue, Error, Result, Value, ValueRef}; +use crate::{ + error::{status_to_result, Error, Result}, + memory::Allocator, + ortsys, + value::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut} +}; pub trait Kernel { fn compute(&mut self, ctx: &KernelContext) -> crate::Result<()>; diff --git a/src/operator/mod.rs b/src/operator/mod.rs index ad361f2..207a74d 100644 --- a/src/operator/mod.rs +++ b/src/operator/mod.rs @@ -8,11 +8,14 @@ pub(crate) mod io; pub(crate) mod kernel; use self::{ - bound::ErasedBoundOperator, + bound::{BoundOperator, ErasedBoundOperator}, io::{OperatorInput, OperatorOutput}, kernel::{DummyKernel, Kernel, KernelAttributes} }; -use crate::{operator::bound::BoundOperator, ortsys, Error, Result}; +use crate::{ + error::{Error, Result}, + ortsys +}; pub type InferShapeFn = dyn FnMut(*mut ort_sys::OrtShapeInferContext) -> crate::Result<()>; diff --git a/src/session/async.rs b/src/session/async.rs index db13a84..c02a8eb 100644 --- a/src/session/async.rs +++ b/src/session/async.rs @@ -11,7 +11,11 @@ use std::{ use ort_sys::{c_void, OrtStatus}; -use crate::{error::assert_non_null_pointer, Error, Result, RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner, Value}; +use crate::{ + error::{assert_non_null_pointer, Error, Result}, + session::{RunOptions, SelectedOutputMarker, SessionInputValue, SessionOutputs, SharedSessionInner}, + value::Value +}; #[derive(Debug)] pub(crate) struct InferenceFutInner<'r, 's> { diff --git a/src/session/builder.rs b/src/session/builder.rs index e105fd7..8632c57 100644 --- a/src/session/builder.rs +++ b/src/session/builder.rs @@ -23,8 +23,9 @@ use crate::{ environment::get_environment, error::{assert_non_null_pointer, status_to_result, Error, Result}, execution_providers::{apply_execution_providers, ExecutionProviderDispatch}, - memory::Allocator, - ortsys, MemoryInfo, OperatorDomain + memory::{Allocator, MemoryInfo}, + operator::OperatorDomain, + ortsys }; /// Creates a session using the builder pattern. diff --git a/src/session/input.rs b/src/session/input.rs index 61d55e5..31a1433 100644 --- a/src/session/input.rs +++ b/src/session/input.rs @@ -1,9 +1,6 @@ use std::{borrow::Cow, collections::HashMap, ops::Deref}; -use crate::{ - value::{DynValueTypeMarker, ValueTypeMarker}, - Value, ValueRef, ValueRefMut -}; +use crate::value::{DynValueTypeMarker, Value, ValueRef, ValueRefMut, ValueTypeMarker}; pub enum SessionInputValue<'v> { ViewMut(ValueRefMut<'v, DynValueTypeMarker>), @@ -140,7 +137,8 @@ macro_rules! inputs { mod tests { use std::{collections::HashMap, sync::Arc}; - use crate::{DynTensor, SessionInputs}; + use super::SessionInputs; + use crate::value::DynTensor; #[test] fn test_hashmap_static_keys() -> crate::Result<()> { diff --git a/src/session/mod.rs b/src/session/mod.rs index fd78a11..53d9518 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -2,7 +2,7 @@ use std::{any::Any, ffi::CString, marker::PhantomData, ops::Deref, os::raw::c_char, ptr::NonNull, sync::Arc}; -use super::{ +use crate::{ char_p_to_string, environment::Environment, error::{assert_non_null_pointer, assert_null_pointer, status_to_result, Error, ErrorInternal, Result}, diff --git a/src/session/output.rs b/src/session/output.rs index 74e1332..2409d6b 100644 --- a/src/session/output.rs +++ b/src/session/output.rs @@ -4,7 +4,7 @@ use std::{ ops::{Deref, DerefMut, Index} }; -use crate::{Allocator, DynValue}; +use crate::{memory::Allocator, value::DynValue}; /// The outputs returned by a [`crate::Session`] inference call. /// diff --git a/src/session/run_options.rs b/src/session/run_options.rs index ae222cb..4236892 100644 --- a/src/session/run_options.rs +++ b/src/session/run_options.rs @@ -1,6 +1,11 @@ use std::{collections::HashMap, ffi::CString, marker::PhantomData, ptr::NonNull, sync::Arc}; -use crate::{ortsys, DynValue, Error, Output, Result, Value, ValueTypeMarker}; +use crate::{ + error::{Error, Result}, + ortsys, + session::Output, + value::{DynValue, Value, ValueTypeMarker} +}; /// Allows selecting/deselecting/preallocating the outputs of a [`crate::Session`] inference call. /// diff --git a/src/tensor/types.rs b/src/tensor/types.rs index aabe683..f5f0a1a 100644 --- a/src/tensor/types.rs +++ b/src/tensor/types.rs @@ -2,7 +2,10 @@ use std::ptr; #[cfg(feature = "ndarray")] -use crate::{ortsys, Error, Result}; +use crate::{ + error::{Error, Result}, + ortsys +}; /// Enum mapping ONNX Runtime's supported tensor data types. #[derive(Debug, PartialEq, Eq, Clone, Copy)] diff --git a/src/value/impl_map.rs b/src/value/impl_map.rs index f638787..87d653f 100644 --- a/src/value/impl_map.rs +++ b/src/value/impl_map.rs @@ -7,12 +7,15 @@ use std::{ sync::Arc }; -use super::{ValueInner, ValueTypeMarker}; +use super::{ + impl_tensor::{calculate_tensor_size, DynTensor, Tensor}, + DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker +}; use crate::{ + error::{Error, Result}, memory::Allocator, ortsys, - value::impl_tensor::{calculate_tensor_size, DynTensor}, - DynValue, Error, IntoTensorElementType, PrimitiveTensorElementType, Result, Tensor, TensorElementType, Value, ValueRef, ValueRefMut, ValueType + tensor::{IntoTensorElementType, PrimitiveTensorElementType, TensorElementType} }; pub trait MapValueTypeMarker: ValueTypeMarker { diff --git a/src/value/impl_sequence.rs b/src/value/impl_sequence.rs index 2923e27..8b209e1 100644 --- a/src/value/impl_sequence.rs +++ b/src/value/impl_sequence.rs @@ -5,8 +5,12 @@ use std::{ sync::Arc }; -use super::{DowncastableTarget, ValueInner, ValueTypeMarker}; -use crate::{memory::Allocator, ortsys, Error, Result, Value, ValueRef, ValueRefMut, ValueType}; +use super::{DowncastableTarget, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker}; +use crate::{ + error::{Error, Result}, + memory::Allocator, + ortsys +}; pub trait SequenceValueTypeMarker: ValueTypeMarker { crate::private_trait!(); diff --git a/src/value/impl_tensor/create.rs b/src/value/impl_tensor/create.rs index 8a2e950..391f448 100644 --- a/src/value/impl_tensor/create.rs +++ b/src/value/impl_tensor/create.rs @@ -10,14 +10,13 @@ use std::{ #[cfg(feature = "ndarray")] use ndarray::{ArcArray, Array, ArrayView, CowArray, Dimension}; -use super::{DynTensor, Tensor}; +use super::{calculate_tensor_size, DynTensor, Tensor, TensorRefMut}; use crate::{ - error::assert_non_null_pointer, - memory::{Allocator, MemoryInfo}, + error::{assert_non_null_pointer, Error, Result}, + memory::{AllocationDevice, Allocator, AllocatorType, MemoryInfo, MemoryType}, ortsys, - tensor::{TensorElementType, Utf8Data}, - value::{impl_tensor::calculate_tensor_size, ValueInner}, - AllocationDevice, AllocatorType, DynValue, Error, MemoryType, PrimitiveTensorElementType, Result, TensorRefMut, Value + tensor::{PrimitiveTensorElementType, TensorElementType, Utf8Data}, + value::{DynValue, Value, ValueInner} }; impl Tensor { diff --git a/src/value/impl_tensor/extract.rs b/src/value/impl_tensor/extract.rs index bec573f..1b27d8a 100644 --- a/src/value/impl_tensor/extract.rs +++ b/src/value/impl_tensor/extract.rs @@ -3,11 +3,14 @@ use std::{fmt::Debug, ptr, string::FromUtf8Error}; #[cfg(feature = "ndarray")] use ndarray::IxDyn; -use super::TensorValueTypeMarker; +use super::{calculate_tensor_size, Tensor, TensorValueTypeMarker}; #[cfg(feature = "ndarray")] use crate::tensor::{extract_primitive_array, extract_primitive_array_mut}; use crate::{ - ortsys, tensor::TensorElementType, value::impl_tensor::calculate_tensor_size, Error, PrimitiveTensorElementType, Result, Tensor, Value, ValueType + error::{Error, Result}, + ortsys, + tensor::{PrimitiveTensorElementType, TensorElementType}, + value::{Value, ValueType} }; impl Value { diff --git a/src/value/impl_tensor/mod.rs b/src/value/impl_tensor/mod.rs index a4c7ba6..92a08c9 100644 --- a/src/value/impl_tensor/mod.rs +++ b/src/value/impl_tensor/mod.rs @@ -8,8 +8,13 @@ use std::{ ptr::NonNull }; -use super::{DowncastableTarget, Value, ValueInner, ValueTypeMarker}; -use crate::{ortsys, DynValue, Error, IntoTensorElementType, MemoryInfo, Result, ValueRef, ValueRefMut, ValueType}; +use super::{DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker}; +use crate::{ + error::{Error, Result}, + memory::MemoryInfo, + ortsys, + tensor::IntoTensorElementType +}; pub trait TensorValueTypeMarker: ValueTypeMarker { crate::private_trait!(); diff --git a/src/value/mod.rs b/src/value/mod.rs index b68b301..3aa57ef 100644 --- a/src/value/mod.rs +++ b/src/value/mod.rs @@ -18,7 +18,13 @@ pub use self::{ }, impl_tensor::{DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker} }; -use crate::{error::status_to_result, memory::MemoryInfo, ortsys, session::SharedSessionInner, tensor::TensorElementType, Error, Result}; +use crate::{ + error::{status_to_result, Error, Result}, + memory::MemoryInfo, + ortsys, + session::SharedSessionInner, + tensor::TensorElementType +}; /// The type of a [`Value`], or a session input/output. ///