Skip to content

Commit

Permalink
[NPU] Remove template in ext wrapper and fuse functions (#27511)
Browse files Browse the repository at this point in the history
### Details:
 - *Remove template in zero_ext_graph_wrappers*
 - *Remove zero_ext_graph_wrappers_interface.hpp*
 - *Add more low level debug log*
 - *Update level-zero-ext repo commit to use 1.9 version*

### Tickets:
 - *156387*

---------

Signed-off-by: Xin Wang <[email protected]>
  • Loading branch information
XinWangIntel authored Nov 15, 2024
1 parent 6489755 commit 3e63de0
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 486 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "intel_npu/config/config.hpp"
#include "intel_npu/utils/logger/logger.hpp"
#include "intel_npu/utils/zero/zero_init.hpp"
#include "ze_graph_ext_wrappers_interface.hpp"
#include "ze_graph_ext_wrappers.hpp"

namespace intel_npu {

Expand Down Expand Up @@ -54,7 +54,7 @@ class DriverCompilerAdapter final : public ICompilerAdapter {
std::string serializeConfig(const Config& config, ze_graph_compiler_version_info_t compilerVersion) const;

std::shared_ptr<ZeroInitStructsHolder> _zeroInitStruct;
std::shared_ptr<ZeGraphExtWrappersInterface> _zeGraphExt;
std::shared_ptr<ZeGraphExtWrappers> _zeGraphExt;

ze_device_graph_properties_t _deviceGraphProperties = {};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

#include "intel_npu/common/igraph.hpp"
#include "intel_npu/utils/zero/zero_init.hpp"
#include "ze_graph_ext_wrappers_interface.hpp"
#include "ze_graph_ext_wrappers.hpp"

namespace intel_npu {

class DriverGraph final : public IGraph {
public:
DriverGraph(const std::shared_ptr<ZeGraphExtWrappersInterface>& zeGraphExt,
DriverGraph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
const std::shared_ptr<ZeroInitStructsHolder>& zeroInitStruct,
ze_graph_handle_t graphHandle,
NetworkMetadata metadata,
Expand All @@ -37,7 +37,7 @@ class DriverGraph final : public IGraph {
private:
bool release_blob(const Config& config);

std::shared_ptr<ZeGraphExtWrappersInterface> _zeGraphExt;
std::shared_ptr<ZeGraphExtWrappers> _zeGraphExt;
std::shared_ptr<ZeroInitStructsHolder> _zeroInitStruct;

Logger _logger;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "intel_npu/utils/logger/logger.hpp"
#include "intel_npu/utils/zero/zero_init.hpp"
#include "openvino/runtime/so_ptr.hpp"
#include "ze_graph_ext_wrappers_interface.hpp"
#include "ze_graph_ext_wrappers.hpp"

namespace intel_npu {

Expand All @@ -28,7 +28,7 @@ class PluginCompilerAdapter final : public ICompilerAdapter {
private:
std::shared_ptr<ZeroInitStructsHolder> _zeroInitStruct;

std::shared_ptr<ZeGraphExtWrappersInterface> _zeGraphExt;
std::shared_ptr<ZeGraphExtWrappers> _zeGraphExt;
ov::SoPtr<ICompiler> _compiler;

Logger _logger;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
#include "intel_npu/icompiler.hpp"
#include "intel_npu/utils/zero/zero_init.hpp"
#include "openvino/runtime/so_ptr.hpp"
#include "ze_graph_ext_wrappers_interface.hpp"
#include "ze_graph_ext_wrappers.hpp"

namespace intel_npu {

class PluginGraph final : public IGraph {
public:
PluginGraph(const std::shared_ptr<ZeGraphExtWrappersInterface>& zeGraphExt,
PluginGraph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
const ov::SoPtr<ICompiler>& compiler,
const std::shared_ptr<ZeroInitStructsHolder>& zeroInitStruct,
ze_graph_handle_t graphHandle,
Expand All @@ -38,7 +38,7 @@ class PluginGraph final : public IGraph {
~PluginGraph() override;

private:
std::shared_ptr<ZeGraphExtWrappersInterface> _zeGraphExt;
std::shared_ptr<ZeGraphExtWrappers> _zeGraphExt;
std::shared_ptr<ZeroInitStructsHolder> _zeroInitStruct;

const ov::SoPtr<ICompiler> _compiler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,148 +10,60 @@
#include <type_traits>
#include <utility>

#include "intel_npu/network_metadata.hpp"
#include "intel_npu/utils/logger/logger.hpp"
#include "intel_npu/utils/zero/zero_init.hpp"
#include "intel_npu/utils/zero/zero_types.hpp"
#include "ze_graph_ext_wrappers_interface.hpp"

namespace intel_npu {

#define NotSupportQuery(T) (T == ZE_GRAPH_EXT_VERSION_1_2)

// ext version == 1.3 && 1.4, support API (pfnQueryNetworkCreate, pfnQueryNetworkDestroy,
// pfnQueryNetworkGetSupportedLayers)
#define SupportAPIGraphQueryNetworkV1(T) (T == ZE_GRAPH_EXT_VERSION_1_3 || T == ZE_GRAPH_EXT_VERSION_1_4)

// ext version >= 1.5, support API (pfnCreate2, pfnQueryNetworkCreate2, pfnQueryContextMemory)
#define SupportAPIGraphQueryNetworkV2(T) ((!NotSupportQuery(T) && !SupportAPIGraphQueryNetworkV1(T)))

// For ext version >= 1.5, pfnCreate2 api is avaible
#define NotSupportGraph2(T) \
(T == ZE_GRAPH_EXT_VERSION_1_2 || T == ZE_GRAPH_EXT_VERSION_1_3 || T == ZE_GRAPH_EXT_VERSION_1_4)

// A bug inside the driver makes the "pfnGraphGetArgumentMetadata" call not safe for use prior to
// "ze_graph_dditable_ext_1_6_t".
// See: E#117498
#define NotSupportArgumentMetadata(T) \
(T == ZE_GRAPH_EXT_VERSION_1_2 || T == ZE_GRAPH_EXT_VERSION_1_3 || T == ZE_GRAPH_EXT_VERSION_1_4 || \
T == ZE_GRAPH_EXT_VERSION_1_5)

#define UseCopyForNativeBinary(T) \
(T == ZE_GRAPH_EXT_VERSION_1_2 || T == ZE_GRAPH_EXT_VERSION_1_3 || T == ZE_GRAPH_EXT_VERSION_1_4 || \
T == ZE_GRAPH_EXT_VERSION_1_5 || T == ZE_GRAPH_EXT_VERSION_1_6)
using SerializedIR = std::pair<size_t, std::shared_ptr<uint8_t>>;

/**
* Adapter to use CiD through ZeroAPI
*/
template <ze_graph_ext_version_t TableExtension>
class ZeGraphExtWrappers final : public ZeGraphExtWrappersInterface {
class ZeGraphExtWrappers {
public:
ZeGraphExtWrappers(const std::shared_ptr<ZeroInitStructsHolder>& zeroInitStruct);
ZeGraphExtWrappers(const ZeGraphExtWrappers&) = delete;
ZeGraphExtWrappers& operator=(const ZeGraphExtWrappers&) = delete;
~ZeGraphExtWrappers();

std::unordered_set<std::string> queryGraph(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags) const override;
const std::string& buildFlags) const;
ze_graph_handle_t getGraphHandle(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags,
const uint32_t& flags) const override;
const uint32_t& flags) const;

ze_graph_handle_t getGraphHandle(const std::vector<uint8_t>& network) const override;
ze_graph_handle_t getGraphHandle(const std::vector<uint8_t>& network) const;

NetworkMetadata getNetworkMeta(ze_graph_handle_t graphHandle) const override;
NetworkMetadata getNetworkMeta(ze_graph_handle_t graphHandle) const;

_ze_result_t destroyGraph(ze_graph_handle_t graphHandle) override;
_ze_result_t destroyGraph(ze_graph_handle_t graphHandle);

void getGraphBinary(ze_graph_handle_t graphHandle,
std::vector<uint8_t>& blob,
const uint8_t*& blobPtr,
size_t& blobSize) const override;
size_t& blobSize) const;

void setGraphArgumentValue(ze_graph_handle_t graphHandle, uint32_t argi_, const void* argv) const override;
void setGraphArgumentValue(ze_graph_handle_t graphHandle, uint32_t argi_, const void* argv) const;

void initializeGraph(ze_graph_handle_t graphHandle, const Config& config) const override;
void initializeGraph(ze_graph_handle_t graphHandle, const Config& config) const;

private:
template <ze_graph_ext_version_t T = TableExtension, std::enable_if_t<!NotSupportQuery(T), bool> = true>
std::unordered_set<std::string> getQueryResultFromSupportedLayers(
ze_result_t result,
ze_graph_query_network_handle_t& hGraphQueryNetwork) const;

template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<NotSupportArgumentMetadata(T), bool> = true>
void getMetadata(ze_graph_handle_t graphHandle,
uint32_t index,
std::vector<IODescriptor>& inputs,
std::vector<IODescriptor>& outputs) const;

template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<!NotSupportArgumentMetadata(T), bool> = true>
void getMetadata(ze_graph_handle_t graphHandle,
uint32_t index,
std::vector<IODescriptor>& inputs,
std::vector<IODescriptor>& outputs) const;

template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<UseCopyForNativeBinary(T), bool> = true>
void getNativeBinary(ze_graph_handle_t graphHandle,
std::vector<uint8_t>& blob,
const uint8_t*& blobPtr,
size_t& blobSize) const;

template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<!UseCopyForNativeBinary(T), bool> = true>
void getNativeBinary(ze_graph_handle_t graphHandle,
std::vector<uint8_t>& /* unusedBlob */,
const uint8_t*& blobPtr,
size_t& blobSize) const;

template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<SupportAPIGraphQueryNetworkV2(T), bool> = true>
ze_result_t queryNetworkCreateV2(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags,
ze_graph_query_network_handle_t& hGraphQueryNetwork) const;

// ext version >= 1.5, support API (pfnCreate2, pfnQueryNetworkCreate2, pfnQueryContextMemory)
template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<SupportAPIGraphQueryNetworkV2(T), bool> = true>
std::unordered_set<std::string> queryImpl(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags) const;

template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<SupportAPIGraphQueryNetworkV1(T), bool> = true>
ze_result_t queryNetworkCreateV1(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags,
ze_graph_query_network_handle_t& hGraphQueryNetwork) const;

// ext version == 1.3 && 1.4, support API (pfnQueryNetworkCreate, pfnQueryNetworkDestroy,
// pfnQueryNetworkGetSupportedLayers)
template <ze_graph_ext_version_t T = TableExtension,
typename std::enable_if_t<SupportAPIGraphQueryNetworkV1(T), bool> = true>
std::unordered_set<std::string> queryImpl(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags) const;

// For ext version < 1.3
template <ze_graph_ext_version_t T = TableExtension, typename std::enable_if_t<NotSupportQuery(T), bool> = true>
std::unordered_set<std::string> queryImpl(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags) const;

template <ze_graph_ext_version_t T = TableExtension, typename std::enable_if_t<NotSupportGraph2(T), bool> = true>
void createGraph(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags,
const uint32_t& flags,
ze_graph_handle_t* graph) const;

template <ze_graph_ext_version_t T = TableExtension, typename std::enable_if_t<!NotSupportGraph2(T), bool> = true>
void createGraph(std::pair<size_t, std::shared_ptr<uint8_t>> serializedIR,
const std::string& buildFlags,
const uint32_t& flags,
ze_graph_handle_t* graph) const;

void initialize_graph_through_command_list(ze_graph_handle_t graphHandle, const Config& config) const;

std::shared_ptr<ZeroInitStructsHolder> _zeroInitStruct;
uint32_t _graphExtVersion;

Logger _logger;
};
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -155,29 +155,7 @@ DriverCompilerAdapter::DriverCompilerAdapter(const std::shared_ptr<ZeroInitStruc

_logger.info("DriverCompilerAdapter creating adapter using graphExtVersion");

switch (graphExtVersion) {
case ZE_GRAPH_EXT_VERSION_1_3:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_3>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_4:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_4>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_5:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_5>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_6:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_6>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_7:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_7>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_8:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_8>>(_zeroInitStruct);
break;
default:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_2>>(_zeroInitStruct);
break;
}
_zeGraphExt = std::make_shared<ZeGraphExtWrappers>(_zeroInitStruct);

_logger.info("initialize DriverCompilerAdapter complete, using graphExtVersion: %d.%d",
ZE_MAJOR_VERSION(graphExtVersion),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace intel_npu {

DriverGraph::DriverGraph(const std::shared_ptr<ZeGraphExtWrappersInterface>& zeGraphExt,
DriverGraph::DriverGraph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
const std::shared_ptr<ZeroInitStructsHolder>& zeroInitStruct,
ze_graph_handle_t graphHandle,
NetworkMetadata metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,29 +70,7 @@ PluginCompilerAdapter::PluginCompilerAdapter(const std::shared_ptr<ZeroInitStruc

_logger.info("PluginCompilerAdapter creating adapter using graphExtVersion");

switch (graphExtVersion) {
case ZE_GRAPH_EXT_VERSION_1_3:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_3>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_4:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_4>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_5:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_5>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_6:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_6>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_7:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_7>>(_zeroInitStruct);
break;
case ZE_GRAPH_EXT_VERSION_1_8:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_8>>(_zeroInitStruct);
break;
default:
_zeGraphExt = std::make_shared<ZeGraphExtWrappers<ZE_GRAPH_EXT_VERSION_1_2>>(_zeroInitStruct);
break;
}
_zeGraphExt = std::make_shared<ZeGraphExtWrappers>(_zeroInitStruct);

_logger.info("initialize PluginCompilerAdapter complete, using graphExtVersion: %d.%d",
ZE_MAJOR_VERSION(graphExtVersion),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace intel_npu {

PluginGraph::PluginGraph(const std::shared_ptr<ZeGraphExtWrappersInterface>& zeGraphExt,
PluginGraph::PluginGraph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
const ov::SoPtr<ICompiler>& compiler,
const std::shared_ptr<ZeroInitStructsHolder>& zeroInitStruct,
ze_graph_handle_t graphHandle,
Expand Down
Loading

0 comments on commit 3e63de0

Please sign in to comment.