diff --git a/src/operator/bound.rs b/src/operator/bound.rs index 58219c4..452736f 100644 --- a/src/operator/bound.rs +++ b/src/operator/bound.rs @@ -9,7 +9,7 @@ use super::{ kernel::{Kernel, KernelAttributes, KernelContext}, DummyOperator, Operator }; -use crate::error::IntoStatus; +use crate::{error::IntoStatus, extern_system_fn}; #[repr(C)] // <- important! a defined layout allows us to store extra data after the `OrtCustomOp` that we can retrieve later pub(crate) struct BoundOperator { @@ -65,115 +65,147 @@ impl BoundOperator { &*op.cast() } - pub(crate) unsafe extern "C" fn CreateKernelV2( - _: *const ort_sys::OrtCustomOp, - _: *const ort_sys::OrtApi, - info: *const ort_sys::OrtKernelInfo, - kernel_ptr: *mut *mut ort_sys::c_void - ) -> *mut ort_sys::OrtStatus { - let kernel = match O::create_kernel(&KernelAttributes::new(info)) { - Ok(kernel) => kernel, - e => return e.into_status() - }; - *kernel_ptr = (Box::leak(Box::new(kernel)) as *mut O::Kernel).cast(); - Ok(()).into_status() - } - - pub(crate) unsafe extern "C" fn ComputeKernelV2(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus { - let context = KernelContext::new(context); - O::Kernel::compute(unsafe { &mut *kernel_ptr.cast::() }, &context).into_status() - } - - pub(crate) unsafe extern "C" fn KernelDestroy(op_kernel: *mut ort_sys::c_void) { - drop(Box::from_raw(op_kernel.cast::())); - } - - pub(crate) unsafe extern "C" fn GetName(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char { - let safe = Self::safe(op); - safe.name.as_ptr() - } - pub(crate) unsafe extern "C" fn GetExecutionProviderType(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char { - let safe = Self::safe(op); - safe.execution_provider_type.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null) - } - - pub(crate) unsafe extern "C" fn GetStartVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::min_version() - } - pub(crate) unsafe extern "C" fn GetEndVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::max_version() - } - - pub(crate) unsafe extern "C" fn GetInputMemoryType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtMemType { - O::inputs()[index as usize].memory_type.into() - } - pub(crate) unsafe extern "C" fn GetInputCharacteristic( - _: *const ort_sys::OrtCustomOp, - index: ort_sys::size_t - ) -> ort_sys::OrtCustomOpInputOutputCharacteristic { - O::inputs()[index as usize].characteristic.into() - } - pub(crate) unsafe extern "C" fn GetOutputCharacteristic( - _: *const ort_sys::OrtCustomOp, - index: ort_sys::size_t - ) -> ort_sys::OrtCustomOpInputOutputCharacteristic { - O::outputs()[index as usize].characteristic.into() - } - pub(crate) unsafe extern "C" fn GetInputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t { - O::inputs().len() as _ - } - pub(crate) unsafe extern "C" fn GetOutputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t { - O::outputs().len() as _ - } - pub(crate) unsafe extern "C" fn GetInputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType { - O::inputs()[index as usize] - .r#type - .map(|c| c.into()) - .unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) - } - pub(crate) unsafe extern "C" fn GetOutputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType { - O::outputs()[index as usize] - .r#type - .map(|c| c.into()) - .unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) - } - pub(crate) unsafe extern "C" fn GetVariadicInputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::inputs() - .into_iter() - .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) - .and_then(|c| c.variadic_min_arity) - .unwrap_or(1) - .try_into() - .expect("input minimum arity overflows i32") - } - pub(crate) unsafe extern "C" fn GetVariadicInputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::inputs() - .into_iter() - .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) - .and_then(|c| c.variadic_homogeneity) - .unwrap_or(false) - .into() - } - pub(crate) unsafe extern "C" fn GetVariadicOutputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::outputs() - .into_iter() - .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) - .and_then(|c| c.variadic_min_arity) - .unwrap_or(1) - .try_into() - .expect("output minimum arity overflows i32") - } - pub(crate) unsafe extern "C" fn GetVariadicOutputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { - O::outputs() - .into_iter() - .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) - .and_then(|c| c.variadic_homogeneity) - .unwrap_or(false) - .into() - } - - pub(crate) unsafe extern "C" fn InferOutputShapeFn(_: *const ort_sys::OrtCustomOp, arg1: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus { - O::get_infer_shape_function().expect("missing infer shape function")(arg1).into_status() + extern_system_fn! { + pub(crate) unsafe fn CreateKernelV2( + _: *const ort_sys::OrtCustomOp, + _: *const ort_sys::OrtApi, + info: *const ort_sys::OrtKernelInfo, + kernel_ptr: *mut *mut ort_sys::c_void + ) -> *mut ort_sys::OrtStatus { + let kernel = match O::create_kernel(&KernelAttributes::new(info)) { + Ok(kernel) => kernel, + e => return e.into_status() + }; + *kernel_ptr = (Box::leak(Box::new(kernel)) as *mut O::Kernel).cast(); + Ok(()).into_status() + } + } + + extern_system_fn! { + pub(crate) unsafe fn ComputeKernelV2(kernel_ptr: *mut ort_sys::c_void, context: *mut ort_sys::OrtKernelContext) -> *mut ort_sys::OrtStatus { + let context = KernelContext::new(context); + O::Kernel::compute(unsafe { &mut *kernel_ptr.cast::() }, &context).into_status() + } + } + + extern_system_fn! { + pub(crate) unsafe fn KernelDestroy(op_kernel: *mut ort_sys::c_void) { + drop(Box::from_raw(op_kernel.cast::())); + } + } + + extern_system_fn! { + pub(crate) unsafe fn GetName(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char { + let safe = Self::safe(op); + safe.name.as_ptr() + } + } + extern_system_fn! { + pub(crate) unsafe fn GetExecutionProviderType(op: *const ort_sys::OrtCustomOp) -> *const ort_sys::c_char { + let safe = Self::safe(op); + safe.execution_provider_type.as_ref().map(|c| c.as_ptr()).unwrap_or_else(ptr::null) + } + } + + extern_system_fn! { + pub(crate) unsafe fn GetStartVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + O::min_version() + } + } + extern_system_fn! { + pub(crate) unsafe fn GetEndVersion(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + O::max_version() + } + } + + extern_system_fn! { + pub(crate) unsafe fn GetInputMemoryType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtMemType { + O::inputs()[index as usize].memory_type.into() + } + } + extern_system_fn! { + pub(crate) unsafe fn GetInputCharacteristic(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtCustomOpInputOutputCharacteristic { + O::inputs()[index as usize].characteristic.into() + } + } + extern_system_fn! { + pub(crate) unsafe fn GetOutputCharacteristic(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::OrtCustomOpInputOutputCharacteristic { + O::outputs()[index as usize].characteristic.into() + } + } + extern_system_fn! { + pub(crate) unsafe fn GetInputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t { + O::inputs().len() as _ + } + } + extern_system_fn! { + pub(crate) unsafe fn GetOutputTypeCount(_: *const ort_sys::OrtCustomOp) -> ort_sys::size_t { + O::outputs().len() as _ + } + } + extern_system_fn! { + pub(crate) unsafe fn GetInputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType { + O::inputs()[index as usize] + .r#type + .map(|c| c.into()) + .unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) + } + } + extern_system_fn! { + pub(crate) unsafe fn GetOutputType(_: *const ort_sys::OrtCustomOp, index: ort_sys::size_t) -> ort_sys::ONNXTensorElementDataType { + O::outputs()[index as usize] + .r#type + .map(|c| c.into()) + .unwrap_or(ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) + } + } + extern_system_fn! { + pub(crate) unsafe fn GetVariadicInputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + O::inputs() + .into_iter() + .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) + .and_then(|c| c.variadic_min_arity) + .unwrap_or(1) + .try_into() + .expect("input minimum arity overflows i32") + } + } + extern_system_fn! { + pub(crate) unsafe fn GetVariadicInputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + O::inputs() + .into_iter() + .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) + .and_then(|c| c.variadic_homogeneity) + .unwrap_or(false) + .into() + } + } + extern_system_fn! { + pub(crate) unsafe fn GetVariadicOutputMinArity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + O::outputs() + .into_iter() + .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) + .and_then(|c| c.variadic_min_arity) + .unwrap_or(1) + .try_into() + .expect("output minimum arity overflows i32") + } + } + extern_system_fn! { + pub(crate) unsafe fn GetVariadicOutputHomogeneity(_: *const ort_sys::OrtCustomOp) -> ort_sys::c_int { + O::outputs() + .into_iter() + .find(|c| c.characteristic == InputOutputCharacteristic::Variadic) + .and_then(|c| c.variadic_homogeneity) + .unwrap_or(false) + .into() + } + } + + extern_system_fn! { + pub(crate) unsafe fn InferOutputShapeFn(_: *const ort_sys::OrtCustomOp, arg1: *mut ort_sys::OrtShapeInferContext) -> *mut ort_sys::OrtStatus { + O::get_infer_shape_function().expect("missing infer shape function")(arg1).into_status() + } } }