Skip to content

Commit

Permalink
Merge branch 'main' into training
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Jul 1, 2024
2 parents f99f86b + dc79ade commit 8a0646c
Show file tree
Hide file tree
Showing 39 changed files with 383 additions and 219 deletions.
39 changes: 8 additions & 31 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use tracing::{debug, Level};
use crate::G_ORT_DYLIB_PATH;
use crate::{
error::{Error, Result},
extern_system_fn, ortsys, ExecutionProviderDispatch
execution_providers::ExecutionProviderDispatch,
extern_system_fn, ortsys
};

struct EnvironmentSingleton {
Expand Down Expand Up @@ -268,45 +269,21 @@ pub fn init_from(path: impl ToString) -> EnvironmentBuilder {
EnvironmentBuilder::new()
}

/// ONNX's logger sends the code location where the log occurred, which will be parsed into this struct.
#[derive(Debug)]
struct CodeLocation<'a> {
file: &'a str,
line: &'a str,
function: &'a str
}

impl<'a> From<&'a str> for CodeLocation<'a> {
fn from(code_location: &'a str) -> Self {
let mut splitter = code_location.split(' ');
let file_and_line = splitter.next().unwrap_or("<unknown file>:<unknown line>");
let function = splitter.next().unwrap_or("<unknown function>");
let mut file_and_line_splitter = file_and_line.split(':');
let file = file_and_line_splitter.next().unwrap_or("<unknown file>");
let line = file_and_line_splitter.next().unwrap_or("<unknown line>");

CodeLocation { file, line, function }
}
}

