Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New plugins property to pass mmap buffer #27981

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#pragma once

#include "openvino/runtime/aligned_buffer.hpp"
#include "openvino/runtime/properties.hpp"
#include "openvino/runtime/threading/istreams_executor.hpp"

Expand Down Expand Up @@ -36,6 +37,12 @@ static constexpr Property<std::vector<PropertyName>, PropertyMutability::RO> cac
*/
static constexpr Property<bool, PropertyMutability::RO> caching_with_mmap{"CACHING_WITH_MMAP"};

/**
* @brief Property to get a ov::AlignedBuffer with cached model
* @ingroup ov_dev_api_plugin_api
*/
static constexpr Property<std::shared_ptr<ov::AlignedBuffer>, PropertyMutability::RW> cached_model_buffer{"CACHED_MODEL_BUFFER"};

/**
* @brief Allow to create exclusive_async_requests with one executor
* @ingroup ov_dev_api_plugin_api
Expand Down
6 changes: 3 additions & 3 deletions src/inference/src/cache_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ICacheManager {
/**
* @brief Function passing created input stream
*/
using StreamReader = std::function<void(std::istream&)>;
using StreamReader = std::function<void(std::istream&, std::shared_ptr<ov::AlignedBuffer>)>;

/**
* @brief Callback when OpenVINO intends to read model from cache
Expand Down Expand Up @@ -143,10 +143,10 @@ class FileStorageCacheManager final : public ICacheManager {
std::make_shared<ov::SharedBuffer<std::shared_ptr<MappedMemory>>>(mmap->data(), mmap->size(), mmap);
OwningSharedStreamBuffer buf(shared_buffer);
std::istream stream(&buf);
reader(stream);
reader(stream, shared_buffer);
} else {
std::ifstream stream(blob_file_name, std::ios_base::binary);
reader(stream);
reader(stream, nullptr);
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/inference/src/dev/core_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1413,7 +1413,7 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::load_model_from_cache(
cacheContent.blobId,
coreConfig.get_enable_mmap() && ov::util::contains(plugin.get_property(ov::internal::supported_properties),
ov::internal::caching_with_mmap),
[&](std::istream& networkStream) {
[&](std::istream& networkStream, std::shared_ptr<ov::AlignedBuffer> model_buffer) {
OV_ITT_SCOPE(FIRST_INFERENCE,
ov::itt::domains::LoadTime,
"Core::load_model_from_cache::ReadStreamAndImport");
Expand Down Expand Up @@ -1459,6 +1459,9 @@ ov::SoPtr<ov::ICompiledModel> ov::CoreImpl::load_model_from_cache(
update_config[ov::weights_path.name()] = weights_path;
}
}
if (model_buffer) {
update_config[ov::internal::cached_model_buffer.name()] = model_buffer;
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
}
compiled_model = context ? plugin.import_model(networkStream, context, update_config)
: plugin.import_model(networkStream, update_config);
});
Expand Down
136 changes: 136 additions & 0 deletions src/inference/tests/functional/caching_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,142 @@ TEST_P(CachingTest, Load_threads) {
std::cout << "Caching Load multiple threads test completed. Tried " << index << " times" << std::endl;
}

TEST_P(CachingTest, Load_mmap) {
ON_CALL(*mockPlugin, import_model(_, _)).WillByDefault(Invoke([&](std::istream& istr, const ov::AnyMap& config) {
if (m_checkConfigCb) {
m_checkConfigCb(config);
}
std::shared_ptr<ov::AlignedBuffer> model_buffer;
if (config.count(ov::internal::cached_model_buffer.name()))
model_buffer = config.at(ov::internal::cached_model_buffer.name()).as<std::shared_ptr<ov::AlignedBuffer>>();
EXPECT_TRUE(model_buffer);

std::string name;
istr >> name;
char space;
istr.read(&space, 1);
std::lock_guard<std::mutex> lock(mock_creation_mutex);
return create_mock_compiled_model(m_models[name], mockPlugin);
}));

ON_CALL(*mockPlugin, get_property(ov::internal::supported_properties.name(), _))
.WillByDefault(Invoke([&](const std::string&, const ov::AnyMap&) {
return std::vector<ov::PropertyName>{ov::internal::caching_properties.name(),
ov::internal::caching_with_mmap.name()};
}));
EXPECT_CALL(*mockPlugin, get_property(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, query_model(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, get_property(ov::device::architecture.name(), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, get_property(ov::internal::caching_properties.name(), _)).Times(AnyNumber());
if (m_remoteContext) {
return; // skip the remote Context test for Multi plugin
}
int index = 0;
m_post_mock_net_callbacks.emplace_back([&](MockICompiledModelImpl& net) {
EXPECT_CALL(net, export_model(_)).Times(1);
});
MkDirGuard guard(m_cacheDir);
EXPECT_CALL(*mockPlugin, compile_model(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, compile_model(A<const std::shared_ptr<const ov::Model>&>(), _)).Times(1);
EXPECT_CALL(*mockPlugin, import_model(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, import_model(_, _)).Times(1);
testLoad([&](ov::Core& core) {
core.set_property({{ov::cache_dir.name(), m_cacheDir}});
m_testFunction(core);
m_testFunction(core);
});
std::cout << "Caching Load multiple threads test completed. Tried " << index << " times" << std::endl;
}

TEST_P(CachingTest, Load_mmap_is_disabled) {
ON_CALL(*mockPlugin, import_model(_, _)).WillByDefault(Invoke([&](std::istream& istr, const ov::AnyMap& config) {
if (m_checkConfigCb) {
m_checkConfigCb(config);
}
std::shared_ptr<ov::AlignedBuffer> model_buffer;
if (config.count(ov::internal::cached_model_buffer.name()))
model_buffer = config.at(ov::internal::cached_model_buffer.name()).as<std::shared_ptr<ov::AlignedBuffer>>();
EXPECT_FALSE(model_buffer);

std::string name;
istr >> name;
char space;
istr.read(&space, 1);
std::lock_guard<std::mutex> lock(mock_creation_mutex);
return create_mock_compiled_model(m_models[name], mockPlugin);
}));
ON_CALL(*mockPlugin, get_property(ov::internal::supported_properties.name(), _))
.WillByDefault(Invoke([&](const std::string&, const ov::AnyMap&) {
return std::vector<ov::PropertyName>{ov::internal::caching_properties.name(),
ov::internal::caching_with_mmap.name()};
}));
EXPECT_CALL(*mockPlugin, get_property(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, query_model(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, get_property(ov::device::architecture.name(), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, get_property(ov::internal::caching_properties.name(), _)).Times(AnyNumber());
if (m_remoteContext) {
return; // skip the remote Context test for Multi plugin
}
int index = 0;
m_post_mock_net_callbacks.emplace_back([&](MockICompiledModelImpl& net) {
EXPECT_CALL(net, export_model(_)).Times(1);
});
MkDirGuard guard(m_cacheDir);
EXPECT_CALL(*mockPlugin, compile_model(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, compile_model(A<const std::shared_ptr<const ov::Model>&>(), _)).Times(1);
EXPECT_CALL(*mockPlugin, import_model(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, import_model(_, _)).Times(1);
testLoad([&](ov::Core& core) {
core.set_property({{ov::cache_dir.name(), m_cacheDir}});
core.set_property({ov::enable_mmap(false)});
m_testFunction(core);
m_testFunction(core);
});
std::cout << "Caching Load multiple threads test completed. Tried " << index << " times" << std::endl;
}

TEST_P(CachingTest, Load_mmap_is_not_supported_by_plugin) {
ON_CALL(*mockPlugin, import_model(_, _)).WillByDefault(Invoke([&](std::istream& istr, const ov::AnyMap& config) {
if (m_checkConfigCb) {
m_checkConfigCb(config);
}
std::shared_ptr<ov::AlignedBuffer> model_buffer;
if (config.count(ov::internal::cached_model_buffer.name()))
model_buffer = config.at(ov::internal::cached_model_buffer.name()).as<std::shared_ptr<ov::AlignedBuffer>>();
EXPECT_FALSE(model_buffer);

std::string name;
istr >> name;
char space;
istr.read(&space, 1);
std::lock_guard<std::mutex> lock(mock_creation_mutex);
return create_mock_compiled_model(m_models[name], mockPlugin);
}));
EXPECT_CALL(*mockPlugin, get_property(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, query_model(_, _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, get_property(ov::device::architecture.name(), _)).Times(AnyNumber());
EXPECT_CALL(*mockPlugin, get_property(ov::internal::caching_properties.name(), _)).Times(AnyNumber());
if (m_remoteContext) {
return; // skip the remote Context test for Multi plugin
}
int index = 0;
m_post_mock_net_callbacks.emplace_back([&](MockICompiledModelImpl& net) {
EXPECT_CALL(net, export_model(_)).Times(1);
});
MkDirGuard guard(m_cacheDir);
EXPECT_CALL(*mockPlugin, compile_model(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, compile_model(A<const std::shared_ptr<const ov::Model>&>(), _)).Times(1);
EXPECT_CALL(*mockPlugin, import_model(_, _, _)).Times(0);
EXPECT_CALL(*mockPlugin, import_model(_, _)).Times(1);
testLoad([&](ov::Core& core) {
core.set_property({{ov::cache_dir.name(), m_cacheDir}});
core.set_property({ov::enable_mmap(true)});
m_testFunction(core);
m_testFunction(core);
});
std::cout << "Caching Load multiple threads test completed. Tried " << index << " times" << std::endl;
}

#if defined(ENABLE_OV_IR_FRONTEND)

static std::string getTestCaseName(const testing::TestParamInfo<std::tuple<TestParam, std::string>>& obj) {
Expand Down
9 changes: 8 additions & 1 deletion src/plugins/intel_cpu/src/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,16 @@ std::shared_ptr<ov::ICompiledModel> Plugin::import_model(std::istream& model_str
decript_from_string = true;
}

auto _config = config;
std::shared_ptr<ov::AlignedBuffer> model_buffer;
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
if (_config.count(ov::internal::cached_model_buffer.name())) {
model_buffer = _config.at(ov::internal::cached_model_buffer.name()).as<std::shared_ptr<ov::AlignedBuffer>>();
_config.erase(ov::internal::cached_model_buffer.name());
}

ModelDeserializer deserializer(
model_stream,
model_buffer,
[this](const std::shared_ptr<ov::AlignedBuffer>& model, const std::shared_ptr<ov::AlignedBuffer>& weights) {
return get_core()->read_model(model, weights);
},
Expand All @@ -579,7 +587,6 @@ std::shared_ptr<ov::ICompiledModel> Plugin::import_model(std::istream& model_str
Config::ModelType modelType = getModelType(model);
conf.applyRtInfo(model);
// check ov::loaded_from_cache property and erase it to avoid exception in readProperties.
auto _config = config;
const auto& it = _config.find(ov::loaded_from_cache.name());
bool loaded_from_cache = false;
if (it != _config.end()) {
Expand Down
13 changes: 8 additions & 5 deletions src/plugins/intel_cpu/src/utils/serialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ void ModelSerializer::operator<<(const std::shared_ptr<ov::Model>& model) {

////////// ModelDeserializer //////////

ModelDeserializer::ModelDeserializer(std::istream& model_stream, ModelBuilder fn, const CacheDecrypt& decrypt_fn, bool decript_from_string)
: m_istream(model_stream), m_model_builder(std::move(fn)), m_decript_from_string(decript_from_string) {
ModelDeserializer::ModelDeserializer(std::istream& model_stream,
std::shared_ptr<ov::AlignedBuffer> model_buffer,
ModelBuilder fn,
const CacheDecrypt& decrypt_fn,
bool decript_from_string)
: m_istream(model_stream), m_model_builder(std::move(fn)), m_decript_from_string(decript_from_string), m_model_buffer(model_buffer) {
if (m_decript_from_string) {
m_cache_decrypt.m_decrypt_str = decrypt_fn.m_decrypt_str;
} else {
Expand All @@ -42,9 +46,8 @@ ModelDeserializer::ModelDeserializer(std::istream& model_stream, ModelBuilder fn
void ModelDeserializer::set_info(pugi::xml_node& root, std::shared_ptr<ov::Model>& model) {}

void ModelDeserializer::operator>>(std::shared_ptr<ov::Model>& model) {
if (auto mmap_buffer = dynamic_cast<OwningSharedStreamBuffer*>(m_istream.rdbuf())) {
auto buffer = mmap_buffer->get_buffer();
process_mmap(model, buffer);
if (m_model_buffer) {
process_mmap(model, m_model_buffer);
} else {
process_stream(model);
}
Expand Down
7 changes: 6 additions & 1 deletion src/plugins/intel_cpu/src/utils/serialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ class ModelDeserializer {
public:
typedef std::function<std::shared_ptr<ov::Model>(const std::shared_ptr<ov::AlignedBuffer>&, const std::shared_ptr<ov::AlignedBuffer>&)> ModelBuilder;

ModelDeserializer(std::istream& model, ModelBuilder fn, const CacheDecrypt& encrypt_fn, bool decript_from_string);
ModelDeserializer(std::istream& model,
std::shared_ptr<ov::AlignedBuffer> model_buffer,
ModelBuilder fn,
const CacheDecrypt& encrypt_fn,
bool decript_from_string);

virtual ~ModelDeserializer() = default;

Expand All @@ -48,6 +52,7 @@ class ModelDeserializer {
ModelBuilder m_model_builder;
CacheDecrypt m_cache_decrypt;
bool m_decript_from_string;
std::shared_ptr<ov::AlignedBuffer> m_model_buffer;
};

} // namespace intel_cpu
Expand Down
6 changes: 6 additions & 0 deletions src/plugins/intel_gpu/src/plugin/plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,12 @@ std::shared_ptr<ov::ICompiledModel> Plugin::import_model(std::istream& model,
_orig_config.erase(it);
}

std::shared_ptr<ov::AlignedBuffer> model_buffer;
if (_orig_config.count(ov::internal::cached_model_buffer.name())) {
model_buffer = _orig_config.at(ov::internal::cached_model_buffer.name()).as<std::shared_ptr<ov::AlignedBuffer>>();
_orig_config.erase(ov::internal::cached_model_buffer.name());
}

ExecutionConfig config = m_configs_map.at(device_id);
config.set_user_property(_orig_config);
config.apply_user_properties(context_impl->get_engine().get_device_info());
Expand Down
Loading