diff --git a/cub/cub/detail/nvtx.cuh b/cub/cub/detail/nvtx.cuh index 4ffbfc93b92..3b0c039407c 100644 --- a/cub/cub/detail/nvtx.cuh +++ b/cub/cub/detail/nvtx.cuh @@ -37,11 +37,11 @@ # pragma system_header #endif // no system header -// Enable the functionality of this header if +// Enable the functionality of this header if: // * The NVTX3 C API is available in CTK // * NVTX is not explicitly disabled // * C++14 is availabl for cuda::std::optional -#if __has_include() && !defined(NVTX_DISABLE) && _CCCL_STD_VER >= 2014 +#if __has_include( ) && !defined(NVTX_DISABLE) && _CCCL_STD_VER >= 2014 // Include our NVTX3 C++ wrapper if not available from the CTK # if __has_include() // TODO(bgruber): replace by a check for the first CTK version shipping the header # include @@ -49,7 +49,11 @@ # include "nvtx3.hpp" # endif // __has_include() -# include +// We expect the NVTX3 V1 C++ API to be available when nvtx3.hpp is available. This should work, because newer versions +// of NVTX3 will continue to declare previous API versions. See also: +// https://github.com/NVIDIA/NVTX/blob/release-v3/c/include/nvtx3/nvtx3.hpp#L2835-L2841. +# ifdef NVTX3_CPP_DEFINITIONS_V1_0 +# include CUB_NAMESPACE_BEGIN namespace detail @@ -62,26 +66,38 @@ struct NVTXCCCLDomain CUB_NAMESPACE_END // Hook for the NestedNVTXRangeGuard from the unit tests -# ifndef CUB_DETAIL_BEFORE_NVTX_RANGE_SCOPE -# define CUB_DETAIL_BEFORE_NVTX_RANGE_SCOPE(name) -# endif // !CUB_DETAIL_BEFORE_NVTX_RANGE_SCOPE +# ifndef CUB_DETAIL_BEFORE_NVTX_RANGE_SCOPE +# define CUB_DETAIL_BEFORE_NVTX_RANGE_SCOPE(name) +# endif // !CUB_DETAIL_BEFORE_NVTX_RANGE_SCOPE // Conditionally inserts a NVTX range starting here until the end of the current function scope in host code. Does // nothing in device code. // The optional is needed to defer the construction of an NVTX range (host-only code) and message string registration // into a dispatch region running only on the host, while preserving the semantic scope where the range is declared. -# define CUB_DETAIL_NVTX_RANGE_SCOPE_IF(condition, name) \ - CUB_DETAIL_BEFORE_NVTX_RANGE_SCOPE(name) \ - ::cuda::std::optional<::nvtx3::scoped_range_in> __cub_nvtx3_range; \ - NV_IF_TARGET( \ - NV_IS_HOST, \ - static const ::nvtx3::registered_string_in __cub_nvtx3_func_name{name}; \ - static const ::nvtx3::event_attributes __cub_nvtx3_func_attr{__cub_nvtx3_func_name}; \ - if (condition) __cub_nvtx3_range.emplace(__cub_nvtx3_func_attr); \ - (void) __cub_nvtx3_range;) +# define CUB_DETAIL_NVTX_RANGE_SCOPE_IF(condition, name) \ + CUB_DETAIL_BEFORE_NVTX_RANGE_SCOPE(name) \ + ::cuda::std::optional<::nvtx3::v1::scoped_range_in> __cub_nvtx3_range; \ + NV_IF_TARGET( \ + NV_IS_HOST, \ + static const ::nvtx3::v1::registered_string_in __cub_nvtx3_func_name{ \ + name}; \ + static const ::nvtx3::v1::event_attributes __cub_nvtx3_func_attr{__cub_nvtx3_func_name}; \ + if (condition) __cub_nvtx3_range.emplace(__cub_nvtx3_func_attr); \ + (void) __cub_nvtx3_range;) -# define CUB_DETAIL_NVTX_RANGE_SCOPE(name) CUB_DETAIL_NVTX_RANGE_SCOPE_IF(true, name) -#else // __has_include() && !defined(NVTX_DISABLE) && _CCCL_STD_VER > 2011 +# define CUB_DETAIL_NVTX_RANGE_SCOPE(name) CUB_DETAIL_NVTX_RANGE_SCOPE_IF(true, name) +# else // NVTX3_CPP_DEFINITIONS_V1_0 +// Tell the user we don't support their NVTX3 version. +# if defined(_CCCL_COMPILER_MSVC) +# pragma message( \ + "warning: nvtx3.hpp is available but does not define the V1 API. This is odd. Please open a GitHub issue at: https://github.com/NVIDIA/cccl/issues.") +# else +# warning nvtx3.hpp is available but does not define the V1 API. This is odd. Please open a GitHub issue at: https://github.com/NVIDIA/cccl/issues. +# endif +# define CUB_DETAIL_NVTX_RANGE_SCOPE_IF(condition, name) +# define CUB_DETAIL_NVTX_RANGE_SCOPE(name) +# endif // NVTX3_CPP_DEFINITIONS_V1_0 +#else // __has_include( ) && !defined(NVTX_DISABLE) && _CCCL_STD_VER >= 2014 # define CUB_DETAIL_NVTX_RANGE_SCOPE_IF(condition, name) # define CUB_DETAIL_NVTX_RANGE_SCOPE(name) -#endif // __has_include() && !defined(NVTX_DISABLE) && _CCCL_STD_VER > 2011 +#endif // __has_include( ) && !defined(NVTX_DISABLE) && _CCCL_STD_VER >= 2014