Skip to content

Commit

Permalink
Fix rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Jan 10, 2025
1 parent 3c52f05 commit 8a78963
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
18 changes: 9 additions & 9 deletions src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
void load(BinaryInputBuffer& ib) override {
parent::load(ib);
ib >> make_data(&has_scores_output, sizeof(bool));
ib >> make_data(&has_rotation_coefficients, sizeof(bool));
ib >> make_data(&has_rotated_blocks, sizeof(bool));
if (is_dynamic()) {
auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
auto kv_cache_update_kernel_impl = kv_cache_update_kernel_selector.GetImplementation(_kernels_data[Stage::KV_CACHE_UPDATE].kernelName);
Expand All @@ -83,7 +83,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
auto pa_sdpa_kernel_impl = pa_sdpa_kernel_selector.GetImplementation(_kernels_data[Stage::PA_SDPA].kernelName);
pa_sdpa_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[Stage::PA_SDPA]);

if (has_rotation_coefficients) {
if (has_rotated_blocks) {
auto& kv_cache_rotate_kernel_selector = kv_cache_rotate_kernel_selector_t::Instance();
auto kv_cache_rotate_kernel_impl = kv_cache_rotate_kernel_selector.GetImplementation(_kernels_data[Stage::KV_CACHE_ROTATE].kernelName);
kv_cache_rotate_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[Stage::KV_CACHE_ROTATE]);
Expand All @@ -94,7 +94,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
void save(BinaryOutputBuffer& ob) const override {
parent::save(ob);
ob << make_data(&has_scores_output, sizeof(bool));
ob << make_data(&has_rotation_coefficients, sizeof(bool));
ob << make_data(&has_rotated_blocks, sizeof(bool));
}

std::vector<layout> get_internal_buffer_layouts_impl() const override {
Expand Down Expand Up @@ -347,7 +347,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {

std::vector<event::ptr> res_events;
std::vector<event::ptr> dep_events = events;
if (has_rotation_coefficients) {
if (has_rotated_blocks) {
execute_stage(dep_events, instance, res_events, Stage::KV_CACHE_ROTATE, is_mixed_mode);
dep_events = res_events;
}
Expand Down Expand Up @@ -472,7 +472,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
config.has_const_scale_val = false;
}

config.has_rotation_coefficients_input = desc->has_rotation_coefficients;
config.has_rotated_blocks = desc->has_rotated_blocks;

if (desc->heads_num != desc->kv_heads_num) {
config.broadcast_axis = 1;
Expand Down Expand Up @@ -752,7 +752,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
for (const auto& input_layout : impl_param.input_layouts)
input_tensors.emplace_back(convert_data_tensor(input_layout));

if (has_rotation_coefficients) {
if (has_rotated_blocks) {
auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params(impl_param, input_tensors, impl_param.is_dynamic());
(_kernels_data[Stage::KV_CACHE_ROTATE].update_dispatch_data_func)(kv_cache_rotate_kernel_params, _kernels_data[Stage::KV_CACHE_ROTATE]);
}
Expand Down Expand Up @@ -792,22 +792,22 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
auto& pa_sdpa_kernel_selector = pa_sdpa_kernel_selector_t::Instance();
kernels_data.push_back(pa_sdpa_kernel_selector.get_best_kernel(pa_sdpa_kernel_params));

if (desc->has_rotation_coefficients) {
if (desc->has_rotated_blocks) {
auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params(impl_param, input_tensors, impl_param.is_dynamic());
auto& kv_cache_rotate_kernel_selector = kv_cache_rotate_kernel_selector_t::Instance();
kernels_data.push_back(kv_cache_rotate_kernel_selector.get_best_kernel(kv_cache_rotate_kernel_params));
}

auto impl = cldnn::make_unique<paged_attention_impl>(kernels_data);
impl->has_scores_output = desc->has_scores_output();
impl->has_rotation_coefficients = desc->has_rotation_coefficients;
impl->has_rotated_blocks = desc->has_rotated_blocks;

return impl;
}

private:
bool has_scores_output = false;
bool has_rotation_coefficients = false;
bool has_rotated_blocks = false;
};

namespace detail {
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared
std::shared_ptr<ov::op::v0::Constant> alibi_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(op->get_input_node_shared_ptr(alibi_idx));
OPENVINO_ASSERT(alibi_const != nullptr);
prim.has_alibi = ov::shape_size(alibi_const->get_output_shape(0)) > 0;
prim.has_rotation_coefficients = op->get_input_size() == 16;
prim.has_rotated_blocks = op->get_input_size() == 16;

prim.num_outputs = 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ struct PagedAttentionTest : public ::testing::TestWithParam<T> {
pa_prim.scale_val = pam.get_default_scale();
pa_prim.has_alibi = false;
pa_prim.num_outputs = p.scores_output ? 2 : 1;
pa_prim.has_rotation_coefficients = p.rotation_config.apply_rotation;
pa_prim.has_rotated_blocks = p.rotation_config.apply_rotation;

topology topology;

Expand Down

0 comments on commit 8a78963

Please sign in to comment.