Skip to content

Commit

Permalink
fix: query device type to determine CPU accessibility
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Oct 23, 2024
1 parent 636a133 commit 2628378
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,34 @@ impl From<ort_sys::OrtMemType> for MemoryType {
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(clippy::upper_case_acronyms)]
pub enum DeviceType {
CPU,
GPU,
FPGA
}

impl From<DeviceType> for ort_sys::OrtMemoryInfoDeviceType {
fn from(value: DeviceType) -> Self {
match value {
DeviceType::CPU => ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU,
DeviceType::GPU => ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU,
DeviceType::FPGA => ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA
}
}
}

impl From<ort_sys::OrtMemoryInfoDeviceType> for DeviceType {
fn from(value: ort_sys::OrtMemoryInfoDeviceType) -> Self {
match value {
ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU => DeviceType::CPU,
ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU => DeviceType::GPU,
ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA => DeviceType::FPGA
}
}
}

/// Describes allocation properties for value memory.
///
/// `MemoryInfo` is used in the creation of [`Session`]s, [`Allocator`]s, and [`crate::Value`]s to describe on which
Expand Down Expand Up @@ -445,10 +473,17 @@ impl MemoryInfo {
raw as _
}

/// Returns the type of device (CPU/GPU) this memory is allocated on.
pub fn device_type(&self) -> DeviceType {
let mut raw: ort_sys::OrtMemoryInfoDeviceType = ort_sys::OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU;
ortsys![unsafe MemoryInfoGetDeviceType(self.ptr.as_ptr(), &mut raw)];
raw.into()
}

/// Returns `true` if this memory is accessible by the CPU; meaning that, if a value were allocated on this device,
/// it could be extracted to an `ndarray` or slice.
pub fn is_cpu_accessible(&self) -> bool {
self.allocation_device() == AllocationDevice::CPU || matches!(self.memory_type(), MemoryType::CPUInput | MemoryType::CPUOutput)
self.device_type() == DeviceType::CPU
}

pub fn ptr(&self) -> *mut ort_sys::OrtMemoryInfo {
Expand Down

0 comments on commit 2628378

Please sign in to comment.