Skip to content

Commit

Permalink
Improve internal memory reuse
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Aug 26, 2024
1 parent 75b0ade commit 34d731c
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class wait_for_events_impl : public primitive_impl {
void init_kernels(const kernels_cache&, const kernel_impl_params&) override {}
void set_arguments(primitive_inst& /*instance*/) override {}
void set_arguments(primitive_inst& /*instance*/, kernel_arguments_data& /*args*/) override {}
std::vector<layout> get_internal_buffer_layouts(const kernel_impl_params& /*params*/) const override { return {}; }
std::vector<layout> get_internal_buffer_layouts() const override { return {}; }

event::ptr execute(const std::vector<event::ptr>& events, primitive_inst& instance) override {
auto& stream = instance.get_network().get_stream();
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/graph/impls/ocl/border.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ struct border_impl : typed_primitive_impl_ocl<border> {
return args;
}

std::vector<layout> get_internal_buffer_layouts_impl(const kernel_impl_params& /*params*/) const override {
std::vector<layout> get_internal_buffer_layouts_impl() const override {
const auto& prim_params = static_cast<const kernel_selector::border_params&>(*_kernel_data.params);
std::vector<layout> layouts;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ struct multi_stage_primitive : public typed_primitive_impl<PType> {
return _kernels;
}

std::vector<layout> get_internal_buffer_layouts_impl(const kernel_impl_params& /*params*/) const override {
std::vector<layout> get_internal_buffer_layouts_impl() const override {
std::vector<layout> layouts;
for (auto& kd : _kernels_data) {
if (kd.internalBufferSizes.empty())
Expand Down
33 changes: 20 additions & 13 deletions src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
}
}

std::vector<layout> get_internal_buffer_layouts_impl(const kernel_impl_params& params) const override {
std::vector<layout> get_internal_buffer_layouts_impl() const override {
auto add_internal_buffers = [](std::vector<layout>& layouts, const kernel_selector::KernelData& kd) {
if (kd.internalBufferSizes.empty())
return;
Expand All @@ -84,13 +84,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
};

std::vector<layout> layouts;
if (is_prefill_stage(params)) {
add_internal_buffers(layouts, _kernels_data[Stage::KV_CACHE_UPDATE]);
add_internal_buffers(layouts, _kernels_data[Stage::SDPA]);
} else {
add_internal_buffers(layouts, _kernels_data[Stage::KV_CACHE_UPDATE]);
add_internal_buffers(layouts, _kernels_data[Stage::PA_SDPA]);
}
add_internal_buffers(layouts, _kernels_data[Stage::SDPA]);
add_internal_buffers(layouts, _kernels_data[Stage::PA_SDPA]);

return layouts;
}
Expand Down Expand Up @@ -144,8 +139,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
return args;
}

std::set<size_t> get_lockable_internal_buffers(const kernel_impl_params& params) const override {
return is_prefill_stage(params) ? std::set<size_t>{ 0, 1, 2 } : std::set<size_t>{};
std::set<size_t> get_lockable_internal_buffers() const override {
return std::set<size_t>{ 0, 1, 2 }; /* SDPA and KV_CACHE_UPDATE indexes configuration */
};

void execute_stage(const std::vector<event::ptr>& events, paged_attention_inst& instance, std::vector<event::ptr>& all_events, size_t stage) {
Expand All @@ -155,6 +150,17 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
for (size_t s = 0; s < stage; s++) {
kernel_offset += _kernels_data[s].kernels.size();
}

// Stages SDPA and KV_CACHE_UPDATE reuse the same internal buffers at prefill stage
size_t internal_buffers_offset = 0;
size_t internal_buffers_count = 0;
if (stage == Stage::PA_SDPA) {
internal_buffers_offset = _kernels_data[Stage::SDPA].internalBufferSizes.size();
internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBufferSizes.size();
} else {
internal_buffers_count = _kernels_data[Stage::SDPA].internalBufferSizes.size();
}

for (size_t kd_idx = 0; kd_idx < _kernels_data[stage].kernels.size(); ++kd_idx) {
if (_kernels_data[stage].kernels[kd_idx].skip_execution)
continue;
Expand All @@ -168,9 +174,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
auto args = get_arguments(instance, stage, kd_idx);
args.scalars = &params.scalars;

for (const auto& m : instance.get_intermediates_memories()) {
args.intermediates.push_back(m);
}
const auto& intermediate_memories = instance.get_intermediates_memories();
args.intermediates.insert(args.intermediates.end(),
intermediate_memories.begin() + internal_buffers_offset,
intermediate_memories.begin() + internal_buffers_offset + internal_buffers_count);

stream.set_arguments(*_kernels[idx_final], _kernels_data[stage].kernels[kd_idx].params, args);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ struct typed_primitive_impl_ocl : public typed_primitive_impl<PType> {
return _kernels;
}

std::vector<layout> get_internal_buffer_layouts_impl(const kernel_impl_params& /*params*/) const override {
std::vector<layout> get_internal_buffer_layouts_impl() const override {
if (_kernel_data.internalBufferSizes.empty())
return {};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
}

protected:
std::vector<layout> get_internal_buffer_layouts_impl(const kernel_impl_params& /*params*/) const override {
std::vector<layout> get_internal_buffer_layouts_impl() const override {
// TODO: current implementation is supposed to have the same kernel version for both indirect/default paths,
// considering this, we may assume that both indirect/default kernels have absolutely the same intermediate
// buffers number and its' sizes (since update_dispatch_data is called for both kernels too), and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ struct typed_primitive_onednn_impl : public typed_primitive_impl<PType> {
return event;
}

std::vector<layout> get_internal_buffer_layouts_impl(const kernel_impl_params& /*params*/) const override {
std::vector<layout> get_internal_buffer_layouts_impl() const override {
if (_scratchpad_md.get_size() == 0)
return {};
return {{{1, 1, 1, (tensor::value_type)(_scratchpad_md.get_size())}, cldnn::data_types::u8, format::bfyx}};
Expand Down
10 changes: 5 additions & 5 deletions src/plugins/intel_gpu/src/graph/include/primitive_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ struct primitive_impl {
primitive_impl(nullptr, std::move(kernel_name), is_dynamic) {}
virtual ~primitive_impl() = default;

virtual std::vector<layout> get_internal_buffer_layouts(const kernel_impl_params& params) const = 0;
virtual std::set<size_t> get_lockable_internal_buffers(const kernel_impl_params& params) const { return {}; }
virtual std::vector<layout> get_internal_buffer_layouts() const = 0;
virtual std::set<size_t> get_lockable_internal_buffers() const { return {}; }
virtual void set_node_params(const program_node&) {}
virtual const std::string& get_type_info() const = 0;
virtual void set_arguments(primitive_inst& instance) = 0;
Expand Down Expand Up @@ -486,11 +486,11 @@ struct typed_primitive_impl : public primitive_impl {
return execute_impl(event, reinterpret_cast<typed_primitive_inst<PType>&>(instance));
}

std::vector<layout> get_internal_buffer_layouts(const kernel_impl_params& params) const override {
return get_internal_buffer_layouts_impl(params);
std::vector<layout> get_internal_buffer_layouts() const override {
return get_internal_buffer_layouts_impl();
}

virtual std::vector<layout> get_internal_buffer_layouts_impl(const kernel_impl_params& params) const {
virtual std::vector<layout> get_internal_buffer_layouts_impl() const {
return {};
}

Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/graph/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void paged_attention_inst::on_execute() {
if (!is_prefill_stage(*_impl_params))
return;

OPENVINO_ASSERT(_intermediates_memory.size() == 3, "Unexpected number of intermediates buffers for Paged Attention at prefill stage");
OPENVINO_ASSERT(_intermediates_memory.size() >= 3, "Unexpected number of intermediates buffers for Paged Attention at prefill stage");

const auto blocks_indexes_start_idx = 0;
const auto blocks_indexes_end_idx = 1;
Expand Down
19 changes: 7 additions & 12 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ event::ptr primitive_inst::realloc_if_needed() {
{
if (_impl == nullptr)
return ev;
const auto& ibuf_layouts = _impl->get_internal_buffer_layouts(*_impl_params);
const auto& ibuf_layouts = _impl->get_internal_buffer_layouts();
if (ibuf_layouts.empty())
return ev;
GPU_DEBUG_CODE(std::string memalloc_info = "");
Expand Down Expand Up @@ -1820,12 +1820,6 @@ primitive_inst::primitive_inst(network & network, program_node const& node, bool
allocate_memory = _mem_allocated = available_allocate_memory(_impl_params->output_layouts);
}

// Do not allocate zero buffers for constant and reuse memory from program node
// if (_is_constant && get_output_layout().count() == 0) {
// _outputs[0] = node.as<data>().get_attached_memory_ptr();
// allocate_memory = false;
// }

if (allocate_memory) {
// In case when output is mutable_data primitive, and other users dependencies are only used for
// synchronization, The output memory of such primitive will be fused with mutable_data
Expand Down Expand Up @@ -1876,7 +1870,7 @@ primitive_inst::primitive_inst(network & network, program_node const& node, bool
memory::ptr primitive_inst::allocate_internal_buffer(size_t idx, bool reset) {
if (_impl == nullptr || _outputs.empty() || _outputs[0] == nullptr)
return nullptr;
const auto& ibuf_layouts = _impl->get_internal_buffer_layouts(*_impl_params);
const auto& ibuf_layouts = _impl->get_internal_buffer_layouts();
if (ibuf_layouts.empty())
return nullptr;

Expand Down Expand Up @@ -1905,11 +1899,12 @@ memory::ptr primitive_inst::allocate_internal_buffer(size_t idx, bool reset) {
}

int64_t available_device_mem_size = engine.get_device_info().max_global_mem_size - total_device_mem_size;
// TODO: check if this logic is needed
// check if there is any device mem input
if (engine.supports_allocation(allocation_type::usm_device)) {
for (const auto& dep : inst_deps) {
if (!dep.first->mem_allocated()) continue;
if (dep.first->output_memory().get_allocation_type() == allocation_type::usm_device) {
if (dep.first->output_memory_ptr() &&
dep.first->output_memory_ptr()->get_allocation_type() == allocation_type::usm_device) {
input_device_mem = true;
break;
}
Expand All @@ -1918,7 +1913,7 @@ memory::ptr primitive_inst::allocate_internal_buffer(size_t idx, bool reset) {
// allocate intermediate memory for the updated layout of buffer
auto layout = ibuf_layouts[idx];
auto alloc_type = allocation_type::unknown;
const auto& lockable_buffers_indexes = _impl->get_lockable_internal_buffers(*_impl_params);
const auto& lockable_buffers_indexes = _impl->get_lockable_internal_buffers();
auto need_lockable_allocation = lockable_buffers_indexes.find(idx) != lockable_buffers_indexes.end();
GPU_DEBUG_LOG << "[" << _node->id() << ": internal buf " << idx << "] "
<< layout.to_short_string() << " need_lockable_allocation=" << need_lockable_allocation << std::endl;
Expand Down Expand Up @@ -1956,7 +1951,7 @@ memory::ptr primitive_inst::allocate_internal_buffer(size_t idx, bool reset) {
void primitive_inst::allocate_internal_buffers(bool reset) {
if (_impl == nullptr || _outputs.empty() || _outputs[0] == nullptr)
return;
const auto& ibuf_layouts = _impl->get_internal_buffer_layouts(*_impl_params);
const auto& ibuf_layouts = _impl->get_internal_buffer_layouts();
if (ibuf_layouts.empty())
return;

Expand Down

0 comments on commit 34d731c

Please sign in to comment.