diff --git a/source/loader/CMakeLists.txt b/source/loader/CMakeLists.txt index 07dab17943..1401379341 100644 --- a/source/loader/CMakeLists.txt +++ b/source/loader/CMakeLists.txt @@ -131,45 +131,42 @@ if(UR_ENABLE_SANITIZER) target_sources(ur_loader PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../ur/ur.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_allocator.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_allocator.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_buffer.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_buffer.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_interceptor.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_interceptor.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_libdevice.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_options.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_options.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_quarantine.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_quarantine.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_report.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_report.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_shadow.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_shadow.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_statistics.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_statistics.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_validator.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan_validator.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/common.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/stacktrace.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/stacktrace.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_allocator.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_buffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_buffer.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_ddi.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_interceptor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_interceptor.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_libdevice.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_options.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_options.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_quarantine.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_quarantine.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_report.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_report.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_shadow.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_shadow.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_statistics.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_statistics.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_validator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/asan/asan_validator.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/sanitizer_common/linux/backtrace.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/sanitizer_common/linux/sanitizer_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/sanitizer_common/sanitizer_common.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/sanitizer_common/sanitizer_stacktrace.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/sanitizer_common/sanitizer_stacktrace.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/sanitizer_common/sanitizer_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/sanitizer_common/sanitizer_utils.hpp ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_sanddi.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_sanitizer_layer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_sanitizer_layer.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_sanitizer_utils.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_sanitizer_utils.hpp - ) - - target_sources(ur_loader - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/linux/backtrace.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/linux/sanitizer_utils.cpp ) if(UR_ENABLE_SYMBOLIZER) target_sources(ur_loader PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/linux/symbolizer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/sanitizer_common/linux/symbolizer.cpp ) target_include_directories(ur_loader PRIVATE ${LLVM_INCLUDE_DIRS}) target_link_libraries(ur_loader PRIVATE LLVMSupport LLVMSymbolize) diff --git a/source/loader/layers/sanitizer/asan_allocator.cpp b/source/loader/layers/sanitizer/asan/asan_allocator.cpp similarity index 100% rename from source/loader/layers/sanitizer/asan_allocator.cpp rename to source/loader/layers/sanitizer/asan/asan_allocator.cpp diff --git a/source/loader/layers/sanitizer/asan_allocator.hpp b/source/loader/layers/sanitizer/asan/asan_allocator.hpp similarity index 93% rename from source/loader/layers/sanitizer/asan_allocator.hpp rename to source/loader/layers/sanitizer/asan/asan_allocator.hpp index 249ef896d0..a3b92ec868 100644 --- a/source/loader/layers/sanitizer/asan_allocator.hpp +++ b/source/loader/layers/sanitizer/asan/asan_allocator.hpp @@ -12,11 +12,8 @@ #pragma once -#include "common.hpp" -#include "stacktrace.hpp" - -#include -#include +#include "sanitizer_common/sanitizer_common.hpp" +#include "sanitizer_common/sanitizer_stacktrace.hpp" namespace ur_sanitizer_layer { diff --git a/source/loader/layers/sanitizer/asan_buffer.cpp b/source/loader/layers/sanitizer/asan/asan_buffer.cpp similarity index 96% rename from source/loader/layers/sanitizer/asan_buffer.cpp rename to source/loader/layers/sanitizer/asan/asan_buffer.cpp index 9316d68bf4..ac2947ed29 100644 --- a/source/loader/layers/sanitizer/asan_buffer.cpp +++ b/source/loader/layers/sanitizer/asan/asan_buffer.cpp @@ -12,8 +12,8 @@ #include "asan_buffer.hpp" #include "asan_interceptor.hpp" +#include "sanitizer_common/sanitizer_utils.hpp" #include "ur_sanitizer_layer.hpp" -#include "ur_sanitizer_utils.hpp" namespace ur_sanitizer_layer { @@ -91,7 +91,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) { ur_usm_desc_t USMDesc{}; USMDesc.align = getAlignment(); ur_usm_pool_handle_t Pool{}; - URes = getContext()->interceptor->allocateMemory( + URes = getAsanInterceptor()->allocateMemory( Context, Device, &USMDesc, Pool, Size, AllocType::MEM_BUFFER, ur_cast(&Allocation)); if (URes != UR_RESULT_SUCCESS) { @@ -129,7 +129,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) { ur_usm_desc_t USMDesc{}; USMDesc.align = getAlignment(); ur_usm_pool_handle_t Pool{}; - URes = getContext()->interceptor->allocateMemory( + URes = getAsanInterceptor()->allocateMemory( Context, nullptr, &USMDesc, Pool, Size, AllocType::HOST_USM, ur_cast(&HostAllocation)); if (URes != UR_RESULT_SUCCESS) { @@ -174,8 +174,7 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) { ur_result_t MemBuffer::free() { for (const auto &[_, Ptr] : Allocations) { - ur_result_t URes = - getContext()->interceptor->releaseMemory(Context, Ptr); + ur_result_t URes = getAsanInterceptor()->releaseMemory(Context, Ptr); if (URes != UR_RESULT_SUCCESS) { getContext()->logger.error("Failed to free buffer handle {}", Ptr); return URes; diff --git a/source/loader/layers/sanitizer/asan_buffer.hpp b/source/loader/layers/sanitizer/asan/asan_buffer.hpp similarity index 98% rename from source/loader/layers/sanitizer/asan_buffer.hpp rename to source/loader/layers/sanitizer/asan/asan_buffer.hpp index 989ef4249f..a62dfa1ea8 100644 --- a/source/loader/layers/sanitizer/asan_buffer.hpp +++ b/source/loader/layers/sanitizer/asan/asan_buffer.hpp @@ -16,7 +16,7 @@ #include #include -#include "common.hpp" +#include "ur/ur.hpp" namespace ur_sanitizer_layer { diff --git a/source/loader/layers/sanitizer/asan/asan_ddi.cpp b/source/loader/layers/sanitizer/asan/asan_ddi.cpp new file mode 100644 index 0000000000..94f63d7a7d --- /dev/null +++ b/source/loader/layers/sanitizer/asan/asan_ddi.cpp @@ -0,0 +1,1877 @@ +/* + * + * Copyright (C) 2024 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. + * See LICENSE.TXT + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file asan_ddi.cpp + * + */ + +#include "asan_ddi.hpp" +#include "asan_interceptor.hpp" +#include "asan_options.hpp" +#include "sanitizer_common/sanitizer_stacktrace.hpp" +#include "sanitizer_common/sanitizer_utils.hpp" +#include "ur_sanitizer_layer.hpp" + +#include + +namespace ur_sanitizer_layer { +namespace asan { + +namespace { + +ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices, + const ur_device_handle_t *phDevices) { + std::shared_ptr CI; + UR_CALL(getAsanInterceptor()->insertContext(Context, CI)); + for (uint32_t i = 0; i < numDevices; ++i) { + auto hDevice = phDevices[i]; + std::shared_ptr DI; + UR_CALL(getAsanInterceptor()->insertDevice(hDevice, DI)); + DI->Type = GetDeviceType(Context, hDevice); + if (DI->Type == DeviceType::UNKNOWN) { + getContext()->logger.error("Unsupport device"); + return UR_RESULT_ERROR_INVALID_DEVICE; + } + getContext()->logger.info( + "DeviceInfo {} (Type={}, IsSupportSharedSystemUSM={})", + (void *)DI->Handle, ToString(DI->Type), + DI->IsSupportSharedSystemUSM); + getContext()->logger.info("Add {} into context {}", (void *)DI->Handle, + (void *)Context); + if (!DI->Shadow) { + UR_CALL(DI->allocShadowMemory(Context)); + } + CI->DeviceList.emplace_back(hDevice); + CI->AllocInfosMap[hDevice]; + } + return UR_RESULT_SUCCESS; +} + +} // namespace + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urAdapterGet +__urdlllocal ur_result_t UR_APICALL urAdapterGet( + uint32_t + NumEntries, ///< [in] the number of adapters to be added to phAdapters. + ///< If phAdapters is not NULL, then NumEntries should be greater than + ///< zero, otherwise ::UR_RESULT_ERROR_INVALID_SIZE, + ///< will be returned. + ur_adapter_handle_t * + phAdapters, ///< [out][optional][range(0, NumEntries)] array of handle of adapters. + ///< If NumEntries is less than the number of adapters available, then + ///< ::urAdapterGet shall only retrieve that number of platforms. + uint32_t * + pNumAdapters ///< [out][optional] returns the total number of adapters available. +) { + auto pfnAdapterGet = getContext()->urDdiTable.Global.pfnAdapterGet; + + if (nullptr == pfnAdapterGet) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + ur_result_t result = pfnAdapterGet(NumEntries, phAdapters, pNumAdapters); + if (result == UR_RESULT_SUCCESS && phAdapters) { + const uint32_t NumAdapters = pNumAdapters ? *pNumAdapters : NumEntries; + for (uint32_t i = 0; i < NumAdapters; ++i) { + UR_CALL(getAsanInterceptor()->holdAdapter(phAdapters[i])); + } + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMHostAlloc +__urdlllocal ur_result_t UR_APICALL urUSMHostAlloc( + ur_context_handle_t hContext, ///< [in] handle of the context object + const ur_usm_desc_t + *pUSMDesc, ///< [in][optional] USM memory allocation descriptor + ur_usm_pool_handle_t + pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate + size_t + size, ///< [in] size in bytes of the USM memory object to be allocated + void **ppMem ///< [out] pointer to USM host memory object +) { + auto pfnHostAlloc = getContext()->urDdiTable.USM.pfnHostAlloc; + + if (nullptr == pfnHostAlloc) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urUSMHostAlloc"); + + return getAsanInterceptor()->allocateMemory( + hContext, nullptr, pUSMDesc, pool, size, AllocType::HOST_USM, ppMem); +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMDeviceAlloc +__urdlllocal ur_result_t UR_APICALL urUSMDeviceAlloc( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + const ur_usm_desc_t + *pUSMDesc, ///< [in][optional] USM memory allocation descriptor + ur_usm_pool_handle_t + pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate + size_t + size, ///< [in] size in bytes of the USM memory object to be allocated + void **ppMem ///< [out] pointer to USM device memory object +) { + auto pfnDeviceAlloc = getContext()->urDdiTable.USM.pfnDeviceAlloc; + + if (nullptr == pfnDeviceAlloc) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urUSMDeviceAlloc"); + + return getAsanInterceptor()->allocateMemory( + hContext, hDevice, pUSMDesc, pool, size, AllocType::DEVICE_USM, ppMem); +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMSharedAlloc +__urdlllocal ur_result_t UR_APICALL urUSMSharedAlloc( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_device_handle_t hDevice, ///< [in] handle of the device object + const ur_usm_desc_t * + pUSMDesc, ///< [in][optional] Pointer to USM memory allocation descriptor. + ur_usm_pool_handle_t + pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate + size_t + size, ///< [in] size in bytes of the USM memory object to be allocated + void **ppMem ///< [out] pointer to USM shared memory object +) { + auto pfnSharedAlloc = getContext()->urDdiTable.USM.pfnSharedAlloc; + + if (nullptr == pfnSharedAlloc) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urUSMSharedAlloc"); + + return getAsanInterceptor()->allocateMemory( + hContext, hDevice, pUSMDesc, pool, size, AllocType::SHARED_USM, ppMem); +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urUSMFree +__urdlllocal ur_result_t UR_APICALL urUSMFree( + ur_context_handle_t hContext, ///< [in] handle of the context object + void *pMem ///< [in] pointer to USM memory object +) { + auto pfnFree = getContext()->urDdiTable.USM.pfnFree; + + if (nullptr == pfnFree) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urUSMFree"); + + return getAsanInterceptor()->releaseMemory(hContext, pMem); +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramCreateWithIL +__urdlllocal ur_result_t UR_APICALL urProgramCreateWithIL( + ur_context_handle_t hContext, ///< [in] handle of the context instance + const void *pIL, ///< [in] pointer to IL binary. + size_t length, ///< [in] length of `pIL` in bytes. + const ur_program_properties_t * + pProperties, ///< [in][optional] pointer to program creation properties. + ur_program_handle_t + *phProgram ///< [out] pointer to handle of program object created. +) { + auto pfnProgramCreateWithIL = + getContext()->urDdiTable.Program.pfnCreateWithIL; + + if (nullptr == pfnProgramCreateWithIL) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urProgramCreateWithIL"); + + UR_CALL( + pfnProgramCreateWithIL(hContext, pIL, length, pProperties, phProgram)); + UR_CALL(getAsanInterceptor()->insertProgram(*phProgram)); + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramCreateWithBinary +__urdlllocal ur_result_t UR_APICALL urProgramCreateWithBinary( + ur_context_handle_t hContext, ///< [in] handle of the context instance + uint32_t numDevices, ///< [in] number of devices + ur_device_handle_t * + phDevices, ///< [in][range(0, numDevices)] a pointer to a list of device handles. The + ///< binaries are loaded for devices specified in this list. + size_t * + pLengths, ///< [in][range(0, numDevices)] array of sizes of program binaries + ///< specified by `pBinaries` (in bytes). + const uint8_t ** + ppBinaries, ///< [in][range(0, numDevices)] pointer to program binaries to be loaded + ///< for devices specified by `phDevices`. + const ur_program_properties_t * + pProperties, ///< [in][optional] pointer to program creation properties. + ur_program_handle_t + *phProgram ///< [out] pointer to handle of Program object created. +) { + auto pfnProgramCreateWithBinary = + getContext()->urDdiTable.Program.pfnCreateWithBinary; + + if (nullptr == pfnProgramCreateWithBinary) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urProgramCreateWithBinary"); + + UR_CALL(pfnProgramCreateWithBinary(hContext, numDevices, phDevices, + pLengths, ppBinaries, pProperties, + phProgram)); + UR_CALL(getAsanInterceptor()->insertProgram(*phProgram)); + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramCreateWithNativeHandle +__urdlllocal ur_result_t UR_APICALL urProgramCreateWithNativeHandle( + ur_native_handle_t + hNativeProgram, ///< [in][nocheck] the native handle of the program. + ur_context_handle_t hContext, ///< [in] handle of the context instance + const ur_program_native_properties_t * + pProperties, ///< [in][optional] pointer to native program properties struct. + ur_program_handle_t * + phProgram ///< [out] pointer to the handle of the program object created. +) { + auto pfnProgramCreateWithNativeHandle = + getContext()->urDdiTable.Program.pfnCreateWithNativeHandle; + + if (nullptr == pfnProgramCreateWithNativeHandle) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urProgramCreateWithNativeHandle"); + + UR_CALL(pfnProgramCreateWithNativeHandle(hNativeProgram, hContext, + pProperties, phProgram)); + UR_CALL(getAsanInterceptor()->insertProgram(*phProgram)); + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramRetain +__urdlllocal ur_result_t UR_APICALL urProgramRetain( + ur_program_handle_t + hProgram ///< [in][retain] handle for the Program to retain +) { + auto pfnRetain = getContext()->urDdiTable.Program.pfnRetain; + + if (nullptr == pfnRetain) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urProgramRetain"); + + UR_CALL(pfnRetain(hProgram)); + + auto ProgramInfo = getAsanInterceptor()->getProgramInfo(hProgram); + UR_ASSERT(ProgramInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); + ProgramInfo->RefCount++; + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramBuild +__urdlllocal ur_result_t UR_APICALL urProgramBuild( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_program_handle_t hProgram, ///< [in] handle of the program object + const char *pOptions ///< [in] string of build options +) { + auto pfnProgramBuild = getContext()->urDdiTable.Program.pfnBuild; + + if (nullptr == pfnProgramBuild) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urProgramBuild"); + + UR_CALL(pfnProgramBuild(hContext, hProgram, pOptions)); + + UR_CALL(getAsanInterceptor()->registerProgram(hContext, hProgram)); + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramBuildExp +__urdlllocal ur_result_t UR_APICALL urProgramBuildExp( + ur_program_handle_t hProgram, ///< [in] Handle of the program to build. + uint32_t numDevices, ///< [in] number of devices + ur_device_handle_t * + phDevices, ///< [in][range(0, numDevices)] pointer to array of device handles + const char * + pOptions ///< [in][optional] pointer to build options null-terminated string. +) { + auto pfnBuildExp = getContext()->urDdiTable.ProgramExp.pfnBuildExp; + + if (nullptr == pfnBuildExp) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urProgramBuildExp"); + + UR_CALL(pfnBuildExp(hProgram, numDevices, phDevices, pOptions)); + UR_CALL( + getAsanInterceptor()->registerProgram(GetContext(hProgram), hProgram)); + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramLink +__urdlllocal ur_result_t UR_APICALL urProgramLink( + ur_context_handle_t hContext, ///< [in] handle of the context instance. + uint32_t count, ///< [in] number of program handles in `phPrograms`. + const ur_program_handle_t * + phPrograms, ///< [in][range(0, count)] pointer to array of program handles. + const char * + pOptions, ///< [in][optional] pointer to linker options null-terminated string. + ur_program_handle_t + *phProgram ///< [out] pointer to handle of program object created. +) { + auto pfnProgramLink = getContext()->urDdiTable.Program.pfnLink; + + if (nullptr == pfnProgramLink) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urProgramLink"); + + UR_CALL(pfnProgramLink(hContext, count, phPrograms, pOptions, phProgram)); + + UR_CALL(getAsanInterceptor()->registerProgram(hContext, *phProgram)); + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramLinkExp +ur_result_t UR_APICALL urProgramLinkExp( + ur_context_handle_t hContext, ///< [in] handle of the context instance. + uint32_t numDevices, ///< [in] number of devices + ur_device_handle_t * + phDevices, ///< [in][range(0, numDevices)] pointer to array of device handles + uint32_t count, ///< [in] number of program handles in `phPrograms`. + const ur_program_handle_t * + phPrograms, ///< [in][range(0, count)] pointer to array of program handles. + const char * + pOptions, ///< [in][optional] pointer to linker options null-terminated string. + ur_program_handle_t + *phProgram ///< [out] pointer to handle of program object created. +) { + auto pfnProgramLinkExp = getContext()->urDdiTable.ProgramExp.pfnLinkExp; + + if (nullptr == pfnProgramLinkExp) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urProgramLinkExp"); + + UR_CALL(pfnProgramLinkExp(hContext, numDevices, phDevices, count, + phPrograms, pOptions, phProgram)); + + UR_CALL(getAsanInterceptor()->registerProgram(hContext, *phProgram)); + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urProgramRelease +ur_result_t UR_APICALL urProgramRelease( + ur_program_handle_t + hProgram ///< [in][release] handle for the Program to release +) { + auto pfnProgramRelease = getContext()->urDdiTable.Program.pfnRelease; + + if (nullptr == pfnProgramRelease) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urProgramRelease"); + + UR_CALL(pfnProgramRelease(hProgram)); + + auto ProgramInfo = getAsanInterceptor()->getProgramInfo(hProgram); + UR_ASSERT(ProgramInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); + if (--ProgramInfo->RefCount == 0) { + UR_CALL(getAsanInterceptor()->unregisterProgram(hProgram)); + UR_CALL(getAsanInterceptor()->eraseProgram(hProgram)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueKernelLaunch +__urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t + workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and + ///< work-group work-items + const size_t * + pGlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< offset used to calculate the global ID of a work-item + const size_t * + pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the + ///< number of global work-items in workDim that will execute the kernel + ///< function + const size_t * + pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that + ///< specify the number of local work-items forming a work-group that will + ///< execute the kernel function. + ///< If nullptr, the runtime implementation will choose the work-group + ///< size. + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before the kernel execution. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait + ///< event. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< kernel execution instance. +) { + auto pfnKernelLaunch = getContext()->urDdiTable.Enqueue.pfnKernelLaunch; + + if (nullptr == pfnKernelLaunch) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urEnqueueKernelLaunch"); + + USMLaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue), + pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset, + workDim); + UR_CALL(LaunchInfo.initialize()); + + UR_CALL(getAsanInterceptor()->preLaunchKernel(hKernel, hQueue, LaunchInfo)); + + ur_event_handle_t hEvent{}; + ur_result_t result = + pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset, + pGlobalWorkSize, LaunchInfo.LocalWorkSize.data(), + numEventsInWaitList, phEventWaitList, &hEvent); + + if (result == UR_RESULT_SUCCESS) { + UR_CALL(getAsanInterceptor()->postLaunchKernel(hKernel, hQueue, + LaunchInfo)); + } + + if (phEvent) { + *phEvent = hEvent; + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urContextCreate +__urdlllocal ur_result_t UR_APICALL urContextCreate( + uint32_t numDevices, ///< [in] the number of devices given in phDevices + const ur_device_handle_t + *phDevices, ///< [in][range(0, numDevices)] array of handle of devices. + const ur_context_properties_t * + pProperties, ///< [in][optional] pointer to context creation properties. + ur_context_handle_t + *phContext ///< [out] pointer to handle of context object created +) { + auto pfnCreate = getContext()->urDdiTable.Context.pfnCreate; + + if (nullptr == pfnCreate) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urContextCreate"); + + ur_result_t result = + pfnCreate(numDevices, phDevices, pProperties, phContext); + + if (result == UR_RESULT_SUCCESS) { + UR_CALL(setupContext(*phContext, numDevices, phDevices)); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urContextCreateWithNativeHandle +__urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle( + ur_native_handle_t + hNativeContext, ///< [in][nocheck] the native handle of the getContext()-> + ur_adapter_handle_t hAdapter, + uint32_t numDevices, ///< [in] number of devices associated with the context + const ur_device_handle_t * + phDevices, ///< [in][range(0, numDevices)] list of devices associated with the context + const ur_context_native_properties_t * + pProperties, ///< [in][optional] pointer to native context properties struct + ur_context_handle_t * + phContext ///< [out] pointer to the handle of the context object created. +) { + auto pfnCreateWithNativeHandle = + getContext()->urDdiTable.Context.pfnCreateWithNativeHandle; + + if (nullptr == pfnCreateWithNativeHandle) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urContextCreateWithNativeHandle"); + + ur_result_t result = + pfnCreateWithNativeHandle(hNativeContext, hAdapter, numDevices, + phDevices, pProperties, phContext); + + if (result == UR_RESULT_SUCCESS) { + UR_CALL(setupContext(*phContext, numDevices, phDevices)); + } + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urContextRetain +__urdlllocal ur_result_t UR_APICALL urContextRetain( + ur_context_handle_t + hContext ///< [in] handle of the context to get a reference of. +) { + auto pfnRetain = getContext()->urDdiTable.Context.pfnRetain; + + if (nullptr == pfnRetain) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urContextRetain"); + + UR_CALL(pfnRetain(hContext)); + + auto ContextInfo = getAsanInterceptor()->getContextInfo(hContext); + UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); + ContextInfo->RefCount++; + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urContextRelease +__urdlllocal ur_result_t UR_APICALL urContextRelease( + ur_context_handle_t hContext ///< [in] handle of the context to release. +) { + auto pfnRelease = getContext()->urDdiTable.Context.pfnRelease; + + if (nullptr == pfnRelease) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urContextRelease"); + + UR_CALL(pfnRelease(hContext)); + + auto ContextInfo = getAsanInterceptor()->getContextInfo(hContext); + UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); + if (--ContextInfo->RefCount == 0) { + UR_CALL(getAsanInterceptor()->eraseContext(hContext)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemBufferCreate +__urdlllocal ur_result_t UR_APICALL urMemBufferCreate( + ur_context_handle_t hContext, ///< [in] handle of the context object + ur_mem_flags_t flags, ///< [in] allocation and usage information flags + size_t size, ///< [in] size in bytes of the memory object to be allocated + const ur_buffer_properties_t + *pProperties, ///< [in][optional] pointer to buffer creation properties + ur_mem_handle_t + *phBuffer ///< [out] pointer to handle of the memory buffer created +) { + auto pfnBufferCreate = getContext()->urDdiTable.Mem.pfnBufferCreate; + + if (nullptr == pfnBufferCreate) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + if (nullptr == phBuffer) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + getContext()->logger.debug("==== urMemBufferCreate"); + + void *Host = nullptr; + if (pProperties) { + Host = pProperties->pHost; + } + + char *hostPtrOrNull = (flags & UR_MEM_FLAG_USE_HOST_POINTER) + ? ur_cast(Host) + : nullptr; + + std::shared_ptr pMemBuffer = + std::make_shared(hContext, size, hostPtrOrNull); + + if (Host && (flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER)) { + std::shared_ptr CtxInfo = + getAsanInterceptor()->getContextInfo(hContext); + for (const auto &hDevice : CtxInfo->DeviceList) { + ManagedQueue InternalQueue(hContext, hDevice); + char *Handle = nullptr; + UR_CALL(pMemBuffer->getHandle(hDevice, Handle)); + UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( + InternalQueue, true, Handle, Host, size, 0, nullptr, nullptr)); + } + } + + ur_result_t result = getAsanInterceptor()->insertMemBuffer(pMemBuffer); + *phBuffer = ur_cast(pMemBuffer.get()); + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemGetInfo +__urdlllocal ur_result_t UR_APICALL urMemGetInfo( + ur_mem_handle_t + hMemory, ///< [in] handle to the memory object being queried. + ur_mem_info_t propName, ///< [in] type of the info to retrieve. + size_t + propSize, ///< [in] the number of bytes of memory pointed to by pPropValue. + void * + pPropValue, ///< [out][optional][typename(propName, propSize)] array of bytes holding + ///< the info. + ///< If propSize is less than the real number of bytes needed to return + ///< the info then the ::UR_RESULT_ERROR_INVALID_SIZE error is returned and + ///< pPropValue is not used. + size_t * + pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of the queried propName. +) { + auto pfnGetInfo = getContext()->urDdiTable.Mem.pfnGetInfo; + + if (nullptr == pfnGetInfo) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urMemGetInfo"); + + if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hMemory)) { + UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet); + switch (propName) { + case UR_MEM_INFO_CONTEXT: { + return ReturnValue(MemBuffer->Context); + } + case UR_MEM_INFO_SIZE: { + return ReturnValue(size_t{MemBuffer->Size}); + } + default: { + return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; + } + } + } else { + UR_CALL( + pfnGetInfo(hMemory, propName, propSize, pPropValue, pPropSizeRet)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemRetain +__urdlllocal ur_result_t UR_APICALL urMemRetain( + ur_mem_handle_t hMem ///< [in] handle of the memory object to get access +) { + auto pfnRetain = getContext()->urDdiTable.Mem.pfnRetain; + + if (nullptr == pfnRetain) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urMemRetain"); + + if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hMem)) { + MemBuffer->RefCount++; + } else { + UR_CALL(pfnRetain(hMem)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemRelease +__urdlllocal ur_result_t UR_APICALL urMemRelease( + ur_mem_handle_t hMem ///< [in] handle of the memory object to release +) { + auto pfnRelease = getContext()->urDdiTable.Mem.pfnRelease; + + if (nullptr == pfnRelease) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urMemRelease"); + + if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hMem)) { + if (--MemBuffer->RefCount != 0) { + return UR_RESULT_SUCCESS; + } + UR_CALL(MemBuffer->free()); + UR_CALL(getAsanInterceptor()->eraseMemBuffer(hMem)); + } else { + UR_CALL(pfnRelease(hMem)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemBufferPartition +__urdlllocal ur_result_t UR_APICALL urMemBufferPartition( + ur_mem_handle_t + hBuffer, ///< [in] handle of the buffer object to allocate from + ur_mem_flags_t flags, ///< [in] allocation and usage information flags + ur_buffer_create_type_t bufferCreateType, ///< [in] buffer creation type + const ur_buffer_region_t + *pRegion, ///< [in] pointer to buffer create region information + ur_mem_handle_t + *phMem ///< [out] pointer to the handle of sub buffer created +) { + auto pfnBufferPartition = getContext()->urDdiTable.Mem.pfnBufferPartition; + + if (nullptr == pfnBufferPartition) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urMemBufferPartition"); + + if (auto ParentBuffer = getAsanInterceptor()->getMemBuffer(hBuffer)) { + if (ParentBuffer->Size < (pRegion->origin + pRegion->size)) { + return UR_RESULT_ERROR_INVALID_BUFFER_SIZE; + } + std::shared_ptr SubBuffer = std::make_shared( + ParentBuffer, pRegion->origin, pRegion->size); + UR_CALL(getAsanInterceptor()->insertMemBuffer(SubBuffer)); + *phMem = reinterpret_cast(SubBuffer.get()); + } else { + UR_CALL(pfnBufferPartition(hBuffer, flags, bufferCreateType, pRegion, + phMem)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urMemGetNativeHandle +__urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( + ur_mem_handle_t hMem, ///< [in] handle of the mem. + ur_device_handle_t hDevice, + ur_native_handle_t + *phNativeMem ///< [out] a pointer to the native handle of the mem. +) { + auto pfnGetNativeHandle = getContext()->urDdiTable.Mem.pfnGetNativeHandle; + + if (nullptr == pfnGetNativeHandle) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urMemGetNativeHandle"); + + if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hMem)) { + char *Handle = nullptr; + UR_CALL(MemBuffer->getHandle(hDevice, Handle)); + *phNativeMem = ur_cast(Handle); + } else { + UR_CALL(pfnGetNativeHandle(hMem, hDevice, phNativeMem)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferRead +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferRead( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBuffer, ///< [in][bounds(offset, size)] handle of the buffer object + bool blockingRead, ///< [in] indicates blocking (true), non-blocking (false) + size_t offset, ///< [in] offset in bytes in the buffer object + size_t size, ///< [in] size in bytes of data being read + void *pDst, ///< [in] pointer to host memory where data is to be read into + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferRead = getContext()->urDdiTable.Enqueue.pfnMemBufferRead; + + if (nullptr == pfnMemBufferRead) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urEnqueueMemBufferRead"); + + if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hBuffer)) { + ur_device_handle_t Device = GetDevice(hQueue); + char *pSrc = nullptr; + UR_CALL(MemBuffer->getHandle(Device, pSrc)); + UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( + hQueue, blockingRead, pDst, pSrc + offset, size, + numEventsInWaitList, phEventWaitList, phEvent)); + } else { + UR_CALL(pfnMemBufferRead(hQueue, hBuffer, blockingRead, offset, size, + pDst, numEventsInWaitList, phEventWaitList, + phEvent)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferWrite +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWrite( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBuffer, ///< [in][bounds(offset, size)] handle of the buffer object + bool + blockingWrite, ///< [in] indicates blocking (true), non-blocking (false) + size_t offset, ///< [in] offset in bytes in the buffer object + size_t size, ///< [in] size in bytes of data being written + const void + *pSrc, ///< [in] pointer to host memory where data is to be written from + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferWrite = getContext()->urDdiTable.Enqueue.pfnMemBufferWrite; + + if (nullptr == pfnMemBufferWrite) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urEnqueueMemBufferWrite"); + + if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hBuffer)) { + ur_device_handle_t Device = GetDevice(hQueue); + char *pDst = nullptr; + UR_CALL(MemBuffer->getHandle(Device, pDst)); + UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( + hQueue, blockingWrite, pDst + offset, pSrc, size, + numEventsInWaitList, phEventWaitList, phEvent)); + } else { + UR_CALL(pfnMemBufferWrite(hQueue, hBuffer, blockingWrite, offset, size, + pSrc, numEventsInWaitList, phEventWaitList, + phEvent)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferReadRect +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferReadRect( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBuffer, ///< [in][bounds(bufferOrigin, region)] handle of the buffer object + bool blockingRead, ///< [in] indicates blocking (true), non-blocking (false) + ur_rect_offset_t bufferOrigin, ///< [in] 3D offset in the buffer + ur_rect_offset_t hostOrigin, ///< [in] 3D offset in the host region + ur_rect_region_t + region, ///< [in] 3D rectangular region descriptor: width, height, depth + size_t + bufferRowPitch, ///< [in] length of each row in bytes in the buffer object + size_t + bufferSlicePitch, ///< [in] length of each 2D slice in bytes in the buffer object being read + size_t + hostRowPitch, ///< [in] length of each row in bytes in the host memory region pointed by + ///< dst + size_t + hostSlicePitch, ///< [in] length of each 2D slice in bytes in the host memory region + ///< pointed by dst + void *pDst, ///< [in] pointer to host memory where data is to be read into + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferReadRect = + getContext()->urDdiTable.Enqueue.pfnMemBufferReadRect; + + if (nullptr == pfnMemBufferReadRect) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urEnqueueMemBufferReadRect"); + + if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hBuffer)) { + char *SrcHandle = nullptr; + ur_device_handle_t Device = GetDevice(hQueue); + UR_CALL(MemBuffer->getHandle(Device, SrcHandle)); + + UR_CALL(EnqueueMemCopyRectHelper( + hQueue, SrcHandle, ur_cast(pDst), bufferOrigin, hostOrigin, + region, bufferRowPitch, bufferSlicePitch, hostRowPitch, + hostSlicePitch, blockingRead, numEventsInWaitList, phEventWaitList, + phEvent)); + } else { + UR_CALL(pfnMemBufferReadRect( + hQueue, hBuffer, blockingRead, bufferOrigin, hostOrigin, region, + bufferRowPitch, bufferSlicePitch, hostRowPitch, hostSlicePitch, + pDst, numEventsInWaitList, phEventWaitList, phEvent)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferWriteRect +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWriteRect( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBuffer, ///< [in][bounds(bufferOrigin, region)] handle of the buffer object + bool + blockingWrite, ///< [in] indicates blocking (true), non-blocking (false) + ur_rect_offset_t bufferOrigin, ///< [in] 3D offset in the buffer + ur_rect_offset_t hostOrigin, ///< [in] 3D offset in the host region + ur_rect_region_t + region, ///< [in] 3D rectangular region descriptor: width, height, depth + size_t + bufferRowPitch, ///< [in] length of each row in bytes in the buffer object + size_t + bufferSlicePitch, ///< [in] length of each 2D slice in bytes in the buffer object being + ///< written + size_t + hostRowPitch, ///< [in] length of each row in bytes in the host memory region pointed by + ///< src + size_t + hostSlicePitch, ///< [in] length of each 2D slice in bytes in the host memory region + ///< pointed by src + void + *pSrc, ///< [in] pointer to host memory where data is to be written from + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] points to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferWriteRect = + getContext()->urDdiTable.Enqueue.pfnMemBufferWriteRect; + + if (nullptr == pfnMemBufferWriteRect) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urEnqueueMemBufferWriteRect"); + + if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hBuffer)) { + char *DstHandle = nullptr; + ur_device_handle_t Device = GetDevice(hQueue); + UR_CALL(MemBuffer->getHandle(Device, DstHandle)); + + UR_CALL(EnqueueMemCopyRectHelper( + hQueue, ur_cast(pSrc), DstHandle, hostOrigin, bufferOrigin, + region, hostRowPitch, hostSlicePitch, bufferRowPitch, + bufferSlicePitch, blockingWrite, numEventsInWaitList, + phEventWaitList, phEvent)); + } else { + UR_CALL(pfnMemBufferWriteRect( + hQueue, hBuffer, blockingWrite, bufferOrigin, hostOrigin, region, + bufferRowPitch, bufferSlicePitch, hostRowPitch, hostSlicePitch, + pSrc, numEventsInWaitList, phEventWaitList, phEvent)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferCopy +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopy( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBufferSrc, ///< [in][bounds(srcOffset, size)] handle of the src buffer object + ur_mem_handle_t + hBufferDst, ///< [in][bounds(dstOffset, size)] handle of the dest buffer object + size_t srcOffset, ///< [in] offset into hBufferSrc to begin copying from + size_t dstOffset, ///< [in] offset info hBufferDst to begin copying into + size_t size, ///< [in] size in bytes of data being copied + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferCopy = getContext()->urDdiTable.Enqueue.pfnMemBufferCopy; + + if (nullptr == pfnMemBufferCopy) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urEnqueueMemBufferCopy"); + + auto SrcBuffer = getAsanInterceptor()->getMemBuffer(hBufferSrc); + auto DstBuffer = getAsanInterceptor()->getMemBuffer(hBufferDst); + + UR_ASSERT((SrcBuffer && DstBuffer) || (!SrcBuffer && !DstBuffer), + UR_RESULT_ERROR_INVALID_MEM_OBJECT); + + if (SrcBuffer && DstBuffer) { + ur_device_handle_t Device = GetDevice(hQueue); + char *SrcHandle = nullptr; + UR_CALL(SrcBuffer->getHandle(Device, SrcHandle)); + + char *DstHandle = nullptr; + UR_CALL(DstBuffer->getHandle(Device, DstHandle)); + + UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( + hQueue, false, DstHandle + dstOffset, SrcHandle + srcOffset, size, + numEventsInWaitList, phEventWaitList, phEvent)); + } else { + UR_CALL(pfnMemBufferCopy(hQueue, hBufferSrc, hBufferDst, srcOffset, + dstOffset, size, numEventsInWaitList, + phEventWaitList, phEvent)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferCopyRect +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBufferSrc, ///< [in][bounds(srcOrigin, region)] handle of the source buffer object + ur_mem_handle_t + hBufferDst, ///< [in][bounds(dstOrigin, region)] handle of the dest buffer object + ur_rect_offset_t srcOrigin, ///< [in] 3D offset in the source buffer + ur_rect_offset_t dstOrigin, ///< [in] 3D offset in the destination buffer + ur_rect_region_t + region, ///< [in] source 3D rectangular region descriptor: width, height, depth + size_t + srcRowPitch, ///< [in] length of each row in bytes in the source buffer object + size_t + srcSlicePitch, ///< [in] length of each 2D slice in bytes in the source buffer object + size_t + dstRowPitch, ///< [in] length of each row in bytes in the destination buffer object + size_t + dstSlicePitch, ///< [in] length of each 2D slice in bytes in the destination buffer object + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferCopyRect = + getContext()->urDdiTable.Enqueue.pfnMemBufferCopyRect; + + if (nullptr == pfnMemBufferCopyRect) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urEnqueueMemBufferCopyRect"); + + auto SrcBuffer = getAsanInterceptor()->getMemBuffer(hBufferSrc); + auto DstBuffer = getAsanInterceptor()->getMemBuffer(hBufferDst); + + UR_ASSERT((SrcBuffer && DstBuffer) || (!SrcBuffer && !DstBuffer), + UR_RESULT_ERROR_INVALID_MEM_OBJECT); + + if (SrcBuffer && DstBuffer) { + ur_device_handle_t Device = GetDevice(hQueue); + char *SrcHandle = nullptr; + UR_CALL(SrcBuffer->getHandle(Device, SrcHandle)); + + char *DstHandle = nullptr; + UR_CALL(DstBuffer->getHandle(Device, DstHandle)); + + UR_CALL(EnqueueMemCopyRectHelper( + hQueue, SrcHandle, DstHandle, srcOrigin, dstOrigin, region, + srcRowPitch, srcSlicePitch, dstRowPitch, dstSlicePitch, false, + numEventsInWaitList, phEventWaitList, phEvent)); + } else { + UR_CALL(pfnMemBufferCopyRect( + hQueue, hBufferSrc, hBufferDst, srcOrigin, dstOrigin, region, + srcRowPitch, srcSlicePitch, dstRowPitch, dstSlicePitch, + numEventsInWaitList, phEventWaitList, phEvent)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferFill +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferFill( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBuffer, ///< [in][bounds(offset, size)] handle of the buffer object + const void *pPattern, ///< [in] pointer to the fill pattern + size_t patternSize, ///< [in] size in bytes of the pattern + size_t offset, ///< [in] offset into the buffer + size_t size, ///< [in] fill size in bytes, must be a multiple of patternSize + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemBufferFill = getContext()->urDdiTable.Enqueue.pfnMemBufferFill; + + if (nullptr == pfnMemBufferFill) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urEnqueueMemBufferFill"); + + if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hBuffer)) { + char *Handle = nullptr; + ur_device_handle_t Device = GetDevice(hQueue); + UR_CALL(MemBuffer->getHandle(Device, Handle)); + UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill( + hQueue, Handle + offset, patternSize, pPattern, size, + numEventsInWaitList, phEventWaitList, phEvent)); + } else { + UR_CALL(pfnMemBufferFill(hQueue, hBuffer, pPattern, patternSize, offset, + size, numEventsInWaitList, phEventWaitList, + phEvent)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemBufferMap +__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferMap( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hBuffer, ///< [in][bounds(offset, size)] handle of the buffer object + bool blockingMap, ///< [in] indicates blocking (true), non-blocking (false) + ur_map_flags_t mapFlags, ///< [in] flags for read, write, readwrite mapping + size_t offset, ///< [in] offset in bytes of the buffer region being mapped + size_t size, ///< [in] size in bytes of the buffer region being mapped + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent, ///< [out][optional] return an event object that identifies this particular + ///< command instance. + void **ppRetMap ///< [out] return mapped pointer. TODO: move it before + ///< numEventsInWaitList? +) { + auto pfnMemBufferMap = getContext()->urDdiTable.Enqueue.pfnMemBufferMap; + + if (nullptr == pfnMemBufferMap) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urEnqueueMemBufferMap"); + + if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hBuffer)) { + + // Translate the host access mode info. + MemBuffer::AccessMode AccessMode = MemBuffer::UNKNOWN; + if (mapFlags & UR_MAP_FLAG_WRITE_INVALIDATE_REGION) { + AccessMode = MemBuffer::WRITE_ONLY; + } else { + if (mapFlags & UR_MAP_FLAG_READ) { + AccessMode = MemBuffer::READ_ONLY; + if (mapFlags & UR_MAP_FLAG_WRITE) { + AccessMode = MemBuffer::READ_WRITE; + } + } else if (mapFlags & UR_MAP_FLAG_WRITE) { + AccessMode = MemBuffer::WRITE_ONLY; + } + } + + UR_ASSERT(AccessMode != MemBuffer::UNKNOWN, + UR_RESULT_ERROR_INVALID_ARGUMENT); + + ur_device_handle_t Device = GetDevice(hQueue); + // If the buffer used host pointer, then we just reuse it. If not, we + // need to manually allocate a new host USM. + if (MemBuffer->HostPtr) { + *ppRetMap = MemBuffer->HostPtr + offset; + } else { + ur_context_handle_t Context = GetContext(hQueue); + ur_usm_desc_t USMDesc{}; + USMDesc.align = MemBuffer->getAlignment(); + ur_usm_pool_handle_t Pool{}; + UR_CALL(getAsanInterceptor()->allocateMemory( + Context, nullptr, &USMDesc, Pool, size, AllocType::HOST_USM, + ppRetMap)); + } + + // Actually, if the access mode is write only, we don't need to do this + // copy. However, in that way, we cannot generate a event to user. So, + // we'll aways do copy here. + char *SrcHandle = nullptr; + UR_CALL(MemBuffer->getHandle(Device, SrcHandle)); + UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( + hQueue, blockingMap, *ppRetMap, SrcHandle + offset, size, + numEventsInWaitList, phEventWaitList, phEvent)); + + { + std::scoped_lock Guard(MemBuffer->Mutex); + UR_ASSERT(MemBuffer->Mappings.find(*ppRetMap) == + MemBuffer->Mappings.end(), + UR_RESULT_ERROR_INVALID_VALUE); + MemBuffer->Mappings[*ppRetMap] = {offset, size}; + } + } else { + UR_CALL(pfnMemBufferMap(hQueue, hBuffer, blockingMap, mapFlags, offset, + size, numEventsInWaitList, phEventWaitList, + phEvent, ppRetMap)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urEnqueueMemUnmap +__urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap( + ur_queue_handle_t hQueue, ///< [in] handle of the queue object + ur_mem_handle_t + hMem, ///< [in] handle of the memory (buffer or image) object + void *pMappedPtr, ///< [in] mapped host address + uint32_t numEventsInWaitList, ///< [in] size of the event wait list + const ur_event_handle_t * + phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of + ///< events that must be complete before this command can be executed. + ///< If nullptr, the numEventsInWaitList must be 0, indicating that this + ///< command does not wait on any event to complete. + ur_event_handle_t * + phEvent ///< [out][optional] return an event object that identifies this particular + ///< command instance. +) { + auto pfnMemUnmap = getContext()->urDdiTable.Enqueue.pfnMemUnmap; + + if (nullptr == pfnMemUnmap) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urEnqueueMemUnmap"); + + if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hMem)) { + MemBuffer::Mapping Mapping{}; + { + std::scoped_lock Guard(MemBuffer->Mutex); + auto It = MemBuffer->Mappings.find(pMappedPtr); + UR_ASSERT(It != MemBuffer->Mappings.end(), + UR_RESULT_ERROR_INVALID_VALUE); + Mapping = It->second; + MemBuffer->Mappings.erase(It); + } + + // Write back mapping memory data to device and release mapping memory + // if we allocated a host USM. But for now, UR doesn't support event + // call back, we can only do blocking copy here. + char *DstHandle = nullptr; + ur_context_handle_t Context = GetContext(hQueue); + ur_device_handle_t Device = GetDevice(hQueue); + UR_CALL(MemBuffer->getHandle(Device, DstHandle)); + UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( + hQueue, true, DstHandle + Mapping.Offset, pMappedPtr, Mapping.Size, + numEventsInWaitList, phEventWaitList, phEvent)); + + if (!MemBuffer->HostPtr) { + UR_CALL(getAsanInterceptor()->releaseMemory(Context, pMappedPtr)); + } + } else { + UR_CALL(pfnMemUnmap(hQueue, hMem, pMappedPtr, numEventsInWaitList, + phEventWaitList, phEvent)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelCreate +__urdlllocal ur_result_t UR_APICALL urKernelCreate( + ur_program_handle_t hProgram, ///< [in] handle of the program instance + const char *pKernelName, ///< [in] pointer to null-terminated string. + ur_kernel_handle_t + *phKernel ///< [out] pointer to handle of kernel object created. +) { + auto pfnCreate = getContext()->urDdiTable.Kernel.pfnCreate; + + if (nullptr == pfnCreate) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urKernelCreate"); + + UR_CALL(pfnCreate(hProgram, pKernelName, phKernel)); + UR_CALL(getAsanInterceptor()->insertKernel(*phKernel)); + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelRetain +__urdlllocal ur_result_t UR_APICALL urKernelRetain( + ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain +) { + auto pfnRetain = getContext()->urDdiTable.Kernel.pfnRetain; + + if (nullptr == pfnRetain) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urKernelRetain"); + + UR_CALL(pfnRetain(hKernel)); + + auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel); + UR_ASSERT(KernelInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); + KernelInfo->RefCount++; + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelRelease +__urdlllocal ur_result_t urKernelRelease( + ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to release +) { + auto pfnRelease = getContext()->urDdiTable.Kernel.pfnRelease; + + if (nullptr == pfnRelease) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urKernelRelease"); + UR_CALL(pfnRelease(hKernel)); + + auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel); + UR_ASSERT(KernelInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); + if (--KernelInfo->RefCount == 0) { + UR_CALL(getAsanInterceptor()->eraseKernel(hKernel)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSetArgValue +__urdlllocal ur_result_t UR_APICALL urKernelSetArgValue( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t argIndex, ///< [in] argument index in range [0, num args - 1] + size_t argSize, ///< [in] size of argument type + const ur_kernel_arg_value_properties_t + *pProperties, ///< [in][optional] pointer to value properties. + const void + *pArgValue ///< [in] argument value represented as matching arg type. +) { + auto pfnSetArgValue = getContext()->urDdiTable.Kernel.pfnSetArgValue; + + if (nullptr == pfnSetArgValue) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urKernelSetArgValue"); + + std::shared_ptr MemBuffer; + if (argSize == sizeof(ur_mem_handle_t) && + (MemBuffer = getAsanInterceptor()->getMemBuffer( + *ur_cast(pArgValue)))) { + auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel); + std::scoped_lock Guard(KernelInfo->Mutex); + KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer); + } else { + UR_CALL( + pfnSetArgValue(hKernel, argIndex, argSize, pProperties, pArgValue)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSetArgMemObj +__urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t argIndex, ///< [in] argument index in range [0, num args - 1] + const ur_kernel_arg_mem_obj_properties_t + *pProperties, ///< [in][optional] pointer to Memory object properties. + ur_mem_handle_t hArgValue ///< [in][optional] handle of Memory object. +) { + auto pfnSetArgMemObj = getContext()->urDdiTable.Kernel.pfnSetArgMemObj; + + if (nullptr == pfnSetArgMemObj) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug("==== urKernelSetArgMemObj"); + + if (auto MemBuffer = getAsanInterceptor()->getMemBuffer(hArgValue)) { + auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel); + std::scoped_lock Guard(KernelInfo->Mutex); + KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer); + } else { + UR_CALL(pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue)); + } + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSetArgLocal +__urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t argIndex, ///< [in] argument index in range [0, num args - 1] + size_t + argSize, ///< [in] size of the local buffer to be allocated by the runtime + const ur_kernel_arg_local_properties_t + *pProperties ///< [in][optional] pointer to local buffer properties. +) { + auto pfnSetArgLocal = getContext()->urDdiTable.Kernel.pfnSetArgLocal; + + if (nullptr == pfnSetArgLocal) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug( + "==== urKernelSetArgLocal (argIndex={}, argSize={})", argIndex, + argSize); + + { + auto KI = getAsanInterceptor()->getKernelInfo(hKernel); + std::scoped_lock Guard(KI->Mutex); + // TODO: get local variable alignment + auto argSizeWithRZ = GetSizeAndRedzoneSizeForLocal( + argSize, ASAN_SHADOW_GRANULARITY, ASAN_SHADOW_GRANULARITY); + KI->LocalArgs[argIndex] = LocalArgsInfo{argSize, argSizeWithRZ}; + argSize = argSizeWithRZ; + } + + ur_result_t result = + pfnSetArgLocal(hKernel, argIndex, argSize, pProperties); + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Intercept function for urKernelSetArgPointer +__urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer( + ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object + uint32_t argIndex, ///< [in] argument index in range [0, num args - 1] + const ur_kernel_arg_pointer_properties_t + *pProperties, ///< [in][optional] pointer to USM pointer properties. + const void * + pArgValue ///< [in][optional] Pointer obtained by USM allocation or virtual memory + ///< mapping operation. If null then argument value is considered null. +) { + auto pfnSetArgPointer = getContext()->urDdiTable.Kernel.pfnSetArgPointer; + + if (nullptr == pfnSetArgPointer) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + getContext()->logger.debug( + "==== urKernelSetArgPointer (argIndex={}, pArgValue={})", argIndex, + pArgValue); + + if (getAsanInterceptor()->getOptions().DetectKernelArguments) { + auto KI = getAsanInterceptor()->getKernelInfo(hKernel); + std::scoped_lock Guard(KI->Mutex); + KI->PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()}; + } + + ur_result_t result = + pfnSetArgPointer(hKernel, argIndex, pProperties, pArgValue); + + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Global table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +__urdlllocal ur_result_t UR_APICALL urGetGlobalProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_global_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + pDdiTable->pfnAdapterGet = ur_sanitizer_layer::asan::urAdapterGet; + + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Context table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +__urdlllocal ur_result_t UR_APICALL urGetContextProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_context_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + pDdiTable->pfnCreate = ur_sanitizer_layer::asan::urContextCreate; + pDdiTable->pfnRetain = ur_sanitizer_layer::asan::urContextRetain; + pDdiTable->pfnRelease = ur_sanitizer_layer::asan::urContextRelease; + + pDdiTable->pfnCreateWithNativeHandle = + ur_sanitizer_layer::asan::urContextCreateWithNativeHandle; + + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Program table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +__urdlllocal ur_result_t UR_APICALL urGetProgramProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_program_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + pDdiTable->pfnCreateWithIL = + ur_sanitizer_layer::asan::urProgramCreateWithIL; + pDdiTable->pfnCreateWithBinary = + ur_sanitizer_layer::asan::urProgramCreateWithBinary; + pDdiTable->pfnCreateWithNativeHandle = + ur_sanitizer_layer::asan::urProgramCreateWithNativeHandle; + pDdiTable->pfnBuild = ur_sanitizer_layer::asan::urProgramBuild; + pDdiTable->pfnLink = ur_sanitizer_layer::asan::urProgramLink; + pDdiTable->pfnRetain = ur_sanitizer_layer::asan::urProgramRetain; + pDdiTable->pfnRelease = ur_sanitizer_layer::asan::urProgramRelease; + + return UR_RESULT_SUCCESS; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Kernel table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +__urdlllocal ur_result_t UR_APICALL urGetKernelProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_kernel_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + pDdiTable->pfnCreate = ur_sanitizer_layer::asan::urKernelCreate; + pDdiTable->pfnRetain = ur_sanitizer_layer::asan::urKernelRetain; + pDdiTable->pfnRelease = ur_sanitizer_layer::asan::urKernelRelease; + pDdiTable->pfnSetArgValue = ur_sanitizer_layer::asan::urKernelSetArgValue; + pDdiTable->pfnSetArgMemObj = ur_sanitizer_layer::asan::urKernelSetArgMemObj; + pDdiTable->pfnSetArgLocal = ur_sanitizer_layer::asan::urKernelSetArgLocal; + pDdiTable->pfnSetArgPointer = + ur_sanitizer_layer::asan::urKernelSetArgPointer; + + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Mem table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +__urdlllocal ur_result_t UR_APICALL urGetMemProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_mem_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + pDdiTable->pfnBufferCreate = ur_sanitizer_layer::asan::urMemBufferCreate; + pDdiTable->pfnRetain = ur_sanitizer_layer::asan::urMemRetain; + pDdiTable->pfnRelease = ur_sanitizer_layer::asan::urMemRelease; + pDdiTable->pfnBufferPartition = + ur_sanitizer_layer::asan::urMemBufferPartition; + pDdiTable->pfnGetNativeHandle = + ur_sanitizer_layer::asan::urMemGetNativeHandle; + pDdiTable->pfnGetInfo = ur_sanitizer_layer::asan::urMemGetInfo; + + return result; +} +/// @brief Exported function for filling application's ProgramExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +__urdlllocal ur_result_t UR_APICALL urGetProgramExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_program_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + pDdiTable->pfnBuildExp = ur_sanitizer_layer::asan::urProgramBuildExp; + pDdiTable->pfnLinkExp = ur_sanitizer_layer::asan::urProgramLinkExp; + + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's Enqueue table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +__urdlllocal ur_result_t UR_APICALL urGetEnqueueProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_enqueue_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + pDdiTable->pfnMemBufferRead = + ur_sanitizer_layer::asan::urEnqueueMemBufferRead; + pDdiTable->pfnMemBufferWrite = + ur_sanitizer_layer::asan::urEnqueueMemBufferWrite; + pDdiTable->pfnMemBufferReadRect = + ur_sanitizer_layer::asan::urEnqueueMemBufferReadRect; + pDdiTable->pfnMemBufferWriteRect = + ur_sanitizer_layer::asan::urEnqueueMemBufferWriteRect; + pDdiTable->pfnMemBufferCopy = + ur_sanitizer_layer::asan::urEnqueueMemBufferCopy; + pDdiTable->pfnMemBufferCopyRect = + ur_sanitizer_layer::asan::urEnqueueMemBufferCopyRect; + pDdiTable->pfnMemBufferFill = + ur_sanitizer_layer::asan::urEnqueueMemBufferFill; + pDdiTable->pfnMemBufferMap = + ur_sanitizer_layer::asan::urEnqueueMemBufferMap; + pDdiTable->pfnMemUnmap = ur_sanitizer_layer::asan::urEnqueueMemUnmap; + pDdiTable->pfnKernelLaunch = + ur_sanitizer_layer::asan::urEnqueueKernelLaunch; + + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's USM table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +__urdlllocal ur_result_t UR_APICALL urGetUSMProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_usm_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + pDdiTable->pfnDeviceAlloc = ur_sanitizer_layer::asan::urUSMDeviceAlloc; + pDdiTable->pfnHostAlloc = ur_sanitizer_layer::asan::urUSMHostAlloc; + pDdiTable->pfnSharedAlloc = ur_sanitizer_layer::asan::urUSMSharedAlloc; + pDdiTable->pfnFree = ur_sanitizer_layer::asan::urUSMFree; + + return result; +} + +} // namespace asan + +ur_result_t asan_ddi_init(ur_dditable_t *dditable) { + ur_result_t result = UR_RESULT_SUCCESS; + + if (UR_RESULT_SUCCESS == result) { + result = ur_sanitizer_layer::asan::urGetGlobalProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->Global); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_sanitizer_layer::asan::urGetContextProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->Context); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_sanitizer_layer::asan::urGetKernelProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->Kernel); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_sanitizer_layer::asan::urGetProgramProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->Program); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_sanitizer_layer::asan::urGetKernelProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->Kernel); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_sanitizer_layer::asan::urGetMemProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->Mem); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_sanitizer_layer::asan::urGetProgramExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->ProgramExp); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_sanitizer_layer::asan::urGetEnqueueProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->Enqueue); + } + + if (UR_RESULT_SUCCESS == result) { + result = ur_sanitizer_layer::asan::urGetUSMProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->USM); + } + + return result; +} + +} // namespace ur_sanitizer_layer diff --git a/source/loader/layers/sanitizer/asan/asan_ddi.hpp b/source/loader/layers/sanitizer/asan/asan_ddi.hpp new file mode 100644 index 0000000000..9b1ef66f42 --- /dev/null +++ b/source/loader/layers/sanitizer/asan/asan_ddi.hpp @@ -0,0 +1,22 @@ +/* + * + * Copyright (C) 2024 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. + * See LICENSE.TXT + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file asan_ddi.hpp + * + */ + +#include "ur_ddi.h" + +namespace ur_sanitizer_layer { + +void initAsanInterceptor(); +void destroyAsanInterceptor(); + +ur_result_t asan_ddi_init(ur_dditable_t *dditable); + +} // namespace ur_sanitizer_layer diff --git a/source/loader/layers/sanitizer/asan_interceptor.cpp b/source/loader/layers/sanitizer/asan/asan_interceptor.cpp similarity index 90% rename from source/loader/layers/sanitizer/asan_interceptor.cpp rename to source/loader/layers/sanitizer/asan/asan_interceptor.cpp index 4a315588fd..ff36f4a0d4 100644 --- a/source/loader/layers/sanitizer/asan_interceptor.cpp +++ b/source/loader/layers/sanitizer/asan/asan_interceptor.cpp @@ -12,17 +12,18 @@ */ #include "asan_interceptor.hpp" +#include "asan_ddi.hpp" #include "asan_options.hpp" #include "asan_quarantine.hpp" #include "asan_report.hpp" #include "asan_shadow.hpp" #include "asan_validator.hpp" -#include "stacktrace.hpp" -#include "ur_sanitizer_utils.hpp" +#include "sanitizer_common/sanitizer_stacktrace.hpp" +#include "sanitizer_common/sanitizer_utils.hpp" namespace ur_sanitizer_layer { -SanitizerInterceptor::SanitizerInterceptor() { +AsanInterceptor::AsanInterceptor() { if (getOptions().MaxQuarantineSizeMB) { m_Quarantine = std::make_unique( static_cast(getOptions().MaxQuarantineSizeMB) * 1024 * @@ -30,7 +31,7 @@ SanitizerInterceptor::SanitizerInterceptor() { } } -SanitizerInterceptor::~SanitizerInterceptor() { +AsanInterceptor::~AsanInterceptor() { // We must release these objects before releasing adapters, since // they may use the adapter in their destructor for (const auto &[_, DeviceInfo] : m_DeviceMap) { @@ -55,10 +56,12 @@ SanitizerInterceptor::~SanitizerInterceptor() { /// R -- right redzone (0 or more bytes) /// /// ref: "compiler-rt/lib/asan/asan_allocator.cpp" Allocator::Allocate -ur_result_t SanitizerInterceptor::allocateMemory( - ur_context_handle_t Context, ur_device_handle_t Device, - const ur_usm_desc_t *Properties, ur_usm_pool_handle_t Pool, size_t Size, - AllocType Type, void **ResultPtr) { +ur_result_t AsanInterceptor::allocateMemory(ur_context_handle_t Context, + ur_device_handle_t Device, + const ur_usm_desc_t *Properties, + ur_usm_pool_handle_t Pool, + size_t Size, AllocType Type, + void **ResultPtr) { auto ContextInfo = getContextInfo(Context); std::shared_ptr DeviceInfo = @@ -150,8 +153,8 @@ ur_result_t SanitizerInterceptor::allocateMemory( return UR_RESULT_SUCCESS; } -ur_result_t SanitizerInterceptor::releaseMemory(ur_context_handle_t Context, - void *Ptr) { +ur_result_t AsanInterceptor::releaseMemory(ur_context_handle_t Context, + void *Ptr) { auto ContextInfo = getContextInfo(Context); auto Addr = reinterpret_cast(Ptr); @@ -241,9 +244,9 @@ ur_result_t SanitizerInterceptor::releaseMemory(ur_context_handle_t Context, return UR_RESULT_SUCCESS; } -ur_result_t SanitizerInterceptor::preLaunchKernel(ur_kernel_handle_t Kernel, - ur_queue_handle_t Queue, - USMLaunchInfo &LaunchInfo) { +ur_result_t AsanInterceptor::preLaunchKernel(ur_kernel_handle_t Kernel, + ur_queue_handle_t Queue, + USMLaunchInfo &LaunchInfo) { auto Context = GetContext(Queue); auto Device = GetDevice(Queue); auto ContextInfo = getContextInfo(Context); @@ -266,9 +269,9 @@ ur_result_t SanitizerInterceptor::preLaunchKernel(ur_kernel_handle_t Kernel, return UR_RESULT_SUCCESS; } -ur_result_t SanitizerInterceptor::postLaunchKernel(ur_kernel_handle_t Kernel, - ur_queue_handle_t Queue, - USMLaunchInfo &LaunchInfo) { +ur_result_t AsanInterceptor::postLaunchKernel(ur_kernel_handle_t Kernel, + ur_queue_handle_t Queue, + USMLaunchInfo &LaunchInfo) { // FIXME: We must use block operation here, until we support urEventSetCallback auto Result = getContext()->urDdiTable.Queue.pfnFinish(Queue); @@ -316,9 +319,9 @@ ur_result_t DeviceInfo::allocShadowMemory(ur_context_handle_t Context) { /// /// ref: https://github.com/google/sanitizers/wiki/AddressSanitizerAlgorithm#mapping ur_result_t -SanitizerInterceptor::enqueueAllocInfo(std::shared_ptr &DeviceInfo, - ur_queue_handle_t Queue, - std::shared_ptr &AI) { +AsanInterceptor::enqueueAllocInfo(std::shared_ptr &DeviceInfo, + ur_queue_handle_t Queue, + std::shared_ptr &AI) { if (AI->IsReleased) { int ShadowByte; switch (AI->Type) { @@ -391,9 +394,10 @@ SanitizerInterceptor::enqueueAllocInfo(std::shared_ptr &DeviceInfo, return UR_RESULT_SUCCESS; } -ur_result_t SanitizerInterceptor::updateShadowMemory( - std::shared_ptr &ContextInfo, - std::shared_ptr &DeviceInfo, ur_queue_handle_t Queue) { +ur_result_t +AsanInterceptor::updateShadowMemory(std::shared_ptr &ContextInfo, + std::shared_ptr &DeviceInfo, + ur_queue_handle_t Queue) { auto &AllocInfos = ContextInfo->AllocInfosMap[DeviceInfo->Handle]; std::scoped_lock Guard(AllocInfos.Mutex); @@ -405,8 +409,8 @@ ur_result_t SanitizerInterceptor::updateShadowMemory( return UR_RESULT_SUCCESS; } -ur_result_t SanitizerInterceptor::registerProgram(ur_context_handle_t Context, - ur_program_handle_t Program) { +ur_result_t AsanInterceptor::registerProgram(ur_context_handle_t Context, + ur_program_handle_t Program) { std::vector Devices = GetDevices(Program); auto ContextInfo = getContextInfo(Context); @@ -464,8 +468,7 @@ ur_result_t SanitizerInterceptor::registerProgram(ur_context_handle_t Context, return UR_RESULT_SUCCESS; } -ur_result_t -SanitizerInterceptor::unregisterProgram(ur_program_handle_t Program) { +ur_result_t AsanInterceptor::unregisterProgram(ur_program_handle_t Program) { auto ProgramInfo = getProgramInfo(Program); std::scoped_lock Guard( @@ -479,9 +482,8 @@ SanitizerInterceptor::unregisterProgram(ur_program_handle_t Program) { return UR_RESULT_SUCCESS; } -ur_result_t -SanitizerInterceptor::insertContext(ur_context_handle_t Context, - std::shared_ptr &CI) { +ur_result_t AsanInterceptor::insertContext(ur_context_handle_t Context, + std::shared_ptr &CI) { std::scoped_lock Guard(m_ContextMapMutex); if (m_ContextMap.find(Context) != m_ContextMap.end()) { @@ -497,7 +499,7 @@ SanitizerInterceptor::insertContext(ur_context_handle_t Context, return UR_RESULT_SUCCESS; } -ur_result_t SanitizerInterceptor::eraseContext(ur_context_handle_t Context) { +ur_result_t AsanInterceptor::eraseContext(ur_context_handle_t Context) { std::scoped_lock Guard(m_ContextMapMutex); assert(m_ContextMap.find(Context) != m_ContextMap.end()); m_ContextMap.erase(Context); @@ -505,9 +507,8 @@ ur_result_t SanitizerInterceptor::eraseContext(ur_context_handle_t Context) { return UR_RESULT_SUCCESS; } -ur_result_t -SanitizerInterceptor::insertDevice(ur_device_handle_t Device, - std::shared_ptr &DI) { +ur_result_t AsanInterceptor::insertDevice(ur_device_handle_t Device, + std::shared_ptr &DI) { std::scoped_lock Guard(m_DeviceMapMutex); if (m_DeviceMap.find(Device) != m_DeviceMap.end()) { @@ -531,7 +532,7 @@ SanitizerInterceptor::insertDevice(ur_device_handle_t Device, return UR_RESULT_SUCCESS; } -ur_result_t SanitizerInterceptor::eraseDevice(ur_device_handle_t Device) { +ur_result_t AsanInterceptor::eraseDevice(ur_device_handle_t Device) { std::scoped_lock Guard(m_DeviceMapMutex); assert(m_DeviceMap.find(Device) != m_DeviceMap.end()); m_DeviceMap.erase(Device); @@ -539,7 +540,7 @@ ur_result_t SanitizerInterceptor::eraseDevice(ur_device_handle_t Device) { return UR_RESULT_SUCCESS; } -ur_result_t SanitizerInterceptor::insertProgram(ur_program_handle_t Program) { +ur_result_t AsanInterceptor::insertProgram(ur_program_handle_t Program) { std::scoped_lock Guard(m_ProgramMapMutex); if (m_ProgramMap.find(Program) != m_ProgramMap.end()) { return UR_RESULT_SUCCESS; @@ -548,14 +549,14 @@ ur_result_t SanitizerInterceptor::insertProgram(ur_program_handle_t Program) { return UR_RESULT_SUCCESS; } -ur_result_t SanitizerInterceptor::eraseProgram(ur_program_handle_t Program) { +ur_result_t AsanInterceptor::eraseProgram(ur_program_handle_t Program) { std::scoped_lock Guard(m_ProgramMapMutex); assert(m_ProgramMap.find(Program) != m_ProgramMap.end()); m_ProgramMap.erase(Program); return UR_RESULT_SUCCESS; } -ur_result_t SanitizerInterceptor::insertKernel(ur_kernel_handle_t Kernel) { +ur_result_t AsanInterceptor::insertKernel(ur_kernel_handle_t Kernel) { std::scoped_lock Guard(m_KernelMapMutex); if (m_KernelMap.find(Kernel) != m_KernelMap.end()) { return UR_RESULT_SUCCESS; @@ -564,7 +565,7 @@ ur_result_t SanitizerInterceptor::insertKernel(ur_kernel_handle_t Kernel) { return UR_RESULT_SUCCESS; } -ur_result_t SanitizerInterceptor::eraseKernel(ur_kernel_handle_t Kernel) { +ur_result_t AsanInterceptor::eraseKernel(ur_kernel_handle_t Kernel) { std::scoped_lock Guard(m_KernelMapMutex); assert(m_KernelMap.find(Kernel) != m_KernelMap.end()); m_KernelMap.erase(Kernel); @@ -572,7 +573,7 @@ ur_result_t SanitizerInterceptor::eraseKernel(ur_kernel_handle_t Kernel) { } ur_result_t -SanitizerInterceptor::insertMemBuffer(std::shared_ptr MemBuffer) { +AsanInterceptor::insertMemBuffer(std::shared_ptr MemBuffer) { std::scoped_lock Guard(m_MemBufferMapMutex); assert(m_MemBufferMap.find(ur_cast(MemBuffer.get())) == m_MemBufferMap.end()); @@ -581,7 +582,7 @@ SanitizerInterceptor::insertMemBuffer(std::shared_ptr MemBuffer) { return UR_RESULT_SUCCESS; } -ur_result_t SanitizerInterceptor::eraseMemBuffer(ur_mem_handle_t MemHandle) { +ur_result_t AsanInterceptor::eraseMemBuffer(ur_mem_handle_t MemHandle) { std::scoped_lock Guard(m_MemBufferMapMutex); assert(m_MemBufferMap.find(MemHandle) != m_MemBufferMap.end()); m_MemBufferMap.erase(MemHandle); @@ -589,7 +590,7 @@ ur_result_t SanitizerInterceptor::eraseMemBuffer(ur_mem_handle_t MemHandle) { } std::shared_ptr -SanitizerInterceptor::getMemBuffer(ur_mem_handle_t MemHandle) { +AsanInterceptor::getMemBuffer(ur_mem_handle_t MemHandle) { std::shared_lock Guard(m_MemBufferMapMutex); if (m_MemBufferMap.find(MemHandle) != m_MemBufferMap.end()) { return m_MemBufferMap[MemHandle]; @@ -597,7 +598,7 @@ SanitizerInterceptor::getMemBuffer(ur_mem_handle_t MemHandle) { return nullptr; } -ur_result_t SanitizerInterceptor::prepareLaunch( +ur_result_t AsanInterceptor::prepareLaunch( std::shared_ptr &ContextInfo, std::shared_ptr &DeviceInfo, ur_queue_handle_t Queue, ur_kernel_handle_t Kernel, USMLaunchInfo &LaunchInfo) { @@ -803,7 +804,7 @@ ur_result_t SanitizerInterceptor::prepareLaunch( } std::optional -SanitizerInterceptor::findAllocInfoByAddress(uptr Address) { +AsanInterceptor::findAllocInfoByAddress(uptr Address) { std::shared_lock Guard(m_AllocationMapMutex); auto It = m_AllocationMap.upper_bound(Address); if (It == m_AllocationMap.begin()) { @@ -818,7 +819,7 @@ SanitizerInterceptor::findAllocInfoByAddress(uptr Address) { } std::vector -SanitizerInterceptor::findAllocInfoByContext(ur_context_handle_t Context) { +AsanInterceptor::findAllocInfoByContext(ur_context_handle_t Context) { std::shared_lock Guard(m_AllocationMapMutex); std::vector AllocInfos; for (auto It = m_AllocationMap.begin(); It != m_AllocationMap.end(); It++) { @@ -839,7 +840,7 @@ ContextInfo::~ContextInfo() { // check memory leaks std::vector AllocInfos = - getContext()->interceptor->findAllocInfoByContext(Handle); + getAsanInterceptor()->findAllocInfoByContext(Handle); for (const auto &It : AllocInfos) { const auto &[_, AI] = *It; if (!AI->IsReleased) { @@ -879,7 +880,7 @@ USMLaunchInfo::~USMLaunchInfo() { [[maybe_unused]] ur_result_t Result; if (Data) { auto Type = GetDeviceType(Context, Device); - auto ContextInfo = getContext()->interceptor->getContextInfo(Context); + auto ContextInfo = getAsanInterceptor()->getContextInfo(Context); if (Type == DeviceType::GPU_PVC || Type == DeviceType::GPU_DG2) { if (Data->PrivateShadowOffset) { ContextInfo->Stats.UpdateShadowFreed( @@ -911,4 +912,20 @@ USMLaunchInfo::~USMLaunchInfo() { assert(Result == UR_RESULT_SUCCESS); } +static AsanInterceptor *interceptor; + +AsanInterceptor *getAsanInterceptor() { return interceptor; } + +void initAsanInterceptor() { + if (interceptor) { + return; + } + interceptor = new AsanInterceptor(); +} + +void destroyAsanInterceptor() { + delete interceptor; + interceptor = nullptr; +} + } // namespace ur_sanitizer_layer diff --git a/source/loader/layers/sanitizer/asan_interceptor.hpp b/source/loader/layers/sanitizer/asan/asan_interceptor.hpp similarity index 98% rename from source/loader/layers/sanitizer/asan_interceptor.hpp rename to source/loader/layers/sanitizer/asan/asan_interceptor.hpp index e5429acd56..6ce88a06fe 100644 --- a/source/loader/layers/sanitizer/asan_interceptor.hpp +++ b/source/loader/layers/sanitizer/asan/asan_interceptor.hpp @@ -18,7 +18,7 @@ #include "asan_options.hpp" #include "asan_shadow.hpp" #include "asan_statistics.hpp" -#include "common.hpp" +#include "sanitizer_common/sanitizer_common.hpp" #include "ur_sanitizer_layer.hpp" #include @@ -186,11 +186,11 @@ struct DeviceGlobalInfo { uptr Addr; }; -class SanitizerInterceptor { +class AsanInterceptor { public: - explicit SanitizerInterceptor(); + explicit AsanInterceptor(); - ~SanitizerInterceptor(); + ~AsanInterceptor(); ur_result_t allocateMemory(ur_context_handle_t Context, ur_device_handle_t Device, @@ -322,4 +322,6 @@ class SanitizerInterceptor { ur_shared_mutex m_AdaptersMutex; }; +AsanInterceptor *getAsanInterceptor(); + } // namespace ur_sanitizer_layer diff --git a/source/loader/layers/sanitizer/asan_libdevice.hpp b/source/loader/layers/sanitizer/asan/asan_libdevice.hpp similarity index 96% rename from source/loader/layers/sanitizer/asan_libdevice.hpp rename to source/loader/layers/sanitizer/asan/asan_libdevice.hpp index 8eba929f34..2290263142 100644 --- a/source/loader/layers/sanitizer/asan_libdevice.hpp +++ b/source/loader/layers/sanitizer/asan/asan_libdevice.hpp @@ -1,23 +1,21 @@ /* * - * Copyright (C) 2023 Intel Corporation + * Copyright (C) 2024 Intel Corporation * * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. * See LICENSE.TXT * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception * - * @file device_sanitizer_report.hpp + * @file asan_libdevice.hpp * */ #pragma once -#include +#include "sanitizer_common/sanitizer_libdevice.hpp" namespace ur_sanitizer_layer { -enum class DeviceType : uint32_t { UNKNOWN = 0, CPU, GPU_PVC, GPU_DG2 }; - enum class DeviceSanitizerErrorType : int32_t { UNKNOWN, OUT_OF_BOUNDS, diff --git a/source/loader/layers/sanitizer/asan_options.cpp b/source/loader/layers/sanitizer/asan/asan_options.cpp similarity index 99% rename from source/loader/layers/sanitizer/asan_options.cpp rename to source/loader/layers/sanitizer/asan/asan_options.cpp index 5c42ab8fca..df443b5100 100644 --- a/source/loader/layers/sanitizer/asan_options.cpp +++ b/source/loader/layers/sanitizer/asan/asan_options.cpp @@ -11,6 +11,8 @@ */ #include "asan_options.hpp" + +#include "ur/ur.hpp" #include "ur_sanitizer_layer.hpp" #include diff --git a/source/loader/layers/sanitizer/asan_options.hpp b/source/loader/layers/sanitizer/asan/asan_options.hpp similarity index 96% rename from source/loader/layers/sanitizer/asan_options.hpp rename to source/loader/layers/sanitizer/asan/asan_options.hpp index 4c515e28fe..a084148e9a 100644 --- a/source/loader/layers/sanitizer/asan_options.hpp +++ b/source/loader/layers/sanitizer/asan/asan_options.hpp @@ -12,7 +12,7 @@ #pragma once -#include "common.hpp" +#include namespace ur_sanitizer_layer { diff --git a/source/loader/layers/sanitizer/asan_quarantine.cpp b/source/loader/layers/sanitizer/asan/asan_quarantine.cpp similarity index 100% rename from source/loader/layers/sanitizer/asan_quarantine.cpp rename to source/loader/layers/sanitizer/asan/asan_quarantine.cpp diff --git a/source/loader/layers/sanitizer/asan_quarantine.hpp b/source/loader/layers/sanitizer/asan/asan_quarantine.hpp similarity index 100% rename from source/loader/layers/sanitizer/asan_quarantine.hpp rename to source/loader/layers/sanitizer/asan/asan_quarantine.hpp diff --git a/source/loader/layers/sanitizer/asan_report.cpp b/source/loader/layers/sanitizer/asan/asan_report.cpp similarity index 97% rename from source/loader/layers/sanitizer/asan_report.cpp rename to source/loader/layers/sanitizer/asan/asan_report.cpp index c1a1230e78..1016982c2b 100644 --- a/source/loader/layers/sanitizer/asan_report.cpp +++ b/source/loader/layers/sanitizer/asan/asan_report.cpp @@ -16,8 +16,8 @@ #include "asan_libdevice.hpp" #include "asan_options.hpp" #include "asan_validator.hpp" +#include "sanitizer_common/sanitizer_utils.hpp" #include "ur_sanitizer_layer.hpp" -#include "ur_sanitizer_utils.hpp" namespace ur_sanitizer_layer { @@ -137,9 +137,9 @@ void ReportUseAfterFree(const DeviceSanitizerReport &Report, getContext()->logger.always(" #0 {} {}:{}", Func, File, Report.Line); getContext()->logger.always(""); - if (getContext()->interceptor->getOptions().MaxQuarantineSizeMB > 0) { + if (getAsanInterceptor()->getOptions().MaxQuarantineSizeMB > 0) { auto AllocInfoItOp = - getContext()->interceptor->findAllocInfoByAddress(Report.Address); + getAsanInterceptor()->findAllocInfoByAddress(Report.Address); if (!AllocInfoItOp) { getContext()->logger.always( diff --git a/source/loader/layers/sanitizer/asan_report.hpp b/source/loader/layers/sanitizer/asan/asan_report.hpp similarity index 96% rename from source/loader/layers/sanitizer/asan_report.hpp rename to source/loader/layers/sanitizer/asan/asan_report.hpp index e679b30c5d..a34f211d8f 100644 --- a/source/loader/layers/sanitizer/asan_report.hpp +++ b/source/loader/layers/sanitizer/asan/asan_report.hpp @@ -12,7 +12,7 @@ #pragma once -#include "common.hpp" +#include "sanitizer_common/sanitizer_common.hpp" #include diff --git a/source/loader/layers/sanitizer/asan_shadow.cpp b/source/loader/layers/sanitizer/asan/asan_shadow.cpp similarity index 98% rename from source/loader/layers/sanitizer/asan_shadow.cpp rename to source/loader/layers/sanitizer/asan/asan_shadow.cpp index 1f3ae18986..c5f402b58c 100644 --- a/source/loader/layers/sanitizer/asan_shadow.cpp +++ b/source/loader/layers/sanitizer/asan/asan_shadow.cpp @@ -13,8 +13,8 @@ #include "asan_shadow.hpp" #include "asan_interceptor.hpp" #include "asan_libdevice.hpp" +#include "sanitizer_common/sanitizer_utils.hpp" #include "ur_sanitizer_layer.hpp" -#include "ur_sanitizer_utils.hpp" namespace ur_sanitizer_layer { @@ -201,7 +201,7 @@ ur_result_t ShadowMemoryGPU::EnqueuePoisonShadow(ur_queue_handle_t Queue, } auto AllocInfoIt = - getContext()->interceptor->findAllocInfoByAddress(Ptr); + getAsanInterceptor()->findAllocInfoByAddress(Ptr); assert(AllocInfoIt); VirtualMemMaps[MappedPtr].second.insert((*AllocInfoIt)->second); } diff --git a/source/loader/layers/sanitizer/asan_shadow.hpp b/source/loader/layers/sanitizer/asan/asan_shadow.hpp similarity index 97% rename from source/loader/layers/sanitizer/asan_shadow.hpp rename to source/loader/layers/sanitizer/asan/asan_shadow.hpp index 7ae095062a..b88e757299 100644 --- a/source/loader/layers/sanitizer/asan_shadow.hpp +++ b/source/loader/layers/sanitizer/asan/asan_shadow.hpp @@ -12,8 +12,9 @@ #pragma once -#include "asan_allocator.hpp" -#include "common.hpp" +#include "asan/asan_allocator.hpp" +#include "sanitizer_common/sanitizer_libdevice.hpp" + #include namespace ur_sanitizer_layer { diff --git a/source/loader/layers/sanitizer/asan_statistics.cpp b/source/loader/layers/sanitizer/asan/asan_statistics.cpp similarity index 96% rename from source/loader/layers/sanitizer/asan_statistics.cpp rename to source/loader/layers/sanitizer/asan/asan_statistics.cpp index 82eef69c44..7f6a184ea6 100644 --- a/source/loader/layers/sanitizer/asan_statistics.cpp +++ b/source/loader/layers/sanitizer/asan/asan_statistics.cpp @@ -66,7 +66,7 @@ void AsanStats::UpdateUSMFreed(uptr FreedSize) { void AsanStats::UpdateUSMRealFreed(uptr FreedSize, uptr RedzoneSize) { UsmMalloced -= FreedSize; UsmMallocedRedzones -= RedzoneSize; - if (getContext()->interceptor->getOptions().MaxQuarantineSizeMB) { + if (getAsanInterceptor()->getOptions().MaxQuarantineSizeMB) { UsmFreed -= FreedSize; } getContext()->logger.debug( @@ -136,7 +136,7 @@ void AsanStatsWrapper::Print(ur_context_handle_t Context) { } AsanStatsWrapper::AsanStatsWrapper() : Stat(nullptr) { - if (getContext()->interceptor->getOptions().PrintStats) { + if (getAsanInterceptor()->getOptions().PrintStats) { Stat = new AsanStats; } } diff --git a/source/loader/layers/sanitizer/asan_statistics.hpp b/source/loader/layers/sanitizer/asan/asan_statistics.hpp similarity index 94% rename from source/loader/layers/sanitizer/asan_statistics.hpp rename to source/loader/layers/sanitizer/asan/asan_statistics.hpp index fab30e28c0..d229d6e334 100644 --- a/source/loader/layers/sanitizer/asan_statistics.hpp +++ b/source/loader/layers/sanitizer/asan/asan_statistics.hpp @@ -12,7 +12,7 @@ #pragma once -#include "common.hpp" +#include "sanitizer_common/sanitizer_common.hpp" namespace ur_sanitizer_layer { diff --git a/source/loader/layers/sanitizer/asan_validator.cpp b/source/loader/layers/sanitizer/asan/asan_validator.cpp similarity index 92% rename from source/loader/layers/sanitizer/asan_validator.cpp rename to source/loader/layers/sanitizer/asan/asan_validator.cpp index a9f2bd2b17..ec004e3eac 100644 --- a/source/loader/layers/sanitizer/asan_validator.cpp +++ b/source/loader/layers/sanitizer/asan/asan_validator.cpp @@ -12,7 +12,7 @@ #include "asan_validator.hpp" #include "asan_interceptor.hpp" -#include "ur_sanitizer_utils.hpp" +#include "sanitizer_common/sanitizer_utils.hpp" namespace ur_sanitizer_layer { @@ -38,9 +38,9 @@ ValidateUSMResult ValidateUSMPointer(ur_context_handle_t Context, ur_device_handle_t Device, uptr Ptr) { assert(Ptr != 0 && "Don't validate nullptr here"); - auto AllocInfoItOp = getContext()->interceptor->findAllocInfoByAddress(Ptr); + auto AllocInfoItOp = getAsanInterceptor()->findAllocInfoByAddress(Ptr); if (!AllocInfoItOp) { - auto DI = getContext()->interceptor->getDeviceInfo(Device); + auto DI = getAsanInterceptor()->getDeviceInfo(Device); bool IsSupportSharedSystemUSM = DI->IsSupportSharedSystemUSM; if (IsSupportSharedSystemUSM) { // maybe it's host pointer diff --git a/source/loader/layers/sanitizer/asan_validator.hpp b/source/loader/layers/sanitizer/asan/asan_validator.hpp similarity index 100% rename from source/loader/layers/sanitizer/asan_validator.hpp rename to source/loader/layers/sanitizer/asan/asan_validator.hpp diff --git a/source/loader/layers/sanitizer/linux/backtrace.cpp b/source/loader/layers/sanitizer/sanitizer_common/linux/backtrace.cpp similarity index 94% rename from source/loader/layers/sanitizer/linux/backtrace.cpp rename to source/loader/layers/sanitizer/sanitizer_common/linux/backtrace.cpp index b746348205..4f76d83f69 100644 --- a/source/loader/layers/sanitizer/linux/backtrace.cpp +++ b/source/loader/layers/sanitizer/sanitizer_common/linux/backtrace.cpp @@ -7,7 +7,7 @@ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception * */ -#include "stacktrace.hpp" +#include "sanitizer_common/sanitizer_stacktrace.hpp" #include #include diff --git a/source/loader/layers/sanitizer/linux/sanitizer_utils.cpp b/source/loader/layers/sanitizer/sanitizer_common/linux/sanitizer_utils.cpp similarity index 97% rename from source/loader/layers/sanitizer/linux/sanitizer_utils.cpp rename to source/loader/layers/sanitizer/sanitizer_common/linux/sanitizer_utils.cpp index d0bc038174..6bae7dde17 100644 --- a/source/loader/layers/sanitizer/linux/sanitizer_utils.cpp +++ b/source/loader/layers/sanitizer/sanitizer_common/linux/sanitizer_utils.cpp @@ -11,7 +11,7 @@ * */ -#include "common.hpp" +#include "sanitizer_common/sanitizer_common.hpp" #include "ur_sanitizer_layer.hpp" #include diff --git a/source/loader/layers/sanitizer/linux/symbolizer.cpp b/source/loader/layers/sanitizer/sanitizer_common/linux/symbolizer.cpp similarity index 100% rename from source/loader/layers/sanitizer/linux/symbolizer.cpp rename to source/loader/layers/sanitizer/sanitizer_common/linux/symbolizer.cpp diff --git a/source/loader/layers/sanitizer/common.hpp b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_common.hpp similarity index 93% rename from source/loader/layers/sanitizer/common.hpp rename to source/loader/layers/sanitizer/sanitizer_common/sanitizer_common.hpp index ea5e33ed4b..3bbe368f04 100644 --- a/source/loader/layers/sanitizer/common.hpp +++ b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_common.hpp @@ -12,7 +12,6 @@ #pragma once -#include "asan_libdevice.hpp" #include "ur/ur.hpp" #include "ur_ddi.h" @@ -138,21 +137,6 @@ struct SourceInfo { int column = 0; }; -inline const char *ToString(DeviceType Type) { - switch (Type) { - case DeviceType::UNKNOWN: - return "UNKNOWN"; - case DeviceType::CPU: - return "CPU"; - case DeviceType::GPU_PVC: - return "PVC"; - case DeviceType::GPU_DG2: - return "DG2"; - default: - return "UNKNOWN"; - } -} - bool IsInASanContext(); uptr MmapNoReserve(uptr Addr, uptr Size); diff --git a/source/loader/layers/sanitizer/sanitizer_common/sanitizer_libdevice.hpp b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_libdevice.hpp new file mode 100644 index 0000000000..07698de70d --- /dev/null +++ b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_libdevice.hpp @@ -0,0 +1,36 @@ +/* + * + * Copyright (C) 2024 Intel Corporation + * + * Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. + * See LICENSE.TXT + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * @file sanitizer_libdevice.hpp + * + */ + +#pragma once + +#include + +namespace ur_sanitizer_layer { + +enum class DeviceType : uint32_t { UNKNOWN = 0, CPU, GPU_PVC, GPU_DG2 }; + +inline const char *ToString(DeviceType Type) { + switch (Type) { + case DeviceType::UNKNOWN: + return "UNKNOWN"; + case DeviceType::CPU: + return "CPU"; + case DeviceType::GPU_PVC: + return "PVC"; + case DeviceType::GPU_DG2: + return "DG2"; + default: + return "UNKNOWN"; + } +} + +} // namespace ur_sanitizer_layer diff --git a/source/loader/layers/sanitizer/stacktrace.cpp b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_stacktrace.cpp similarity index 98% rename from source/loader/layers/sanitizer/stacktrace.cpp rename to source/loader/layers/sanitizer/sanitizer_common/sanitizer_stacktrace.cpp index 8adaa2cd34..e8cd4b1089 100644 --- a/source/loader/layers/sanitizer/stacktrace.cpp +++ b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_stacktrace.cpp @@ -6,11 +6,11 @@ * See LICENSE.TXT * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception * - * @file stacktrace.cpp + * @file sanitizer_stacktrace.cpp * */ -#include "stacktrace.hpp" +#include "sanitizer_stacktrace.hpp" #include "ur_sanitizer_layer.hpp" extern "C" { diff --git a/source/loader/layers/sanitizer/stacktrace.hpp b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_stacktrace.hpp similarity index 89% rename from source/loader/layers/sanitizer/stacktrace.hpp rename to source/loader/layers/sanitizer/sanitizer_common/sanitizer_stacktrace.hpp index 57811bba01..41443ee78d 100644 --- a/source/loader/layers/sanitizer/stacktrace.hpp +++ b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_stacktrace.hpp @@ -6,13 +6,13 @@ * See LICENSE.TXT * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception * - * @file stacktrace.hpp + * @file sanitizer_stacktrace.hpp * */ #pragma once -#include "common.hpp" +#include "sanitizer_common.hpp" #include diff --git a/source/loader/layers/sanitizer/ur_sanitizer_utils.cpp b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_utils.cpp similarity index 98% rename from source/loader/layers/sanitizer/ur_sanitizer_utils.cpp rename to source/loader/layers/sanitizer/sanitizer_common/sanitizer_utils.cpp index 53e4326ed4..900eae405b 100644 --- a/source/loader/layers/sanitizer/ur_sanitizer_utils.cpp +++ b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_utils.cpp @@ -6,11 +6,12 @@ * See LICENSE.TXT * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception * - * @file ur_sanitizer_utils.cpp + * @file sanitizer_utils.cpp * */ -#include "ur_sanitizer_utils.hpp" +#include "sanitizer_utils.hpp" +#include "sanitizer_common/sanitizer_common.hpp" #include "ur_sanitizer_layer.hpp" namespace ur_sanitizer_layer { diff --git a/source/loader/layers/sanitizer/ur_sanitizer_utils.hpp b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_utils.hpp similarity index 95% rename from source/loader/layers/sanitizer/ur_sanitizer_utils.hpp rename to source/loader/layers/sanitizer/sanitizer_common/sanitizer_utils.hpp index a04886e5e5..6fcb05894e 100644 --- a/source/loader/layers/sanitizer/ur_sanitizer_utils.hpp +++ b/source/loader/layers/sanitizer/sanitizer_common/sanitizer_utils.hpp @@ -6,13 +6,17 @@ * See LICENSE.TXT * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception * - * @file ur_sanitizer_utils.hpp + * @file sanitizer_utils.hpp * */ #pragma once -#include "common.hpp" +#include "sanitizer_libdevice.hpp" +#include "ur_api.h" + +#include +#include namespace ur_sanitizer_layer { diff --git a/source/loader/layers/sanitizer/ur_sanddi.cpp b/source/loader/layers/sanitizer/ur_sanddi.cpp index 95b1649691..6f0aec4cd0 100644 --- a/source/loader/layers/sanitizer/ur_sanddi.cpp +++ b/source/loader/layers/sanitizer/ur_sanddi.cpp @@ -10,1807 +10,14 @@ * */ -#include "asan_interceptor.hpp" -#include "asan_options.hpp" -#include "stacktrace.hpp" +#include "asan/asan_ddi.hpp" +#include "asan/asan_interceptor.hpp" #include "ur_sanitizer_layer.hpp" -#include "ur_sanitizer_utils.hpp" #include namespace ur_sanitizer_layer { -namespace { - -ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices, - const ur_device_handle_t *phDevices) { - std::shared_ptr CI; - UR_CALL(getContext()->interceptor->insertContext(Context, CI)); - for (uint32_t i = 0; i < numDevices; ++i) { - auto hDevice = phDevices[i]; - std::shared_ptr DI; - UR_CALL(getContext()->interceptor->insertDevice(hDevice, DI)); - DI->Type = GetDeviceType(Context, hDevice); - if (DI->Type == DeviceType::UNKNOWN) { - getContext()->logger.error("Unsupport device"); - return UR_RESULT_ERROR_INVALID_DEVICE; - } - getContext()->logger.info( - "DeviceInfo {} (Type={}, IsSupportSharedSystemUSM={})", - (void *)DI->Handle, ToString(DI->Type), - DI->IsSupportSharedSystemUSM); - getContext()->logger.info("Add {} into context {}", (void *)DI->Handle, - (void *)Context); - if (!DI->Shadow) { - UR_CALL(DI->allocShadowMemory(Context)); - } - CI->DeviceList.emplace_back(hDevice); - CI->AllocInfosMap[hDevice]; - } - return UR_RESULT_SUCCESS; -} - -} // namespace - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urAdapterGet -__urdlllocal ur_result_t UR_APICALL urAdapterGet( - uint32_t - NumEntries, ///< [in] the number of adapters to be added to phAdapters. - ///< If phAdapters is not NULL, then NumEntries should be greater than - ///< zero, otherwise ::UR_RESULT_ERROR_INVALID_SIZE, - ///< will be returned. - ur_adapter_handle_t * - phAdapters, ///< [out][optional][range(0, NumEntries)] array of handle of adapters. - ///< If NumEntries is less than the number of adapters available, then - ///< ::urAdapterGet shall only retrieve that number of platforms. - uint32_t * - pNumAdapters ///< [out][optional] returns the total number of adapters available. -) { - auto pfnAdapterGet = getContext()->urDdiTable.Global.pfnAdapterGet; - - if (nullptr == pfnAdapterGet) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - ur_result_t result = pfnAdapterGet(NumEntries, phAdapters, pNumAdapters); - if (result == UR_RESULT_SUCCESS && phAdapters) { - const uint32_t NumAdapters = pNumAdapters ? *pNumAdapters : NumEntries; - for (uint32_t i = 0; i < NumAdapters; ++i) { - UR_CALL(getContext()->interceptor->holdAdapter(phAdapters[i])); - } - } - - return result; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urUSMHostAlloc -__urdlllocal ur_result_t UR_APICALL urUSMHostAlloc( - ur_context_handle_t hContext, ///< [in] handle of the context object - const ur_usm_desc_t - *pUSMDesc, ///< [in][optional] USM memory allocation descriptor - ur_usm_pool_handle_t - pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate - size_t - size, ///< [in] size in bytes of the USM memory object to be allocated - void **ppMem ///< [out] pointer to USM host memory object -) { - auto pfnHostAlloc = getContext()->urDdiTable.USM.pfnHostAlloc; - - if (nullptr == pfnHostAlloc) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urUSMHostAlloc"); - - return getContext()->interceptor->allocateMemory( - hContext, nullptr, pUSMDesc, pool, size, AllocType::HOST_USM, ppMem); -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urUSMDeviceAlloc -__urdlllocal ur_result_t UR_APICALL urUSMDeviceAlloc( - ur_context_handle_t hContext, ///< [in] handle of the context object - ur_device_handle_t hDevice, ///< [in] handle of the device object - const ur_usm_desc_t - *pUSMDesc, ///< [in][optional] USM memory allocation descriptor - ur_usm_pool_handle_t - pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate - size_t - size, ///< [in] size in bytes of the USM memory object to be allocated - void **ppMem ///< [out] pointer to USM device memory object -) { - auto pfnDeviceAlloc = getContext()->urDdiTable.USM.pfnDeviceAlloc; - - if (nullptr == pfnDeviceAlloc) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urUSMDeviceAlloc"); - - return getContext()->interceptor->allocateMemory( - hContext, hDevice, pUSMDesc, pool, size, AllocType::DEVICE_USM, ppMem); -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urUSMSharedAlloc -__urdlllocal ur_result_t UR_APICALL urUSMSharedAlloc( - ur_context_handle_t hContext, ///< [in] handle of the context object - ur_device_handle_t hDevice, ///< [in] handle of the device object - const ur_usm_desc_t * - pUSMDesc, ///< [in][optional] Pointer to USM memory allocation descriptor. - ur_usm_pool_handle_t - pool, ///< [in][optional] Pointer to a pool created using urUSMPoolCreate - size_t - size, ///< [in] size in bytes of the USM memory object to be allocated - void **ppMem ///< [out] pointer to USM shared memory object -) { - auto pfnSharedAlloc = getContext()->urDdiTable.USM.pfnSharedAlloc; - - if (nullptr == pfnSharedAlloc) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urUSMSharedAlloc"); - - return getContext()->interceptor->allocateMemory( - hContext, hDevice, pUSMDesc, pool, size, AllocType::SHARED_USM, ppMem); -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urUSMFree -__urdlllocal ur_result_t UR_APICALL urUSMFree( - ur_context_handle_t hContext, ///< [in] handle of the context object - void *pMem ///< [in] pointer to USM memory object -) { - auto pfnFree = getContext()->urDdiTable.USM.pfnFree; - - if (nullptr == pfnFree) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urUSMFree"); - - return getContext()->interceptor->releaseMemory(hContext, pMem); -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urProgramCreateWithIL -__urdlllocal ur_result_t UR_APICALL urProgramCreateWithIL( - ur_context_handle_t hContext, ///< [in] handle of the context instance - const void *pIL, ///< [in] pointer to IL binary. - size_t length, ///< [in] length of `pIL` in bytes. - const ur_program_properties_t * - pProperties, ///< [in][optional] pointer to program creation properties. - ur_program_handle_t - *phProgram ///< [out] pointer to handle of program object created. -) { - auto pfnProgramCreateWithIL = - getContext()->urDdiTable.Program.pfnCreateWithIL; - - if (nullptr == pfnProgramCreateWithIL) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urProgramCreateWithIL"); - - UR_CALL( - pfnProgramCreateWithIL(hContext, pIL, length, pProperties, phProgram)); - UR_CALL(getContext()->interceptor->insertProgram(*phProgram)); - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urProgramCreateWithBinary -__urdlllocal ur_result_t UR_APICALL urProgramCreateWithBinary( - ur_context_handle_t hContext, ///< [in] handle of the context instance - uint32_t numDevices, ///< [in] number of devices - ur_device_handle_t * - phDevices, ///< [in][range(0, numDevices)] a pointer to a list of device handles. The - ///< binaries are loaded for devices specified in this list. - size_t * - pLengths, ///< [in][range(0, numDevices)] array of sizes of program binaries - ///< specified by `pBinaries` (in bytes). - const uint8_t ** - ppBinaries, ///< [in][range(0, numDevices)] pointer to program binaries to be loaded - ///< for devices specified by `phDevices`. - const ur_program_properties_t * - pProperties, ///< [in][optional] pointer to program creation properties. - ur_program_handle_t - *phProgram ///< [out] pointer to handle of Program object created. -) { - auto pfnProgramCreateWithBinary = - getContext()->urDdiTable.Program.pfnCreateWithBinary; - - if (nullptr == pfnProgramCreateWithBinary) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urProgramCreateWithBinary"); - - UR_CALL(pfnProgramCreateWithBinary(hContext, numDevices, phDevices, - pLengths, ppBinaries, pProperties, - phProgram)); - UR_CALL(getContext()->interceptor->insertProgram(*phProgram)); - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urProgramCreateWithNativeHandle -__urdlllocal ur_result_t UR_APICALL urProgramCreateWithNativeHandle( - ur_native_handle_t - hNativeProgram, ///< [in][nocheck] the native handle of the program. - ur_context_handle_t hContext, ///< [in] handle of the context instance - const ur_program_native_properties_t * - pProperties, ///< [in][optional] pointer to native program properties struct. - ur_program_handle_t * - phProgram ///< [out] pointer to the handle of the program object created. -) { - auto pfnProgramCreateWithNativeHandle = - getContext()->urDdiTable.Program.pfnCreateWithNativeHandle; - - if (nullptr == pfnProgramCreateWithNativeHandle) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urProgramCreateWithNativeHandle"); - - UR_CALL(pfnProgramCreateWithNativeHandle(hNativeProgram, hContext, - pProperties, phProgram)); - UR_CALL(getContext()->interceptor->insertProgram(*phProgram)); - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urProgramRetain -__urdlllocal ur_result_t UR_APICALL urProgramRetain( - ur_program_handle_t - hProgram ///< [in][retain] handle for the Program to retain -) { - auto pfnRetain = getContext()->urDdiTable.Program.pfnRetain; - - if (nullptr == pfnRetain) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urProgramRetain"); - - UR_CALL(pfnRetain(hProgram)); - - auto ProgramInfo = getContext()->interceptor->getProgramInfo(hProgram); - UR_ASSERT(ProgramInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - ProgramInfo->RefCount++; - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urProgramBuild -__urdlllocal ur_result_t UR_APICALL urProgramBuild( - ur_context_handle_t hContext, ///< [in] handle of the context object - ur_program_handle_t hProgram, ///< [in] handle of the program object - const char *pOptions ///< [in] string of build options -) { - auto pfnProgramBuild = getContext()->urDdiTable.Program.pfnBuild; - - if (nullptr == pfnProgramBuild) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urProgramBuild"); - - UR_CALL(pfnProgramBuild(hContext, hProgram, pOptions)); - - UR_CALL(getContext()->interceptor->registerProgram(hContext, hProgram)); - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urProgramBuildExp -__urdlllocal ur_result_t UR_APICALL urProgramBuildExp( - ur_program_handle_t hProgram, ///< [in] Handle of the program to build. - uint32_t numDevices, ///< [in] number of devices - ur_device_handle_t * - phDevices, ///< [in][range(0, numDevices)] pointer to array of device handles - const char * - pOptions ///< [in][optional] pointer to build options null-terminated string. -) { - auto pfnBuildExp = getContext()->urDdiTable.ProgramExp.pfnBuildExp; - - if (nullptr == pfnBuildExp) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urProgramBuildExp"); - - UR_CALL(pfnBuildExp(hProgram, numDevices, phDevices, pOptions)); - UR_CALL(getContext()->interceptor->registerProgram(GetContext(hProgram), - hProgram)); - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urProgramLink -__urdlllocal ur_result_t UR_APICALL urProgramLink( - ur_context_handle_t hContext, ///< [in] handle of the context instance. - uint32_t count, ///< [in] number of program handles in `phPrograms`. - const ur_program_handle_t * - phPrograms, ///< [in][range(0, count)] pointer to array of program handles. - const char * - pOptions, ///< [in][optional] pointer to linker options null-terminated string. - ur_program_handle_t - *phProgram ///< [out] pointer to handle of program object created. -) { - auto pfnProgramLink = getContext()->urDdiTable.Program.pfnLink; - - if (nullptr == pfnProgramLink) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urProgramLink"); - - UR_CALL(pfnProgramLink(hContext, count, phPrograms, pOptions, phProgram)); - - UR_CALL(getContext()->interceptor->registerProgram(hContext, *phProgram)); - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urProgramLinkExp -ur_result_t UR_APICALL urProgramLinkExp( - ur_context_handle_t hContext, ///< [in] handle of the context instance. - uint32_t numDevices, ///< [in] number of devices - ur_device_handle_t * - phDevices, ///< [in][range(0, numDevices)] pointer to array of device handles - uint32_t count, ///< [in] number of program handles in `phPrograms`. - const ur_program_handle_t * - phPrograms, ///< [in][range(0, count)] pointer to array of program handles. - const char * - pOptions, ///< [in][optional] pointer to linker options null-terminated string. - ur_program_handle_t - *phProgram ///< [out] pointer to handle of program object created. -) { - auto pfnProgramLinkExp = getContext()->urDdiTable.ProgramExp.pfnLinkExp; - - if (nullptr == pfnProgramLinkExp) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urProgramLinkExp"); - - UR_CALL(pfnProgramLinkExp(hContext, numDevices, phDevices, count, - phPrograms, pOptions, phProgram)); - - UR_CALL(getContext()->interceptor->registerProgram(hContext, *phProgram)); - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urProgramRelease -ur_result_t UR_APICALL urProgramRelease( - ur_program_handle_t - hProgram ///< [in][release] handle for the Program to release -) { - auto pfnProgramRelease = getContext()->urDdiTable.Program.pfnRelease; - - if (nullptr == pfnProgramRelease) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urProgramRelease"); - - UR_CALL(pfnProgramRelease(hProgram)); - - auto ProgramInfo = getContext()->interceptor->getProgramInfo(hProgram); - UR_ASSERT(ProgramInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - if (--ProgramInfo->RefCount == 0) { - UR_CALL(getContext()->interceptor->unregisterProgram(hProgram)); - UR_CALL(getContext()->interceptor->eraseProgram(hProgram)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueKernelLaunch -__urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch( - ur_queue_handle_t hQueue, ///< [in] handle of the queue object - ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object - uint32_t - workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and - ///< work-group work-items - const size_t * - pGlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned values that specify the - ///< offset used to calculate the global ID of a work-item - const size_t * - pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the - ///< number of global work-items in workDim that will execute the kernel - ///< function - const size_t * - pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that - ///< specify the number of local work-items forming a work-group that will - ///< execute the kernel function. - ///< If nullptr, the runtime implementation will choose the work-group - ///< size. - uint32_t numEventsInWaitList, ///< [in] size of the event wait list - const ur_event_handle_t * - phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of - ///< events that must be complete before the kernel execution. - ///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait - ///< event. - ur_event_handle_t * - phEvent ///< [out][optional] return an event object that identifies this particular - ///< kernel execution instance. -) { - auto pfnKernelLaunch = getContext()->urDdiTable.Enqueue.pfnKernelLaunch; - - if (nullptr == pfnKernelLaunch) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urEnqueueKernelLaunch"); - - USMLaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue), - pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset, - workDim); - UR_CALL(LaunchInfo.initialize()); - - UR_CALL(getContext()->interceptor->preLaunchKernel(hKernel, hQueue, - LaunchInfo)); - - ur_event_handle_t hEvent{}; - ur_result_t result = - pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset, - pGlobalWorkSize, LaunchInfo.LocalWorkSize.data(), - numEventsInWaitList, phEventWaitList, &hEvent); - - if (result == UR_RESULT_SUCCESS) { - UR_CALL(getContext()->interceptor->postLaunchKernel(hKernel, hQueue, - LaunchInfo)); - } - - if (phEvent) { - *phEvent = hEvent; - } - - return result; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urContextCreate -__urdlllocal ur_result_t UR_APICALL urContextCreate( - uint32_t numDevices, ///< [in] the number of devices given in phDevices - const ur_device_handle_t - *phDevices, ///< [in][range(0, numDevices)] array of handle of devices. - const ur_context_properties_t * - pProperties, ///< [in][optional] pointer to context creation properties. - ur_context_handle_t - *phContext ///< [out] pointer to handle of context object created -) { - auto pfnCreate = getContext()->urDdiTable.Context.pfnCreate; - - if (nullptr == pfnCreate) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urContextCreate"); - - ur_result_t result = - pfnCreate(numDevices, phDevices, pProperties, phContext); - - if (result == UR_RESULT_SUCCESS) { - UR_CALL(setupContext(*phContext, numDevices, phDevices)); - } - - return result; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urContextCreateWithNativeHandle -__urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle( - ur_native_handle_t - hNativeContext, ///< [in][nocheck] the native handle of the getContext()-> - ur_adapter_handle_t hAdapter, - uint32_t numDevices, ///< [in] number of devices associated with the context - const ur_device_handle_t * - phDevices, ///< [in][range(0, numDevices)] list of devices associated with the context - const ur_context_native_properties_t * - pProperties, ///< [in][optional] pointer to native context properties struct - ur_context_handle_t * - phContext ///< [out] pointer to the handle of the context object created. -) { - auto pfnCreateWithNativeHandle = - getContext()->urDdiTable.Context.pfnCreateWithNativeHandle; - - if (nullptr == pfnCreateWithNativeHandle) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urContextCreateWithNativeHandle"); - - ur_result_t result = - pfnCreateWithNativeHandle(hNativeContext, hAdapter, numDevices, - phDevices, pProperties, phContext); - - if (result == UR_RESULT_SUCCESS) { - UR_CALL(setupContext(*phContext, numDevices, phDevices)); - } - - return result; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urContextRetain -__urdlllocal ur_result_t UR_APICALL urContextRetain( - ur_context_handle_t - hContext ///< [in] handle of the context to get a reference of. -) { - auto pfnRetain = getContext()->urDdiTable.Context.pfnRetain; - - if (nullptr == pfnRetain) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urContextRetain"); - - UR_CALL(pfnRetain(hContext)); - - auto ContextInfo = getContext()->interceptor->getContextInfo(hContext); - UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - ContextInfo->RefCount++; - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urContextRelease -__urdlllocal ur_result_t UR_APICALL urContextRelease( - ur_context_handle_t hContext ///< [in] handle of the context to release. -) { - auto pfnRelease = getContext()->urDdiTable.Context.pfnRelease; - - if (nullptr == pfnRelease) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urContextRelease"); - - UR_CALL(pfnRelease(hContext)); - - auto ContextInfo = getContext()->interceptor->getContextInfo(hContext); - UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - if (--ContextInfo->RefCount == 0) { - UR_CALL(getContext()->interceptor->eraseContext(hContext)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urMemBufferCreate -__urdlllocal ur_result_t UR_APICALL urMemBufferCreate( - ur_context_handle_t hContext, ///< [in] handle of the context object - ur_mem_flags_t flags, ///< [in] allocation and usage information flags - size_t size, ///< [in] size in bytes of the memory object to be allocated - const ur_buffer_properties_t - *pProperties, ///< [in][optional] pointer to buffer creation properties - ur_mem_handle_t - *phBuffer ///< [out] pointer to handle of the memory buffer created -) { - auto pfnBufferCreate = getContext()->urDdiTable.Mem.pfnBufferCreate; - - if (nullptr == pfnBufferCreate) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - if (nullptr == phBuffer) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - getContext()->logger.debug("==== urMemBufferCreate"); - - void *Host = nullptr; - if (pProperties) { - Host = pProperties->pHost; - } - - char *hostPtrOrNull = (flags & UR_MEM_FLAG_USE_HOST_POINTER) - ? ur_cast(Host) - : nullptr; - - std::shared_ptr pMemBuffer = - std::make_shared(hContext, size, hostPtrOrNull); - - if (Host && (flags & UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER)) { - std::shared_ptr CtxInfo = - getContext()->interceptor->getContextInfo(hContext); - for (const auto &hDevice : CtxInfo->DeviceList) { - ManagedQueue InternalQueue(hContext, hDevice); - char *Handle = nullptr; - UR_CALL(pMemBuffer->getHandle(hDevice, Handle)); - UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( - InternalQueue, true, Handle, Host, size, 0, nullptr, nullptr)); - } - } - - ur_result_t result = getContext()->interceptor->insertMemBuffer(pMemBuffer); - *phBuffer = ur_cast(pMemBuffer.get()); - - return result; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urMemGetInfo -__urdlllocal ur_result_t UR_APICALL urMemGetInfo( - ur_mem_handle_t - hMemory, ///< [in] handle to the memory object being queried. - ur_mem_info_t propName, ///< [in] type of the info to retrieve. - size_t - propSize, ///< [in] the number of bytes of memory pointed to by pPropValue. - void * - pPropValue, ///< [out][optional][typename(propName, propSize)] array of bytes holding - ///< the info. - ///< If propSize is less than the real number of bytes needed to return - ///< the info then the ::UR_RESULT_ERROR_INVALID_SIZE error is returned and - ///< pPropValue is not used. - size_t * - pPropSizeRet ///< [out][optional] pointer to the actual size in bytes of the queried propName. -) { - auto pfnGetInfo = getContext()->urDdiTable.Mem.pfnGetInfo; - - if (nullptr == pfnGetInfo) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urMemGetInfo"); - - if (auto MemBuffer = getContext()->interceptor->getMemBuffer(hMemory)) { - UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet); - switch (propName) { - case UR_MEM_INFO_CONTEXT: { - return ReturnValue(MemBuffer->Context); - } - case UR_MEM_INFO_SIZE: { - return ReturnValue(size_t{MemBuffer->Size}); - } - default: { - return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; - } - } - } else { - UR_CALL( - pfnGetInfo(hMemory, propName, propSize, pPropValue, pPropSizeRet)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urMemRetain -__urdlllocal ur_result_t UR_APICALL urMemRetain( - ur_mem_handle_t hMem ///< [in] handle of the memory object to get access -) { - auto pfnRetain = getContext()->urDdiTable.Mem.pfnRetain; - - if (nullptr == pfnRetain) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urMemRetain"); - - if (auto MemBuffer = getContext()->interceptor->getMemBuffer(hMem)) { - MemBuffer->RefCount++; - } else { - UR_CALL(pfnRetain(hMem)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urMemRelease -__urdlllocal ur_result_t UR_APICALL urMemRelease( - ur_mem_handle_t hMem ///< [in] handle of the memory object to release -) { - auto pfnRelease = getContext()->urDdiTable.Mem.pfnRelease; - - if (nullptr == pfnRelease) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urMemRelease"); - - if (auto MemBuffer = getContext()->interceptor->getMemBuffer(hMem)) { - if (--MemBuffer->RefCount != 0) { - return UR_RESULT_SUCCESS; - } - UR_CALL(MemBuffer->free()); - UR_CALL(getContext()->interceptor->eraseMemBuffer(hMem)); - } else { - UR_CALL(pfnRelease(hMem)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urMemBufferPartition -__urdlllocal ur_result_t UR_APICALL urMemBufferPartition( - ur_mem_handle_t - hBuffer, ///< [in] handle of the buffer object to allocate from - ur_mem_flags_t flags, ///< [in] allocation and usage information flags - ur_buffer_create_type_t bufferCreateType, ///< [in] buffer creation type - const ur_buffer_region_t - *pRegion, ///< [in] pointer to buffer create region information - ur_mem_handle_t - *phMem ///< [out] pointer to the handle of sub buffer created -) { - auto pfnBufferPartition = getContext()->urDdiTable.Mem.pfnBufferPartition; - - if (nullptr == pfnBufferPartition) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urMemBufferPartition"); - - if (auto ParentBuffer = getContext()->interceptor->getMemBuffer(hBuffer)) { - if (ParentBuffer->Size < (pRegion->origin + pRegion->size)) { - return UR_RESULT_ERROR_INVALID_BUFFER_SIZE; - } - std::shared_ptr SubBuffer = std::make_shared( - ParentBuffer, pRegion->origin, pRegion->size); - UR_CALL(getContext()->interceptor->insertMemBuffer(SubBuffer)); - *phMem = reinterpret_cast(SubBuffer.get()); - } else { - UR_CALL(pfnBufferPartition(hBuffer, flags, bufferCreateType, pRegion, - phMem)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urMemGetNativeHandle -__urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( - ur_mem_handle_t hMem, ///< [in] handle of the mem. - ur_device_handle_t hDevice, - ur_native_handle_t - *phNativeMem ///< [out] a pointer to the native handle of the mem. -) { - auto pfnGetNativeHandle = getContext()->urDdiTable.Mem.pfnGetNativeHandle; - - if (nullptr == pfnGetNativeHandle) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urMemGetNativeHandle"); - - if (auto MemBuffer = getContext()->interceptor->getMemBuffer(hMem)) { - char *Handle = nullptr; - UR_CALL(MemBuffer->getHandle(hDevice, Handle)); - *phNativeMem = ur_cast(Handle); - } else { - UR_CALL(pfnGetNativeHandle(hMem, hDevice, phNativeMem)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueMemBufferRead -__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferRead( - ur_queue_handle_t hQueue, ///< [in] handle of the queue object - ur_mem_handle_t - hBuffer, ///< [in][bounds(offset, size)] handle of the buffer object - bool blockingRead, ///< [in] indicates blocking (true), non-blocking (false) - size_t offset, ///< [in] offset in bytes in the buffer object - size_t size, ///< [in] size in bytes of data being read - void *pDst, ///< [in] pointer to host memory where data is to be read into - uint32_t numEventsInWaitList, ///< [in] size of the event wait list - const ur_event_handle_t * - phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of - ///< events that must be complete before this command can be executed. - ///< If nullptr, the numEventsInWaitList must be 0, indicating that this - ///< command does not wait on any event to complete. - ur_event_handle_t * - phEvent ///< [out][optional] return an event object that identifies this particular - ///< command instance. -) { - auto pfnMemBufferRead = getContext()->urDdiTable.Enqueue.pfnMemBufferRead; - - if (nullptr == pfnMemBufferRead) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urEnqueueMemBufferRead"); - - if (auto MemBuffer = getContext()->interceptor->getMemBuffer(hBuffer)) { - ur_device_handle_t Device = GetDevice(hQueue); - char *pSrc = nullptr; - UR_CALL(MemBuffer->getHandle(Device, pSrc)); - UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( - hQueue, blockingRead, pDst, pSrc + offset, size, - numEventsInWaitList, phEventWaitList, phEvent)); - } else { - UR_CALL(pfnMemBufferRead(hQueue, hBuffer, blockingRead, offset, size, - pDst, numEventsInWaitList, phEventWaitList, - phEvent)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueMemBufferWrite -__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWrite( - ur_queue_handle_t hQueue, ///< [in] handle of the queue object - ur_mem_handle_t - hBuffer, ///< [in][bounds(offset, size)] handle of the buffer object - bool - blockingWrite, ///< [in] indicates blocking (true), non-blocking (false) - size_t offset, ///< [in] offset in bytes in the buffer object - size_t size, ///< [in] size in bytes of data being written - const void - *pSrc, ///< [in] pointer to host memory where data is to be written from - uint32_t numEventsInWaitList, ///< [in] size of the event wait list - const ur_event_handle_t * - phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of - ///< events that must be complete before this command can be executed. - ///< If nullptr, the numEventsInWaitList must be 0, indicating that this - ///< command does not wait on any event to complete. - ur_event_handle_t * - phEvent ///< [out][optional] return an event object that identifies this particular - ///< command instance. -) { - auto pfnMemBufferWrite = getContext()->urDdiTable.Enqueue.pfnMemBufferWrite; - - if (nullptr == pfnMemBufferWrite) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urEnqueueMemBufferWrite"); - - if (auto MemBuffer = getContext()->interceptor->getMemBuffer(hBuffer)) { - ur_device_handle_t Device = GetDevice(hQueue); - char *pDst = nullptr; - UR_CALL(MemBuffer->getHandle(Device, pDst)); - UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( - hQueue, blockingWrite, pDst + offset, pSrc, size, - numEventsInWaitList, phEventWaitList, phEvent)); - } else { - UR_CALL(pfnMemBufferWrite(hQueue, hBuffer, blockingWrite, offset, size, - pSrc, numEventsInWaitList, phEventWaitList, - phEvent)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueMemBufferReadRect -__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferReadRect( - ur_queue_handle_t hQueue, ///< [in] handle of the queue object - ur_mem_handle_t - hBuffer, ///< [in][bounds(bufferOrigin, region)] handle of the buffer object - bool blockingRead, ///< [in] indicates blocking (true), non-blocking (false) - ur_rect_offset_t bufferOrigin, ///< [in] 3D offset in the buffer - ur_rect_offset_t hostOrigin, ///< [in] 3D offset in the host region - ur_rect_region_t - region, ///< [in] 3D rectangular region descriptor: width, height, depth - size_t - bufferRowPitch, ///< [in] length of each row in bytes in the buffer object - size_t - bufferSlicePitch, ///< [in] length of each 2D slice in bytes in the buffer object being read - size_t - hostRowPitch, ///< [in] length of each row in bytes in the host memory region pointed by - ///< dst - size_t - hostSlicePitch, ///< [in] length of each 2D slice in bytes in the host memory region - ///< pointed by dst - void *pDst, ///< [in] pointer to host memory where data is to be read into - uint32_t numEventsInWaitList, ///< [in] size of the event wait list - const ur_event_handle_t * - phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of - ///< events that must be complete before this command can be executed. - ///< If nullptr, the numEventsInWaitList must be 0, indicating that this - ///< command does not wait on any event to complete. - ur_event_handle_t * - phEvent ///< [out][optional] return an event object that identifies this particular - ///< command instance. -) { - auto pfnMemBufferReadRect = - getContext()->urDdiTable.Enqueue.pfnMemBufferReadRect; - - if (nullptr == pfnMemBufferReadRect) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urEnqueueMemBufferReadRect"); - - if (auto MemBuffer = getContext()->interceptor->getMemBuffer(hBuffer)) { - char *SrcHandle = nullptr; - ur_device_handle_t Device = GetDevice(hQueue); - UR_CALL(MemBuffer->getHandle(Device, SrcHandle)); - - UR_CALL(EnqueueMemCopyRectHelper( - hQueue, SrcHandle, ur_cast(pDst), bufferOrigin, hostOrigin, - region, bufferRowPitch, bufferSlicePitch, hostRowPitch, - hostSlicePitch, blockingRead, numEventsInWaitList, phEventWaitList, - phEvent)); - } else { - UR_CALL(pfnMemBufferReadRect( - hQueue, hBuffer, blockingRead, bufferOrigin, hostOrigin, region, - bufferRowPitch, bufferSlicePitch, hostRowPitch, hostSlicePitch, - pDst, numEventsInWaitList, phEventWaitList, phEvent)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueMemBufferWriteRect -__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferWriteRect( - ur_queue_handle_t hQueue, ///< [in] handle of the queue object - ur_mem_handle_t - hBuffer, ///< [in][bounds(bufferOrigin, region)] handle of the buffer object - bool - blockingWrite, ///< [in] indicates blocking (true), non-blocking (false) - ur_rect_offset_t bufferOrigin, ///< [in] 3D offset in the buffer - ur_rect_offset_t hostOrigin, ///< [in] 3D offset in the host region - ur_rect_region_t - region, ///< [in] 3D rectangular region descriptor: width, height, depth - size_t - bufferRowPitch, ///< [in] length of each row in bytes in the buffer object - size_t - bufferSlicePitch, ///< [in] length of each 2D slice in bytes in the buffer object being - ///< written - size_t - hostRowPitch, ///< [in] length of each row in bytes in the host memory region pointed by - ///< src - size_t - hostSlicePitch, ///< [in] length of each 2D slice in bytes in the host memory region - ///< pointed by src - void - *pSrc, ///< [in] pointer to host memory where data is to be written from - uint32_t numEventsInWaitList, ///< [in] size of the event wait list - const ur_event_handle_t * - phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] points to a list of - ///< events that must be complete before this command can be executed. - ///< If nullptr, the numEventsInWaitList must be 0, indicating that this - ///< command does not wait on any event to complete. - ur_event_handle_t * - phEvent ///< [out][optional] return an event object that identifies this particular - ///< command instance. -) { - auto pfnMemBufferWriteRect = - getContext()->urDdiTable.Enqueue.pfnMemBufferWriteRect; - - if (nullptr == pfnMemBufferWriteRect) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urEnqueueMemBufferWriteRect"); - - if (auto MemBuffer = getContext()->interceptor->getMemBuffer(hBuffer)) { - char *DstHandle = nullptr; - ur_device_handle_t Device = GetDevice(hQueue); - UR_CALL(MemBuffer->getHandle(Device, DstHandle)); - - UR_CALL(EnqueueMemCopyRectHelper( - hQueue, ur_cast(pSrc), DstHandle, hostOrigin, bufferOrigin, - region, hostRowPitch, hostSlicePitch, bufferRowPitch, - bufferSlicePitch, blockingWrite, numEventsInWaitList, - phEventWaitList, phEvent)); - } else { - UR_CALL(pfnMemBufferWriteRect( - hQueue, hBuffer, blockingWrite, bufferOrigin, hostOrigin, region, - bufferRowPitch, bufferSlicePitch, hostRowPitch, hostSlicePitch, - pSrc, numEventsInWaitList, phEventWaitList, phEvent)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueMemBufferCopy -__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopy( - ur_queue_handle_t hQueue, ///< [in] handle of the queue object - ur_mem_handle_t - hBufferSrc, ///< [in][bounds(srcOffset, size)] handle of the src buffer object - ur_mem_handle_t - hBufferDst, ///< [in][bounds(dstOffset, size)] handle of the dest buffer object - size_t srcOffset, ///< [in] offset into hBufferSrc to begin copying from - size_t dstOffset, ///< [in] offset info hBufferDst to begin copying into - size_t size, ///< [in] size in bytes of data being copied - uint32_t numEventsInWaitList, ///< [in] size of the event wait list - const ur_event_handle_t * - phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of - ///< events that must be complete before this command can be executed. - ///< If nullptr, the numEventsInWaitList must be 0, indicating that this - ///< command does not wait on any event to complete. - ur_event_handle_t * - phEvent ///< [out][optional] return an event object that identifies this particular - ///< command instance. -) { - auto pfnMemBufferCopy = getContext()->urDdiTable.Enqueue.pfnMemBufferCopy; - - if (nullptr == pfnMemBufferCopy) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urEnqueueMemBufferCopy"); - - auto SrcBuffer = getContext()->interceptor->getMemBuffer(hBufferSrc); - auto DstBuffer = getContext()->interceptor->getMemBuffer(hBufferDst); - - UR_ASSERT((SrcBuffer && DstBuffer) || (!SrcBuffer && !DstBuffer), - UR_RESULT_ERROR_INVALID_MEM_OBJECT); - - if (SrcBuffer && DstBuffer) { - ur_device_handle_t Device = GetDevice(hQueue); - char *SrcHandle = nullptr; - UR_CALL(SrcBuffer->getHandle(Device, SrcHandle)); - - char *DstHandle = nullptr; - UR_CALL(DstBuffer->getHandle(Device, DstHandle)); - - UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( - hQueue, false, DstHandle + dstOffset, SrcHandle + srcOffset, size, - numEventsInWaitList, phEventWaitList, phEvent)); - } else { - UR_CALL(pfnMemBufferCopy(hQueue, hBufferSrc, hBufferDst, srcOffset, - dstOffset, size, numEventsInWaitList, - phEventWaitList, phEvent)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueMemBufferCopyRect -__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( - ur_queue_handle_t hQueue, ///< [in] handle of the queue object - ur_mem_handle_t - hBufferSrc, ///< [in][bounds(srcOrigin, region)] handle of the source buffer object - ur_mem_handle_t - hBufferDst, ///< [in][bounds(dstOrigin, region)] handle of the dest buffer object - ur_rect_offset_t srcOrigin, ///< [in] 3D offset in the source buffer - ur_rect_offset_t dstOrigin, ///< [in] 3D offset in the destination buffer - ur_rect_region_t - region, ///< [in] source 3D rectangular region descriptor: width, height, depth - size_t - srcRowPitch, ///< [in] length of each row in bytes in the source buffer object - size_t - srcSlicePitch, ///< [in] length of each 2D slice in bytes in the source buffer object - size_t - dstRowPitch, ///< [in] length of each row in bytes in the destination buffer object - size_t - dstSlicePitch, ///< [in] length of each 2D slice in bytes in the destination buffer object - uint32_t numEventsInWaitList, ///< [in] size of the event wait list - const ur_event_handle_t * - phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of - ///< events that must be complete before this command can be executed. - ///< If nullptr, the numEventsInWaitList must be 0, indicating that this - ///< command does not wait on any event to complete. - ur_event_handle_t * - phEvent ///< [out][optional] return an event object that identifies this particular - ///< command instance. -) { - auto pfnMemBufferCopyRect = - getContext()->urDdiTable.Enqueue.pfnMemBufferCopyRect; - - if (nullptr == pfnMemBufferCopyRect) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urEnqueueMemBufferCopyRect"); - - auto SrcBuffer = getContext()->interceptor->getMemBuffer(hBufferSrc); - auto DstBuffer = getContext()->interceptor->getMemBuffer(hBufferDst); - - UR_ASSERT((SrcBuffer && DstBuffer) || (!SrcBuffer && !DstBuffer), - UR_RESULT_ERROR_INVALID_MEM_OBJECT); - - if (SrcBuffer && DstBuffer) { - ur_device_handle_t Device = GetDevice(hQueue); - char *SrcHandle = nullptr; - UR_CALL(SrcBuffer->getHandle(Device, SrcHandle)); - - char *DstHandle = nullptr; - UR_CALL(DstBuffer->getHandle(Device, DstHandle)); - - UR_CALL(EnqueueMemCopyRectHelper( - hQueue, SrcHandle, DstHandle, srcOrigin, dstOrigin, region, - srcRowPitch, srcSlicePitch, dstRowPitch, dstSlicePitch, false, - numEventsInWaitList, phEventWaitList, phEvent)); - } else { - UR_CALL(pfnMemBufferCopyRect( - hQueue, hBufferSrc, hBufferDst, srcOrigin, dstOrigin, region, - srcRowPitch, srcSlicePitch, dstRowPitch, dstSlicePitch, - numEventsInWaitList, phEventWaitList, phEvent)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueMemBufferFill -__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferFill( - ur_queue_handle_t hQueue, ///< [in] handle of the queue object - ur_mem_handle_t - hBuffer, ///< [in][bounds(offset, size)] handle of the buffer object - const void *pPattern, ///< [in] pointer to the fill pattern - size_t patternSize, ///< [in] size in bytes of the pattern - size_t offset, ///< [in] offset into the buffer - size_t size, ///< [in] fill size in bytes, must be a multiple of patternSize - uint32_t numEventsInWaitList, ///< [in] size of the event wait list - const ur_event_handle_t * - phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of - ///< events that must be complete before this command can be executed. - ///< If nullptr, the numEventsInWaitList must be 0, indicating that this - ///< command does not wait on any event to complete. - ur_event_handle_t * - phEvent ///< [out][optional] return an event object that identifies this particular - ///< command instance. -) { - auto pfnMemBufferFill = getContext()->urDdiTable.Enqueue.pfnMemBufferFill; - - if (nullptr == pfnMemBufferFill) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urEnqueueMemBufferFill"); - - if (auto MemBuffer = getContext()->interceptor->getMemBuffer(hBuffer)) { - char *Handle = nullptr; - ur_device_handle_t Device = GetDevice(hQueue); - UR_CALL(MemBuffer->getHandle(Device, Handle)); - UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill( - hQueue, Handle + offset, patternSize, pPattern, size, - numEventsInWaitList, phEventWaitList, phEvent)); - } else { - UR_CALL(pfnMemBufferFill(hQueue, hBuffer, pPattern, patternSize, offset, - size, numEventsInWaitList, phEventWaitList, - phEvent)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueMemBufferMap -__urdlllocal ur_result_t UR_APICALL urEnqueueMemBufferMap( - ur_queue_handle_t hQueue, ///< [in] handle of the queue object - ur_mem_handle_t - hBuffer, ///< [in][bounds(offset, size)] handle of the buffer object - bool blockingMap, ///< [in] indicates blocking (true), non-blocking (false) - ur_map_flags_t mapFlags, ///< [in] flags for read, write, readwrite mapping - size_t offset, ///< [in] offset in bytes of the buffer region being mapped - size_t size, ///< [in] size in bytes of the buffer region being mapped - uint32_t numEventsInWaitList, ///< [in] size of the event wait list - const ur_event_handle_t * - phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of - ///< events that must be complete before this command can be executed. - ///< If nullptr, the numEventsInWaitList must be 0, indicating that this - ///< command does not wait on any event to complete. - ur_event_handle_t * - phEvent, ///< [out][optional] return an event object that identifies this particular - ///< command instance. - void **ppRetMap ///< [out] return mapped pointer. TODO: move it before - ///< numEventsInWaitList? -) { - auto pfnMemBufferMap = getContext()->urDdiTable.Enqueue.pfnMemBufferMap; - - if (nullptr == pfnMemBufferMap) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urEnqueueMemBufferMap"); - - if (auto MemBuffer = getContext()->interceptor->getMemBuffer(hBuffer)) { - - // Translate the host access mode info. - MemBuffer::AccessMode AccessMode = MemBuffer::UNKNOWN; - if (mapFlags & UR_MAP_FLAG_WRITE_INVALIDATE_REGION) { - AccessMode = MemBuffer::WRITE_ONLY; - } else { - if (mapFlags & UR_MAP_FLAG_READ) { - AccessMode = MemBuffer::READ_ONLY; - if (mapFlags & UR_MAP_FLAG_WRITE) { - AccessMode = MemBuffer::READ_WRITE; - } - } else if (mapFlags & UR_MAP_FLAG_WRITE) { - AccessMode = MemBuffer::WRITE_ONLY; - } - } - - UR_ASSERT(AccessMode != MemBuffer::UNKNOWN, - UR_RESULT_ERROR_INVALID_ARGUMENT); - - ur_device_handle_t Device = GetDevice(hQueue); - // If the buffer used host pointer, then we just reuse it. If not, we - // need to manually allocate a new host USM. - if (MemBuffer->HostPtr) { - *ppRetMap = MemBuffer->HostPtr + offset; - } else { - ur_context_handle_t Context = GetContext(hQueue); - ur_usm_desc_t USMDesc{}; - USMDesc.align = MemBuffer->getAlignment(); - ur_usm_pool_handle_t Pool{}; - UR_CALL(getContext()->interceptor->allocateMemory( - Context, nullptr, &USMDesc, Pool, size, AllocType::HOST_USM, - ppRetMap)); - } - - // Actually, if the access mode is write only, we don't need to do this - // copy. However, in that way, we cannot generate a event to user. So, - // we'll aways do copy here. - char *SrcHandle = nullptr; - UR_CALL(MemBuffer->getHandle(Device, SrcHandle)); - UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( - hQueue, blockingMap, *ppRetMap, SrcHandle + offset, size, - numEventsInWaitList, phEventWaitList, phEvent)); - - { - std::scoped_lock Guard(MemBuffer->Mutex); - UR_ASSERT(MemBuffer->Mappings.find(*ppRetMap) == - MemBuffer->Mappings.end(), - UR_RESULT_ERROR_INVALID_VALUE); - MemBuffer->Mappings[*ppRetMap] = {offset, size}; - } - } else { - UR_CALL(pfnMemBufferMap(hQueue, hBuffer, blockingMap, mapFlags, offset, - size, numEventsInWaitList, phEventWaitList, - phEvent, ppRetMap)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueMemUnmap -__urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap( - ur_queue_handle_t hQueue, ///< [in] handle of the queue object - ur_mem_handle_t - hMem, ///< [in] handle of the memory (buffer or image) object - void *pMappedPtr, ///< [in] mapped host address - uint32_t numEventsInWaitList, ///< [in] size of the event wait list - const ur_event_handle_t * - phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of - ///< events that must be complete before this command can be executed. - ///< If nullptr, the numEventsInWaitList must be 0, indicating that this - ///< command does not wait on any event to complete. - ur_event_handle_t * - phEvent ///< [out][optional] return an event object that identifies this particular - ///< command instance. -) { - auto pfnMemUnmap = getContext()->urDdiTable.Enqueue.pfnMemUnmap; - - if (nullptr == pfnMemUnmap) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urEnqueueMemUnmap"); - - if (auto MemBuffer = getContext()->interceptor->getMemBuffer(hMem)) { - MemBuffer::Mapping Mapping{}; - { - std::scoped_lock Guard(MemBuffer->Mutex); - auto It = MemBuffer->Mappings.find(pMappedPtr); - UR_ASSERT(It != MemBuffer->Mappings.end(), - UR_RESULT_ERROR_INVALID_VALUE); - Mapping = It->second; - MemBuffer->Mappings.erase(It); - } - - // Write back mapping memory data to device and release mapping memory - // if we allocated a host USM. But for now, UR doesn't support event - // call back, we can only do blocking copy here. - char *DstHandle = nullptr; - ur_context_handle_t Context = GetContext(hQueue); - ur_device_handle_t Device = GetDevice(hQueue); - UR_CALL(MemBuffer->getHandle(Device, DstHandle)); - UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy( - hQueue, true, DstHandle + Mapping.Offset, pMappedPtr, Mapping.Size, - numEventsInWaitList, phEventWaitList, phEvent)); - - if (!MemBuffer->HostPtr) { - UR_CALL( - getContext()->interceptor->releaseMemory(Context, pMappedPtr)); - } - } else { - UR_CALL(pfnMemUnmap(hQueue, hMem, pMappedPtr, numEventsInWaitList, - phEventWaitList, phEvent)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urKernelCreate -__urdlllocal ur_result_t UR_APICALL urKernelCreate( - ur_program_handle_t hProgram, ///< [in] handle of the program instance - const char *pKernelName, ///< [in] pointer to null-terminated string. - ur_kernel_handle_t - *phKernel ///< [out] pointer to handle of kernel object created. -) { - auto pfnCreate = getContext()->urDdiTable.Kernel.pfnCreate; - - if (nullptr == pfnCreate) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urKernelCreate"); - - UR_CALL(pfnCreate(hProgram, pKernelName, phKernel)); - UR_CALL(getContext()->interceptor->insertKernel(*phKernel)); - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urKernelRetain -__urdlllocal ur_result_t UR_APICALL urKernelRetain( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to retain -) { - auto pfnRetain = getContext()->urDdiTable.Kernel.pfnRetain; - - if (nullptr == pfnRetain) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urKernelRetain"); - - UR_CALL(pfnRetain(hKernel)); - - auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel); - UR_ASSERT(KernelInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - KernelInfo->RefCount++; - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urKernelRelease -__urdlllocal ur_result_t urKernelRelease( - ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to release -) { - auto pfnRelease = getContext()->urDdiTable.Kernel.pfnRelease; - - if (nullptr == pfnRelease) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urKernelRelease"); - UR_CALL(pfnRelease(hKernel)); - - auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel); - UR_ASSERT(KernelInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - if (--KernelInfo->RefCount == 0) { - UR_CALL(getContext()->interceptor->eraseKernel(hKernel)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urKernelSetArgValue -__urdlllocal ur_result_t UR_APICALL urKernelSetArgValue( - ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object - uint32_t argIndex, ///< [in] argument index in range [0, num args - 1] - size_t argSize, ///< [in] size of argument type - const ur_kernel_arg_value_properties_t - *pProperties, ///< [in][optional] pointer to value properties. - const void - *pArgValue ///< [in] argument value represented as matching arg type. -) { - auto pfnSetArgValue = getContext()->urDdiTable.Kernel.pfnSetArgValue; - - if (nullptr == pfnSetArgValue) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urKernelSetArgValue"); - - std::shared_ptr MemBuffer; - if (argSize == sizeof(ur_mem_handle_t) && - (MemBuffer = getContext()->interceptor->getMemBuffer( - *ur_cast(pArgValue)))) { - auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel); - std::scoped_lock Guard(KernelInfo->Mutex); - KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer); - } else { - UR_CALL( - pfnSetArgValue(hKernel, argIndex, argSize, pProperties, pArgValue)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urKernelSetArgMemObj -__urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj( - ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object - uint32_t argIndex, ///< [in] argument index in range [0, num args - 1] - const ur_kernel_arg_mem_obj_properties_t - *pProperties, ///< [in][optional] pointer to Memory object properties. - ur_mem_handle_t hArgValue ///< [in][optional] handle of Memory object. -) { - auto pfnSetArgMemObj = getContext()->urDdiTable.Kernel.pfnSetArgMemObj; - - if (nullptr == pfnSetArgMemObj) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug("==== urKernelSetArgMemObj"); - - if (auto MemBuffer = getContext()->interceptor->getMemBuffer(hArgValue)) { - auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel); - std::scoped_lock Guard(KernelInfo->Mutex); - KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer); - } else { - UR_CALL(pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue)); - } - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urKernelSetArgLocal -__urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal( - ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object - uint32_t argIndex, ///< [in] argument index in range [0, num args - 1] - size_t - argSize, ///< [in] size of the local buffer to be allocated by the runtime - const ur_kernel_arg_local_properties_t - *pProperties ///< [in][optional] pointer to local buffer properties. -) { - auto pfnSetArgLocal = getContext()->urDdiTable.Kernel.pfnSetArgLocal; - - if (nullptr == pfnSetArgLocal) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug( - "==== urKernelSetArgLocal (argIndex={}, argSize={})", argIndex, - argSize); - - { - auto KI = getContext()->interceptor->getKernelInfo(hKernel); - std::scoped_lock Guard(KI->Mutex); - // TODO: get local variable alignment - auto argSizeWithRZ = GetSizeAndRedzoneSizeForLocal( - argSize, ASAN_SHADOW_GRANULARITY, ASAN_SHADOW_GRANULARITY); - KI->LocalArgs[argIndex] = LocalArgsInfo{argSize, argSizeWithRZ}; - argSize = argSizeWithRZ; - } - - ur_result_t result = - pfnSetArgLocal(hKernel, argIndex, argSize, pProperties); - - return result; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urKernelSetArgPointer -__urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer( - ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object - uint32_t argIndex, ///< [in] argument index in range [0, num args - 1] - const ur_kernel_arg_pointer_properties_t - *pProperties, ///< [in][optional] pointer to USM pointer properties. - const void * - pArgValue ///< [in][optional] Pointer obtained by USM allocation or virtual memory - ///< mapping operation. If null then argument value is considered null. -) { - auto pfnSetArgPointer = getContext()->urDdiTable.Kernel.pfnSetArgPointer; - - if (nullptr == pfnSetArgPointer) { - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; - } - - getContext()->logger.debug( - "==== urKernelSetArgPointer (argIndex={}, pArgValue={})", argIndex, - pArgValue); - - if (getContext()->interceptor->getOptions().DetectKernelArguments) { - auto KI = getContext()->interceptor->getKernelInfo(hKernel); - std::scoped_lock Guard(KI->Mutex); - KI->PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()}; - } - - ur_result_t result = - pfnSetArgPointer(hKernel, argIndex, pProperties, pArgValue); - - return result; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Exported function for filling application's Global table -/// with current process' addresses -/// -/// @returns -/// - ::UR_RESULT_SUCCESS -/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER -/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION -__urdlllocal ur_result_t UR_APICALL urGetGlobalProcAddrTable( - ur_api_version_t version, ///< [in] API version requested - ur_global_dditable_t - *pDdiTable ///< [in,out] pointer to table of DDI function pointers -) { - if (nullptr == pDdiTable) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != - UR_MAJOR_VERSION(version) || - UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > - UR_MINOR_VERSION(version)) { - return UR_RESULT_ERROR_UNSUPPORTED_VERSION; - } - - ur_result_t result = UR_RESULT_SUCCESS; - - pDdiTable->pfnAdapterGet = ur_sanitizer_layer::urAdapterGet; - - return result; -} -/////////////////////////////////////////////////////////////////////////////// -/// @brief Exported function for filling application's Context table -/// with current process' addresses -/// -/// @returns -/// - ::UR_RESULT_SUCCESS -/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER -/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION -__urdlllocal ur_result_t UR_APICALL urGetContextProcAddrTable( - ur_api_version_t version, ///< [in] API version requested - ur_context_dditable_t - *pDdiTable ///< [in,out] pointer to table of DDI function pointers -) { - if (nullptr == pDdiTable) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != - UR_MAJOR_VERSION(version) || - UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > - UR_MINOR_VERSION(version)) { - return UR_RESULT_ERROR_UNSUPPORTED_VERSION; - } - - ur_result_t result = UR_RESULT_SUCCESS; - - pDdiTable->pfnCreate = ur_sanitizer_layer::urContextCreate; - pDdiTable->pfnRetain = ur_sanitizer_layer::urContextRetain; - pDdiTable->pfnRelease = ur_sanitizer_layer::urContextRelease; - - pDdiTable->pfnCreateWithNativeHandle = - ur_sanitizer_layer::urContextCreateWithNativeHandle; - - return result; -} -/////////////////////////////////////////////////////////////////////////////// -/// @brief Exported function for filling application's Program table -/// with current process' addresses -/// -/// @returns -/// - ::UR_RESULT_SUCCESS -/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER -/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION -__urdlllocal ur_result_t UR_APICALL urGetProgramProcAddrTable( - ur_api_version_t version, ///< [in] API version requested - ur_program_dditable_t - *pDdiTable ///< [in,out] pointer to table of DDI function pointers -) { - if (nullptr == pDdiTable) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != - UR_MAJOR_VERSION(version) || - UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > - UR_MINOR_VERSION(version)) { - return UR_RESULT_ERROR_UNSUPPORTED_VERSION; - } - - pDdiTable->pfnCreateWithIL = ur_sanitizer_layer::urProgramCreateWithIL; - pDdiTable->pfnCreateWithBinary = - ur_sanitizer_layer::urProgramCreateWithBinary; - pDdiTable->pfnCreateWithNativeHandle = - ur_sanitizer_layer::urProgramCreateWithNativeHandle; - pDdiTable->pfnBuild = ur_sanitizer_layer::urProgramBuild; - pDdiTable->pfnLink = ur_sanitizer_layer::urProgramLink; - pDdiTable->pfnRetain = ur_sanitizer_layer::urProgramRetain; - pDdiTable->pfnRelease = ur_sanitizer_layer::urProgramRelease; - - return UR_RESULT_SUCCESS; -} - -/////////////////////////////////////////////////////////////////////////////// -/// @brief Exported function for filling application's Kernel table -/// with current process' addresses -/// -/// @returns -/// - ::UR_RESULT_SUCCESS -/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER -/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION -__urdlllocal ur_result_t UR_APICALL urGetKernelProcAddrTable( - ur_api_version_t version, ///< [in] API version requested - ur_kernel_dditable_t - *pDdiTable ///< [in,out] pointer to table of DDI function pointers -) { - if (nullptr == pDdiTable) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != - UR_MAJOR_VERSION(version) || - UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > - UR_MINOR_VERSION(version)) { - return UR_RESULT_ERROR_UNSUPPORTED_VERSION; - } - - ur_result_t result = UR_RESULT_SUCCESS; - - pDdiTable->pfnCreate = ur_sanitizer_layer::urKernelCreate; - pDdiTable->pfnRetain = ur_sanitizer_layer::urKernelRetain; - pDdiTable->pfnRelease = ur_sanitizer_layer::urKernelRelease; - pDdiTable->pfnSetArgValue = ur_sanitizer_layer::urKernelSetArgValue; - pDdiTable->pfnSetArgMemObj = ur_sanitizer_layer::urKernelSetArgMemObj; - pDdiTable->pfnSetArgLocal = ur_sanitizer_layer::urKernelSetArgLocal; - pDdiTable->pfnSetArgPointer = ur_sanitizer_layer::urKernelSetArgPointer; - - return result; -} -/////////////////////////////////////////////////////////////////////////////// -/// @brief Exported function for filling application's Mem table -/// with current process' addresses -/// -/// @returns -/// - ::UR_RESULT_SUCCESS -/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER -/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION -__urdlllocal ur_result_t UR_APICALL urGetMemProcAddrTable( - ur_api_version_t version, ///< [in] API version requested - ur_mem_dditable_t - *pDdiTable ///< [in,out] pointer to table of DDI function pointers -) { - if (nullptr == pDdiTable) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != - UR_MAJOR_VERSION(version) || - UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > - UR_MINOR_VERSION(version)) { - return UR_RESULT_ERROR_UNSUPPORTED_VERSION; - } - - ur_result_t result = UR_RESULT_SUCCESS; - - pDdiTable->pfnBufferCreate = ur_sanitizer_layer::urMemBufferCreate; - pDdiTable->pfnRetain = ur_sanitizer_layer::urMemRetain; - pDdiTable->pfnRelease = ur_sanitizer_layer::urMemRelease; - pDdiTable->pfnBufferPartition = ur_sanitizer_layer::urMemBufferPartition; - pDdiTable->pfnGetNativeHandle = ur_sanitizer_layer::urMemGetNativeHandle; - pDdiTable->pfnGetInfo = ur_sanitizer_layer::urMemGetInfo; - - return result; -} -/// @brief Exported function for filling application's ProgramExp table -/// with current process' addresses -/// -/// @returns -/// - ::UR_RESULT_SUCCESS -/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER -/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION -__urdlllocal ur_result_t UR_APICALL urGetProgramExpProcAddrTable( - ur_api_version_t version, ///< [in] API version requested - ur_program_exp_dditable_t - *pDdiTable ///< [in,out] pointer to table of DDI function pointers -) { - if (nullptr == pDdiTable) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != - UR_MAJOR_VERSION(version) || - UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > - UR_MINOR_VERSION(version)) { - return UR_RESULT_ERROR_UNSUPPORTED_VERSION; - } - - ur_result_t result = UR_RESULT_SUCCESS; - - pDdiTable->pfnBuildExp = ur_sanitizer_layer::urProgramBuildExp; - pDdiTable->pfnLinkExp = ur_sanitizer_layer::urProgramLinkExp; - - return result; -} -/////////////////////////////////////////////////////////////////////////////// -/// @brief Exported function for filling application's Enqueue table -/// with current process' addresses -/// -/// @returns -/// - ::UR_RESULT_SUCCESS -/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER -/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION -__urdlllocal ur_result_t UR_APICALL urGetEnqueueProcAddrTable( - ur_api_version_t version, ///< [in] API version requested - ur_enqueue_dditable_t - *pDdiTable ///< [in,out] pointer to table of DDI function pointers -) { - if (nullptr == pDdiTable) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != - UR_MAJOR_VERSION(version) || - UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > - UR_MINOR_VERSION(version)) { - return UR_RESULT_ERROR_UNSUPPORTED_VERSION; - } - - ur_result_t result = UR_RESULT_SUCCESS; - - pDdiTable->pfnMemBufferRead = ur_sanitizer_layer::urEnqueueMemBufferRead; - pDdiTable->pfnMemBufferWrite = ur_sanitizer_layer::urEnqueueMemBufferWrite; - pDdiTable->pfnMemBufferReadRect = - ur_sanitizer_layer::urEnqueueMemBufferReadRect; - pDdiTable->pfnMemBufferWriteRect = - ur_sanitizer_layer::urEnqueueMemBufferWriteRect; - pDdiTable->pfnMemBufferCopy = ur_sanitizer_layer::urEnqueueMemBufferCopy; - pDdiTable->pfnMemBufferCopyRect = - ur_sanitizer_layer::urEnqueueMemBufferCopyRect; - pDdiTable->pfnMemBufferFill = ur_sanitizer_layer::urEnqueueMemBufferFill; - pDdiTable->pfnMemBufferMap = ur_sanitizer_layer::urEnqueueMemBufferMap; - pDdiTable->pfnMemUnmap = ur_sanitizer_layer::urEnqueueMemUnmap; - pDdiTable->pfnKernelLaunch = ur_sanitizer_layer::urEnqueueKernelLaunch; - - return result; -} -/////////////////////////////////////////////////////////////////////////////// -/// @brief Exported function for filling application's USM table -/// with current process' addresses -/// -/// @returns -/// - ::UR_RESULT_SUCCESS -/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER -/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION -__urdlllocal ur_result_t UR_APICALL urGetUSMProcAddrTable( - ur_api_version_t version, ///< [in] API version requested - ur_usm_dditable_t - *pDdiTable ///< [in,out] pointer to table of DDI function pointers -) { - if (nullptr == pDdiTable) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - if (UR_MAJOR_VERSION(ur_sanitizer_layer::getContext()->version) != - UR_MAJOR_VERSION(version) || - UR_MINOR_VERSION(ur_sanitizer_layer::getContext()->version) > - UR_MINOR_VERSION(version)) { - return UR_RESULT_ERROR_UNSUPPORTED_VERSION; - } - - ur_result_t result = UR_RESULT_SUCCESS; - - pDdiTable->pfnDeviceAlloc = ur_sanitizer_layer::urUSMDeviceAlloc; - pDdiTable->pfnHostAlloc = ur_sanitizer_layer::urUSMHostAlloc; - pDdiTable->pfnSharedAlloc = ur_sanitizer_layer::urUSMSharedAlloc; - pDdiTable->pfnFree = ur_sanitizer_layer::urUSMFree; - - return result; -} - ur_result_t context_t::init(ur_dditable_t *dditable, const std::set &enabledLayerNames, [[maybe_unused]] codeloc_data codelocData) { @@ -1818,7 +25,6 @@ ur_result_t context_t::init(ur_dditable_t *dditable, if (enabledLayerNames.count("UR_LAYER_ASAN")) { enabledType = SanitizerType::AddressSanitizer; - interceptor = std::make_unique(); } else if (enabledLayerNames.count("UR_LAYER_MSAN")) { enabledType = SanitizerType::MemorySanitizer; } else if (enabledLayerNames.count("UR_LAYER_TSAN")) { @@ -1830,63 +36,10 @@ ur_result_t context_t::init(ur_dditable_t *dditable, return result; } - if (enabledType == SanitizerType::AddressSanitizer) { - if (!(dditable->VirtualMem.pfnReserve && dditable->VirtualMem.pfnMap && - dditable->VirtualMem.pfnGranularityGetInfo)) { - die("Some VirtualMem APIs are needed to enable UR_LAYER_ASAN"); - } - - if (!dditable->PhysicalMem.pfnCreate) { - die("Some PhysicalMem APIs are needed to enable UR_LAYER_ASAN"); - } - } - urDdiTable = *dditable; - if (UR_RESULT_SUCCESS == result) { - result = ur_sanitizer_layer::urGetGlobalProcAddrTable( - UR_API_VERSION_CURRENT, &dditable->Global); - } - - if (UR_RESULT_SUCCESS == result) { - result = ur_sanitizer_layer::urGetContextProcAddrTable( - UR_API_VERSION_CURRENT, &dditable->Context); - } - - if (UR_RESULT_SUCCESS == result) { - result = ur_sanitizer_layer::urGetKernelProcAddrTable( - UR_API_VERSION_CURRENT, &dditable->Kernel); - } - - if (UR_RESULT_SUCCESS == result) { - result = ur_sanitizer_layer::urGetProgramProcAddrTable( - UR_API_VERSION_CURRENT, &dditable->Program); - } - - if (UR_RESULT_SUCCESS == result) { - result = ur_sanitizer_layer::urGetKernelProcAddrTable( - UR_API_VERSION_CURRENT, &dditable->Kernel); - } - - if (UR_RESULT_SUCCESS == result) { - result = ur_sanitizer_layer::urGetMemProcAddrTable( - UR_API_VERSION_CURRENT, &dditable->Mem); - } - - if (UR_RESULT_SUCCESS == result) { - result = ur_sanitizer_layer::urGetProgramExpProcAddrTable( - UR_API_VERSION_CURRENT, &dditable->ProgramExp); - } - - if (UR_RESULT_SUCCESS == result) { - result = ur_sanitizer_layer::urGetEnqueueProcAddrTable( - UR_API_VERSION_CURRENT, &dditable->Enqueue); - } - - if (UR_RESULT_SUCCESS == result) { - result = ur_sanitizer_layer::urGetUSMProcAddrTable( - UR_API_VERSION_CURRENT, &dditable->USM); - } + initAsanInterceptor(); + result = asan_ddi_init(dditable); return result; } diff --git a/source/loader/layers/sanitizer/ur_sanitizer_layer.cpp b/source/loader/layers/sanitizer/ur_sanitizer_layer.cpp index b94235cdf0..191cb873db 100644 --- a/source/loader/layers/sanitizer/ur_sanitizer_layer.cpp +++ b/source/loader/layers/sanitizer/ur_sanitizer_layer.cpp @@ -11,7 +11,7 @@ */ #include "ur_sanitizer_layer.hpp" -#include "asan_interceptor.hpp" +#include "asan/asan_ddi.hpp" namespace ur_sanitizer_layer { context_t *getContext() { return context_t::get_direct(); } @@ -21,7 +21,17 @@ context_t::context_t() : logger(logger::create_logger("sanitizer", false, false, logger::Level::WARN)) {} -ur_result_t context_t::tearDown() { return UR_RESULT_SUCCESS; } +ur_result_t context_t::tearDown() { + switch (enabledType) { + case SanitizerType::AddressSanitizer: + destroyAsanInterceptor(); + break; + default: + break; + } + + return UR_RESULT_SUCCESS; +} /////////////////////////////////////////////////////////////////////////////// context_t::~context_t() {} diff --git a/source/loader/layers/sanitizer/ur_sanitizer_layer.hpp b/source/loader/layers/sanitizer/ur_sanitizer_layer.hpp index e7f704f8a8..7f1158c5ae 100644 --- a/source/loader/layers/sanitizer/ur_sanitizer_layer.hpp +++ b/source/loader/layers/sanitizer/ur_sanitizer_layer.hpp @@ -13,14 +13,13 @@ #pragma once #include "logger/ur_logger.hpp" -#include "ur/ur.hpp" #include "ur_proxy_layer.hpp" #define SANITIZER_COMP_NAME "sanitizer layer" namespace ur_sanitizer_layer { -class SanitizerInterceptor; +class AsanInterceptor; enum class SanitizerType { None, @@ -35,7 +34,6 @@ class __urdlllocal context_t : public proxy_layer_context_t, public: ur_dditable_t urDdiTable = {}; logger::Logger logger; - std::unique_ptr interceptor; SanitizerType enabledType = SanitizerType::None; context_t(); @@ -52,4 +50,5 @@ class __urdlllocal context_t : public proxy_layer_context_t, }; context_t *getContext(); + } // namespace ur_sanitizer_layer