extern_system_fn! {
/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate.
pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: ort_sys::OrtLoggingLevel, category: *const c_char, _: *const c_char, code_location: *const c_char, message: *const c_char) {
assert_ne!(category, ptr::null());
let category = unsafe { CStr::from_ptr(category) }.to_str().unwrap_or("<decode error>");
pub(crate) fn custom_logger(_params: *mut ffi::c_void, severity: ort_sys::OrtLoggingLevel, _: *const c_char, id: *const c_char, code_location: *const c_char, message: *const c_char) {
assert_ne!(code_location, ptr::null());
let code_location_str = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("<decode error>");
let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("<decode error>");
assert_ne!(message, ptr::null());
let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or("<decode error>");
assert_ne!(id, ptr::null());
let id = unsafe { CStr::from_ptr(id) }.to_str().unwrap_or("<decode error>");

let code_location = CodeLocation::from(code_location_str);
let span = tracing::span!(
Level::TRACE,
"ort",
category = category,
file = code_location.file,
line = code_location.line,
function = code_location.function
id = id,
location = code_location
);

match severity {
Expand Down
2 changes: 1 addition & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{convert::Infallible, ffi::CString, io, path::PathBuf, ptr, string};

use thiserror::Error;

use super::{char_p_to_string, ortsys, tensor::TensorElementType, ValueType};
use crate::{char_p_to_string, ortsys, tensor::TensorElementType, value::ValueType};

/// Type alias for the Result type returned by ORT functions.
pub type Result<T, E = Error> = std::result::Result<T, E>;
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/acl.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::ExecutionProvider;
use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[cfg(all(not(feature = "load-dynamic"), feature = "acl"))]
extern "C" {
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/armnn.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::ExecutionProvider;
use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[cfg(all(not(feature = "load-dynamic"), feature = "armnn"))]
extern "C" {
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/cann.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::ExecutionProvider;
use crate::{ArenaExtendStrategy, Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/coreml.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::ExecutionProvider;
use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[cfg(all(not(feature = "load-dynamic"), feature = "coreml"))]
extern "C" {
Expand Down
8 changes: 6 additions & 2 deletions src/execution_providers/cpu.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use super::ExecutionProvider;
use crate::{error::status_to_result, ortsys, Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{status_to_result, Error, Result},
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
ortsys,
session::SessionBuilder
};

#[derive(Debug, Default, Clone)]
pub struct CPUExecutionProvider {
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/cuda.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::ExecutionProvider;
use crate::{ArenaExtendStrategy, Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

/// The type of search done for cuDNN convolution algorithms.
#[derive(Debug, Clone)]
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/directml.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::ExecutionProvider;
use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[cfg(all(not(feature = "load-dynamic"), feature = "directml"))]
extern "C" {
Expand Down
17 changes: 12 additions & 5 deletions src/execution_providers/migraphx.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::{ffi::CString, ptr};
use std::ffi::CString;

use super::ExecutionProvider;
use crate::{ortsys, Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[derive(Debug, Default, Clone)]
pub struct MIGraphXExecutionProvider {
Expand Down Expand Up @@ -68,9 +71,13 @@ impl ExecutionProvider for MIGraphXExecutionProvider {
migraphx_fp16_enable: self.enable_fp16.into(),
migraphx_int8_enable: self.enable_int8.into(),
migraphx_use_native_calibration_table: self.use_native_calibration_table.into(),
migraphx_int8_calibration_table_name: self.int8_calibration_table_name.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null)
migraphx_int8_calibration_table_name: self
.int8_calibration_table_name
.as_ref()
.map(|c| c.as_ptr())
.unwrap_or_else(std::ptr::null)
};
ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.session_options_ptr.as_ptr(), &options) -> Error::ExecutionProvider];
crate::ortsys![unsafe SessionOptionsAppendExecutionProvider_MIGraphX(session_builder.session_options_ptr.as_ptr(), &options) -> Error::ExecutionProvider];
return Ok(());
}

Expand Down
7 changes: 6 additions & 1 deletion src/execution_providers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use std::{fmt::Debug, os::raw::c_char, sync::Arc};

use crate::{char_p_to_string, ortsys, Error, Result, SessionBuilder};
use crate::{
char_p_to_string,
error::{Error, Result},
ortsys,
session::SessionBuilder
};

mod cpu;
pub use self::cpu::CPUExecutionProvider;
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/nnapi.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::ExecutionProvider;
use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[cfg(all(not(feature = "load-dynamic"), feature = "nnapi"))]
extern "C" {
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/onednn.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::ExecutionProvider;
use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[cfg(all(not(feature = "load-dynamic"), feature = "onednn"))]
extern "C" {
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/openvino.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::os::raw::c_void;

use super::ExecutionProvider;
use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[derive(Debug, Clone)]
pub struct OpenVINOExecutionProvider {
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/qnn.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::ExecutionProvider;
use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[derive(Debug, Clone)]
pub enum QNNExecutionProviderPerformanceMode {
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/rocm.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::os::raw::c_void;

use super::ExecutionProvider;
use crate::{ArenaExtendStrategy, Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ArenaExtendStrategy, ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[derive(Debug, Clone)]
pub struct ROCmExecutionProvider {
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/tensorrt.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::ExecutionProvider;
use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[derive(Debug, Default, Clone)]
pub struct TensorRTExecutionProvider {
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/tvm.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use super::ExecutionProvider;
use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[cfg(all(not(feature = "load-dynamic"), feature = "tvm"))]
extern "C" {
Expand Down
7 changes: 5 additions & 2 deletions src/execution_providers/xnnpack.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::num::NonZeroUsize;

use super::ExecutionProvider;
use crate::{Error, ExecutionProviderDispatch, Result, SessionBuilder};
use crate::{
error::{Error, Result},
execution_providers::{ExecutionProvider, ExecutionProviderDispatch},
session::SessionBuilder
};

#[derive(Debug, Default, Clone)]
pub struct XNNPACKExecutionProvider {
Expand Down
10 changes: 5 additions & 5 deletions src/io_binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ use std::{
};

use crate::{
error::{Error, Result},
memory::MemoryInfo,
ortsys,
session::{output::SessionOutputs, RunOptions},
value::{Value, ValueInner},
DynValue, Error, Result, Session, ValueTypeMarker
session::{output::SessionOutputs, NoSelectedOutputs, RunOptions, Session},
value::{DynValue, Value, ValueInner, ValueTypeMarker}
};

/// Enables binding of session inputs and/or outputs to pre-allocated memory.
Expand Down Expand Up @@ -177,11 +177,11 @@ impl<'s> IoBinding<'s> {
}

/// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`].
pub fn run_with_options(&mut self, run_options: &RunOptions) -> Result<SessionOutputs<'_, 's>> {
pub fn run_with_options(&mut self, run_options: &RunOptions<NoSelectedOutputs>) -> Result<SessionOutputs<'_, 's>> {
self.run_inner(Some(run_options))
}

fn run_inner(&mut self, run_options: Option<&RunOptions>) -> Result<SessionOutputs<'_, 's>> {
fn run_inner(&mut self, run_options: Option<&RunOptions<NoSelectedOutputs>>) -> Result<SessionOutputs<'_, 's>> {
let run_options_ptr = if let Some(run_options) = run_options {
run_options.run_options_ptr.as_ptr()
} else {
Expand Down
8 changes: 4 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ pub use self::operator::{
InferShapeFn, Operator, OperatorDomain
};
pub use self::session::{
GraphOptimizationLevel, InMemorySession, Input, Output, OutputSelector, RunOptions, Session, SessionBuilder, SessionInputValue, SessionInputs,
SessionOutputs, SharedSessionInner
GraphOptimizationLevel, HasSelectedOutputs, InMemorySession, InferenceFut, Input, NoSelectedOutputs, Output, OutputSelector, RunOptions,
SelectedOutputMarker, Session, SessionBuilder, SessionInputValue, SessionInputs, SessionOutputs, SharedSessionInner
};
#[cfg(feature = "ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))]
Expand All @@ -75,8 +75,8 @@ pub use self::training::*;
pub use self::value::{
DowncastableTarget, DynMap, DynMapRef, DynMapRefMut, DynMapValueType, DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, DynTensor,
DynTensorRef, DynTensorRefMut, DynTensorValueType, DynValue, DynValueTypeMarker, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker, Sequence,
SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueTypeMarker, Value, ValueRef,
ValueRefMut, ValueType, ValueTypeMarker
SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker, Tensor, TensorRef, TensorRefMut, TensorValueType, TensorValueTypeMarker, Value,
ValueRef, ValueRefMut, ValueType, ValueTypeMarker
};

#[cfg(not(all(target_arch = "x86", target_os = "windows")))]
Expand Down
9 changes: 5 additions & 4 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use std::{
sync::Arc
};

use super::{
error::{Error, Result},
ortsys
use crate::{
char_p_to_string,
error::{status_to_result, Error, Result},
ortsys,
session::{Session, SharedSessionInner}
};
use crate::{char_p_to_string, error::status_to_result, Session, SharedSessionInner};

/// A device allocator used to manage the allocation of [`crate::Value`]s.
///
Expand Down
8 changes: 6 additions & 2 deletions src/metadata.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use std::{ffi::CString, os::raw::c_char, ptr::NonNull};

use super::{char_p_to_string, error::Result, ortsys, Error};
use crate::Allocator;
use crate::{
char_p_to_string,
error::{Error, Result},
memory::Allocator,
ortsys
};

/// Container for model metadata, including name & producer information.
pub struct ModelMetadata<'s> {
Expand Down
2 changes: 1 addition & 1 deletion src/operator/io.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{MemoryType, TensorElementType};
use crate::{memory::MemoryType, tensor::TensorElementType};

#[repr(i32)]
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
Expand Down
7 changes: 6 additions & 1 deletion src/operator/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@ use std::{
ptr::{self, NonNull}
};

use crate::{error::status_to_result, ortsys, value::ValueRefMut, Allocator, DowncastableTarget, DynValue, Error, Result, Value, ValueRef};
use crate::{
error::{status_to_result, Error, Result},
memory::Allocator,
ortsys,
value::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut}
};

pub trait Kernel {
fn compute(&mut self, ctx: &KernelContext) -> crate::Result<()>;
Expand Down
Loading

0 comments on commit 8a0646c

Please sign in to comment.