Skip to content

Commit

Permalink
[GPU] Fix default internal buffers sizes for SDPA
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Aug 28, 2024
1 parent d78c565 commit 790a001
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ static std::vector<size_t> get_internal_buffer_sizes(const sdpa_params& sdpa_par
return {blocks_indexes_buf_size};
} else {
if (sdpa_params.has_dynamic_tensors() || kernel_type == KernelsTypes::MULTI_TOKENS) {
return {1, 1};
const auto default_bytes_count = BytesPerElement(get_softmax_acc_type());
return {default_bytes_count, default_bytes_count};
} else {
TransposedDimensionAccessHelperBase dims_q(sdpa_params.inputs[0], sdpa_params.input0_order);
const auto& output = sdpa_params.outputs[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,21 @@ const std::vector<std::vector<InputShape>> shapes{
{ov::Shape{1, 1, 7, 7}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 10, 10}}}
},
},
// normal case, shapes of q,k,v are same, static shapes
{
// q shape
{ov::test::InputShape{ov::PartialShape{1, 8, 100, 128},
{ov::Shape{1, 8, 100, 128}}}
},
// kv shape
{ov::test::InputShape{ov::PartialShape{1, 8, 100, 128},
{ov::Shape{1, 8, 100, 128}}}
},
// attn shape: [B, 1, -1, L0+L1]
{ov::test::InputShape{ov::PartialShape{1, 1, 100, 100},
{ov::Shape{1, 1, 100, 100}}}
},
},
};

const auto params = testing::Combine(testing::Values(ov::element::f16 /*, ov::element::f32 */),
Expand Down

0 comments on commit 790a001

Please sign in to comment.