Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NVTX wrapper #301

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ cublaslt = ["driver"]
cudnn = ["driver"]
curand = ["driver"]
nccl = ["driver"]
nvtx = []

std = []
no-std = ["no-std-compat/std", "dep:spin"]
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ Safe abstractions over:
4. [cuBLAS API](https://docs.nvidia.com/cuda/cublas/index.html)
5. [cuBLASLt API](https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api)
6. [NCCL API](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/)

7. [cuDNN](https://docs.nvidia.com/cudnn/index.html)
8. [nvtx](https://github.com/NVIDIA/NVTX)

**Pre-alpha state**, expect breaking changes and not all cuda functions
contain a safe wrapper. **Contributions welcome for any that aren't included!**

Expand Down
52 changes: 52 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,58 @@ fn main() {

#[cfg(feature = "dynamic-linking")]
dynamic_linking(major, minor);

#[cfg(feature = "nvtx")]
nvtx();
}

#[allow(unused)]
fn nvtx() {
let output_path = PathBuf::from(std::env::var("OUT_DIR").unwrap());
let obj_path = output_path.join("extern.o");
let clang_output = std::process::Command::new("clang")
.arg("-O")
.arg("-c")
.arg("-o")
.arg(&obj_path)
.arg("-I/usr/local/cuda/include")
.arg("src/nvtx/sys/extern.c")
.output()
.unwrap();

if !clang_output.status.success() {
panic!(
"Could not compile object file:\n{}",
String::from_utf8_lossy(&clang_output.stderr)
);
}

// Turn the object file into a static library
#[cfg(not(target_os = "windows"))]
let lib_output = std::process::Command::new("ar")
.arg("rcs")
.arg(output_path.join("libextern.a"))
.arg(obj_path)
.output()
.unwrap();
#[cfg(target_os = "windows")]
let lib_output = std::process::Command::new("LIB")
.arg(obj_path)
.arg(format!("/OUT:{}", output_path.join("libextern.lib").display()))
.output()
.unwrap();
if !lib_output.status.success() {
panic!(
"Could not emit library file:\n{}",
String::from_utf8_lossy(&lib_output.stderr)
);
}

// Tell cargo to statically link against the `libextern` static library.
let output_path = output_path.as_os_str().to_str().unwrap();
println!("cargo:rustc-link-search=native={output_path}");
println!("cargo:rustc-link-lib=static=extern");

}

#[allow(unused)]
Expand Down
21 changes: 21 additions & 0 deletions examples/nvtx-range.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use cudarc::nvtx::safe::{mark, range_end, range_pop, range_push, range_start};


fn main() {
range_push("Test Range");
std::thread::sleep(std::time::Duration::from_secs(1));
range_pop();

range_push("Test Range2");
std::thread::sleep(std::time::Duration::from_secs(1));
range_push("Test Range3");
std::thread::sleep(std::time::Duration::from_secs(1));
range_pop();
range_pop();

let id = range_start("Test Range4");
std::thread::sleep(std::time::Duration::from_secs(1));
mark("Test Mark");
std::thread::sleep(std::time::Duration::from_secs(1));
range_end(id);
}
2 changes: 1 addition & 1 deletion run-bindgen.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
declare -a modules=("cublas" "cublaslt" "cudnn" "curand" "driver" "nccl" "nvrtc")
declare -a modules=("cublas" "cublaslt" "cudnn" "curand" "driver" "nccl" "nvrtc" "nvtx")
for path in "${modules[@]}"
do
cd src/${path}/sys
Expand Down
9 changes: 8 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
//! 2. [NVRTC API](https://docs.nvidia.com/cuda/nvrtc/index.html)
//! 3. [cuRAND API](https://docs.nvidia.com/cuda/curand/index.html)
//! 4. [cuBLAS API](https://docs.nvidia.com/cuda/cublas/index.html)
//! 5. [cuBLASLt API](https://docs.nvidia.com/cuda/cublas/#using-the-cublaslt-api)
//! 6. [NCCL API](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/)
//! 7. [cuDNN](https://docs.nvidia.com/cudnn/index.html)
//! 8. [nvtx](https://github.com/NVIDIA/NVTX)

//!
//! # crate organization
//!
Expand All @@ -18,7 +23,7 @@
//! | cublaslt | [cublaslt::safe] | [cublaslt::result] | [cublaslt::sys] |
//! | nvrtc | [nvrtc::safe] | [nvrtc::result] | [nvrtc::sys] |
//! | curand | [curand::safe] | [curand::result] | [curand::sys] |
//! | cudnn | - | [cudnn::result] | [cudnn::sys] |
//! | cudnn | [cudnn::safe] | [cudnn::result] | [cudnn::sys] |
//!
//! # Core Concepts
//!
Expand Down Expand Up @@ -91,6 +96,8 @@ pub mod driver;
pub mod nccl;
#[cfg(feature = "nvrtc")]
pub mod nvrtc;
#[cfg(feature = "nvtx")]
pub mod nvtx;

pub mod types;

Expand Down
8 changes: 8 additions & 0 deletions src/nvtx/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
//! Wrappers around the [NVTX API](https://nvidia.github.io/NVTX/doxygen/index.html)
//! in two levels. See crate documentation for description of each.

pub mod safe;
#[allow(warnings)]
pub mod sys;

pub use safe::*;
119 changes: 119 additions & 0 deletions src/nvtx/safe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
use super::sys::{
nvtxInitialize, nvtxMarkA, nvtxNameCategoryA, nvtxNameCuContextA, nvtxNameCuDeviceA,
nvtxNameCuEventA, nvtxNameCuStreamA, nvtxNameOsThreadA, nvtxRangeEnd, nvtxRangePop,
nvtxRangePushA, nvtxRangeStartA, CUcontext, CUevent, CUstream,
};
use std::ffi::CString;

pub fn range_push(message: &str) -> i32 {
unsafe {
let message = CString::new(message).unwrap();
nvtxRangePushA(message.as_ptr())
}
}

pub fn range_pop() -> i32 {
unsafe { nvtxRangePop() }
}

pub fn range_start(message: &str) -> u64 {
unsafe {
let message = CString::new(message).unwrap();
nvtxRangeStartA(message.as_ptr())
}
}

pub fn range_end(range_id: u64) {
unsafe {
nvtxRangeEnd(range_id);
}
}

pub fn mark(message: &str) {
unsafe {
let message = CString::new(message).unwrap();
nvtxMarkA(message.as_ptr());
}
}

pub fn name_category(category: u32, name: &str) {
unsafe {
let name = CString::new(name).unwrap();
nvtxNameCategoryA(category, name.as_ptr());
}
}

pub fn name_os_thread(os_thread_id: u32, name: &str) {
unsafe {
let name = CString::new(name).unwrap();
nvtxNameOsThreadA(os_thread_id, name.as_ptr());
}
}

pub fn name_cu_device(cu_device: i32, name: &str) {
unsafe {
let name = CString::new(name).unwrap();
nvtxNameCuDeviceA(cu_device, name.as_ptr());
}
}

pub fn name_cu_context(cu_context: CUcontext, name: &str) {
unsafe {
let name = CString::new(name).unwrap();
nvtxNameCuContextA(cu_context, name.as_ptr());
}
}

pub fn name_cu_stream(cu_stream: CUstream, name: &str) {
unsafe {
let name = CString::new(name).unwrap();
nvtxNameCuStreamA(cu_stream, name.as_ptr());
}
}

pub fn name_cu_event(cu_event: CUevent, name: &str) {
unsafe {
let name = CString::new(name).unwrap();
nvtxNameCuEventA(cu_event, name.as_ptr());
}
}

pub fn initialize() {
unsafe {
// according to the doc, reserved "must be zero or NULL."
let reserved = std::ptr::null_mut();
nvtxInitialize(reserved);
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_range_push() {
range_push("Test Range");
// Add assertions here
}

#[test]
fn test_range_pop() {
range_pop();
// Add assertions here
}

#[test]
fn test_name_cuda_event() {
let cuda_event: CUevent = std::ptr::null_mut();
name_cu_event(cuda_event, "Test Event");
// Add assertions here
}

#[test]
fn test_initialize() {
initialize();
// Add assertions here
}

// Add more tests here
}
24 changes: 24 additions & 0 deletions src/nvtx/sys/bindgen.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash
set -exu

# Follow https://github.com/rust-lang/rust-bindgen/discussions/2405
bindgen \
--allowlist-var="^CUDA_VERSION.*" \
--allowlist-type="^nvtx.*" \
--allowlist-function="^nvtx.*" \
--default-enum-style=rust \
--no-doc-comments \
--with-derive-default \
--with-derive-eq \
--with-derive-hash \
--with-derive-ord \
--use-core \
--wrap-static-fns \
--experimental \
wrapper.h -- -I/usr/local/cuda/include \
> tmp.rs

mv /tmp/bindgen/extern.c .

CUDA_VERSION=$(cat tmp.rs | grep "CUDA_VERSION" | awk '{ print $6 }' | sed 's/.$//')
mv tmp.rs sys_${CUDA_VERSION}.rs
42 changes: 42 additions & 0 deletions src/nvtx/sys/extern.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include "wrapper.h"

// Static wrappers

void nvtxInitialize__extern(const void *reserved) { nvtxInitialize(reserved); }
void nvtxDomainMarkEx__extern(nvtxDomainHandle_t domain, const nvtxEventAttributes_t *eventAttrib) { nvtxDomainMarkEx(domain, eventAttrib); }
void nvtxMarkEx__extern(const nvtxEventAttributes_t *eventAttrib) { nvtxMarkEx(eventAttrib); }
void nvtxMarkA__extern(const char *message) { nvtxMarkA(message); }
void nvtxMarkW__extern(const wchar_t *message) { nvtxMarkW(message); }
nvtxRangeId_t nvtxDomainRangeStartEx__extern(nvtxDomainHandle_t domain, const nvtxEventAttributes_t *eventAttrib) { return nvtxDomainRangeStartEx(domain, eventAttrib); }
nvtxRangeId_t nvtxRangeStartEx__extern(const nvtxEventAttributes_t *eventAttrib) { return nvtxRangeStartEx(eventAttrib); }
nvtxRangeId_t nvtxRangeStartA__extern(const char *message) { return nvtxRangeStartA(message); }
nvtxRangeId_t nvtxRangeStartW__extern(const wchar_t *message) { return nvtxRangeStartW(message); }
void nvtxDomainRangeEnd__extern(nvtxDomainHandle_t domain, nvtxRangeId_t id) { nvtxDomainRangeEnd(domain, id); }
void nvtxRangeEnd__extern(nvtxRangeId_t id) { nvtxRangeEnd(id); }
int nvtxDomainRangePushEx__extern(nvtxDomainHandle_t domain, const nvtxEventAttributes_t *eventAttrib) { return nvtxDomainRangePushEx(domain, eventAttrib); }
int nvtxRangePushEx__extern(const nvtxEventAttributes_t *eventAttrib) { return nvtxRangePushEx(eventAttrib); }
int nvtxRangePushA__extern(const char *message) { return nvtxRangePushA(message); }
int nvtxRangePushW__extern(const wchar_t *message) { return nvtxRangePushW(message); }
int nvtxDomainRangePop__extern(nvtxDomainHandle_t domain) { return nvtxDomainRangePop(domain); }
int nvtxRangePop__extern(void) { return nvtxRangePop(); }
nvtxResourceHandle_t nvtxDomainResourceCreate__extern(nvtxDomainHandle_t domain, nvtxResourceAttributes_t *attribs) { return nvtxDomainResourceCreate(domain, attribs); }
void nvtxDomainResourceDestroy__extern(nvtxResourceHandle_t resource) { nvtxDomainResourceDestroy(resource); }
void nvtxDomainNameCategoryA__extern(nvtxDomainHandle_t domain, uint32_t category, const char *name) { nvtxDomainNameCategoryA(domain, category, name); }
void nvtxDomainNameCategoryW__extern(nvtxDomainHandle_t domain, uint32_t category, const wchar_t *name) { nvtxDomainNameCategoryW(domain, category, name); }
void nvtxNameCategoryA__extern(uint32_t category, const char *name) { nvtxNameCategoryA(category, name); }
void nvtxNameCategoryW__extern(uint32_t category, const wchar_t *name) { nvtxNameCategoryW(category, name); }
void nvtxNameOsThreadA__extern(uint32_t threadId, const char *name) { nvtxNameOsThreadA(threadId, name); }
void nvtxNameOsThreadW__extern(uint32_t threadId, const wchar_t *name) { nvtxNameOsThreadW(threadId, name); }
nvtxStringHandle_t nvtxDomainRegisterStringA__extern(nvtxDomainHandle_t domain, const char *string) { return nvtxDomainRegisterStringA(domain, string); }
nvtxStringHandle_t nvtxDomainRegisterStringW__extern(nvtxDomainHandle_t domain, const wchar_t *string) { return nvtxDomainRegisterStringW(domain, string); }
nvtxDomainHandle_t nvtxDomainCreateA__extern(const char *name) { return nvtxDomainCreateA(name); }
nvtxDomainHandle_t nvtxDomainCreateW__extern(const wchar_t *name) { return nvtxDomainCreateW(name); }
void nvtxDomainDestroy__extern(nvtxDomainHandle_t domain) { nvtxDomainDestroy(domain); }
void nvtxNameCuDeviceA__extern(CUdevice device, const char *name) { nvtxNameCuDeviceA(device, name); }
void nvtxNameCuDeviceW__extern(CUdevice device, const wchar_t *name) { nvtxNameCuDeviceW(device, name); }
void nvtxNameCuContextA__extern(CUcontext context, const char *name) { nvtxNameCuContextA(context, name); }
void nvtxNameCuContextW__extern(CUcontext context, const wchar_t *name) { nvtxNameCuContextW(context, name); }
void nvtxNameCuStreamA__extern(CUstream stream, const char *name) { nvtxNameCuStreamA(stream, name); }
void nvtxNameCuStreamW__extern(CUstream stream, const wchar_t *name) { nvtxNameCuStreamW(stream, name); }
void nvtxNameCuEventA__extern(CUevent event, const char *name) { nvtxNameCuEventA(event, name); }
void nvtxNameCuEventW__extern(CUevent event, const wchar_t *name) { nvtxNameCuEventW(event, name); }
59 changes: 59 additions & 0 deletions src/nvtx/sys/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#[cfg(feature = "cuda-11040")]
mod sys_11040;
#[cfg(feature = "cuda-11040")]
pub use sys_11040::*;

#[cfg(feature = "cuda-11050")]
mod sys_11050;
#[cfg(feature = "cuda-11050")]
pub use sys_11050::*;

#[cfg(feature = "cuda-11060")]
mod sys_11060;
#[cfg(feature = "cuda-11060")]
pub use sys_11060::*;

#[cfg(feature = "cuda-11070")]
mod sys_11070;
#[cfg(feature = "cuda-11070")]
pub use sys_11070::*;

#[cfg(feature = "cuda-11080")]
mod sys_11080;
#[cfg(feature = "cuda-11080")]
pub use sys_11080::*;

#[cfg(feature = "cuda-12000")]
mod sys_12000;
#[cfg(feature = "cuda-12000")]
pub use sys_12000::*;

#[cfg(feature = "cuda-12010")]
mod sys_12010;
#[cfg(feature = "cuda-12010")]
pub use sys_12010::*;

#[cfg(feature = "cuda-12020")]
mod sys_12020;
#[cfg(feature = "cuda-12020")]
pub use sys_12020::*;

#[cfg(feature = "cuda-12030")]
mod sys_12030;
#[cfg(feature = "cuda-12030")]
pub use sys_12030::*;

#[cfg(feature = "cuda-12040")]
mod sys_12040;
#[cfg(feature = "cuda-12040")]
pub use sys_12040::*;

#[cfg(feature = "cuda-12050")]
mod sys_12050;
#[cfg(feature = "cuda-12050")]
pub use sys_12050::*;

#[cfg(feature = "cuda-12060")]
mod sys_12060;
#[cfg(feature = "cuda-12060")]
pub use sys_12060::*;
Loading