diff --git a/src/memory.rs b/src/memory.rs index b1d4b68..d43a944 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -318,6 +318,34 @@ impl From for MemoryType { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow(clippy::upper_case_acronyms)] +pub enum DeviceType { + CPU, + GPU, + FPGA +} + +impl From for ort_sys::OrtMemoryInfoDeviceType { + fn from(value: DeviceType) -> Self { + match value { + DeviceType::CPU => ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU, + DeviceType::GPU => ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU, + DeviceType::FPGA => ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA + } + } +} + +impl From for DeviceType { + fn from(value: ort_sys::OrtMemoryInfoDeviceType) -> Self { + match value { + ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU => DeviceType::CPU, + ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU => DeviceType::GPU, + ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA => DeviceType::FPGA + } + } +} + /// Describes allocation properties for value memory. /// /// `MemoryInfo` is used in the creation of [`Session`]s, [`Allocator`]s, and [`crate::Value`]s to describe on which @@ -445,10 +473,17 @@ impl MemoryInfo { raw as _ } + /// Returns the type of device (CPU/GPU) this memory is allocated on. + pub fn device_type(&self) -> DeviceType { + let mut raw: ort_sys::OrtMemoryInfoDeviceType = ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU; + ortsys![unsafe MemoryInfoGetDeviceType(self.ptr.as_ptr(), &mut raw)]; + raw.into() + } + /// Returns `true` if this memory is accessible by the CPU; meaning that, if a value were allocated on this device, /// it could be extracted to an `ndarray` or slice. pub fn is_cpu_accessible(&self) -> bool { - self.allocation_device() == AllocationDevice::CPU || matches!(self.memory_type(), MemoryType::CPUInput | MemoryType::CPUOutput) + self.device_type() == DeviceType::CPU } pub fn ptr(&self) -> *mut ort_sys::OrtMemoryInfo {