Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] fix SDPA produce NaN after transpose_fusion pass. #27629

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ JitConstants SDPAKernelBase::GetJitConstants(const sdpa_params& params) const {
};

auto use_index_calc_func = [&](const std::vector<int64_t> order, bool is_query = false) {
if (!params.input0_order.empty() && !is_default_order(params.input0_order))
if (!order.empty() && !is_default_order(order))
return true;

if (params.conf.broadcast_axis != -1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ typedef std::tuple<ov::element::Type, // netPrecision
std::vector<InputShape>, // shape
bool, // is_causal
bool, // has_attn
bool // has_scale
bool, // has_scale
std::vector<std::vector<int64_t>> // input_transpose
> ScaledAttnGPUTestParams;

class ScaledAttnLayerGPUTest : public testing::WithParamInterface<ScaledAttnGPUTestParams>,
Expand All @@ -36,6 +37,7 @@ class ScaledAttnLayerGPUTest : public testing::WithParamInterface<ScaledAttnGPUT
protected:
void SetUp() override;
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
void transpose_prepare(std::vector<InputShape>& shapes, const std::vector<std::vector<int64_t>>& input_transpose);
bool is_causal;
bool has_attn;
bool has_scale;
Expand All @@ -44,11 +46,14 @@ class ScaledAttnLayerGPUTest : public testing::WithParamInterface<ScaledAttnGPUT
std::string ScaledAttnLayerGPUTest::getTestCaseName(const testing::TestParamInfo<ScaledAttnGPUTestParams>& obj) {
ov::element::Type inType;
std::vector<InputShape> inputShapes;
std::vector<std::vector<int64_t>> input_transpose;
bool is_causal;
bool has_attn;
bool has_scale;
std::tie(inType, inputShapes, is_causal, has_attn, has_scale) = obj.param;
bool transpose_enable;
std::tie(inType, inputShapes, is_causal, has_attn, has_scale, input_transpose) = obj.param;

transpose_enable = (input_transpose.size() != 0);
std::ostringstream result;
result << "netPRC=" << inType << "_";
result << "IS=";
Expand All @@ -65,24 +70,27 @@ std::string ScaledAttnLayerGPUTest::getTestCaseName(const testing::TestParamInfo
result << "is_causal=" << is_causal << "_";
result << "has_attn=" << has_attn << "_";
result << "has_scale=" << has_scale << "_";
result << "with_transpose" << transpose_enable << "_";

return result.str();
}

void ScaledAttnLayerGPUTest::SetUp() {
ov::element::Type inType;
std::vector<InputShape> inputShapes;
std::vector<std::vector<int64_t>> input_transpose;

targetDevice = ov::test::utils::DEVICE_GPU;

std::tie(inType, inputShapes, is_causal, has_attn, has_scale) = this->GetParam();
std::tie(inType, inputShapes, is_causal, has_attn, has_scale, input_transpose) = this->GetParam();

transpose_prepare(inputShapes, input_transpose);
init_input_shapes(inputShapes);
ov::ParameterVector inputParams;
// q, k, v
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[2]));
inputParams[0]->set_friendly_name("q");
inputParams[1]->set_friendly_name("k");
inputParams[2]->set_friendly_name("v");
Expand All @@ -96,7 +104,7 @@ void ScaledAttnLayerGPUTest::SetUp() {
inputParams.back()->set_friendly_name("scale");
} else {
if (has_attn) {
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[2]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[3]));
inputParams.back()->set_friendly_name("attention_mask");
}
if (has_scale) {
Expand All @@ -106,9 +114,31 @@ void ScaledAttnLayerGPUTest::SetUp() {
}
}

ov::OutputVector inputs;
ov::OutputVector inputParams_transpose;
for (size_t i = 0; i < inputParams.size(); i++) {
inputs.push_back(inputParams[i]);
inputParams_transpose.push_back(inputParams[i]);
}
if (input_transpose.size() != 0) {
// deal with transpose.
auto tranpose_a_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, input_transpose[0]);
auto tranpose_a = std::make_shared<ov::op::v1::Transpose>(inputParams[0], tranpose_a_const);
tranpose_a->set_friendly_name("tranpose_a");
inputParams_transpose[0] = tranpose_a;

auto tranpose_b_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, input_transpose[1]);
auto tranpose_b = std::make_shared<ov::op::v1::Transpose>(inputParams[1], tranpose_b_const);
tranpose_b->set_friendly_name("tranpose_b");
inputParams_transpose[1] = tranpose_b;

auto tranpose_c_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, input_transpose[2]);
auto tranpose_c = std::make_shared<ov::op::v1::Transpose>(inputParams[2], tranpose_c_const);
tranpose_c->set_friendly_name("tranpose_c");
inputParams_transpose[2] = tranpose_c;
}

ov::OutputVector inputs;
for (size_t i = 0; i < inputParams_transpose.size(); i++) {
inputs.push_back(inputParams_transpose[i]);
}

auto sdp = std::make_shared<ov::opset13::ScaledDotProductAttention>(inputs, is_causal);
Expand Down Expand Up @@ -141,17 +171,53 @@ void ScaledAttnLayerGPUTest::SetUp() {
}
}

void ScaledAttnLayerGPUTest::transpose_prepare(std::vector<InputShape>& shapes,
const std::vector<std::vector<int64_t>>& input_transpose) {
auto transpose_pshape = [](InputShape& pshapes, const std::vector<int64_t>& order) {
auto transposed_pshape = ov::PartialShape::dynamic(pshapes.first.rank());
std::vector<ov::Shape> transposed_cshapes(pshapes.second);
auto& pshape = pshapes.first;
auto& cshape = pshapes.second;
for (size_t i = 0; i < order.size(); i++) {
transposed_pshape[i] = pshape[order[i]];
for (size_t j = 0; j < cshape.size(); j++) {
transposed_cshapes[j][i] = cshape[j][order[i]];
}
}

for (size_t i = 0; i < order.size(); i++) {
pshape[i] = transposed_pshape[i];
for (size_t j = 0; j < cshape.size(); j++) {
cshape[j][i] = transposed_cshapes[j][i];
}
}
};

if (shapes.empty()) {
return;
}

shapes.insert(shapes.begin()+1, shapes[1]);
if (input_transpose.empty()) {
return;
}

for (size_t i = 0; i < input_transpose.size(); i++) {
transpose_pshape(shapes[i], input_transpose[i]);
}
}

void ScaledAttnLayerGPUTest::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) {
std::vector<ov::Shape> shapes(3);
shapes[0] = targetInputStaticShapes[0];
shapes[1] = targetInputStaticShapes[1];
shapes[2] = targetInputStaticShapes[1];
shapes[2] = targetInputStaticShapes[2];
if (!has_attn && has_scale) {
shapes.push_back(ov::Shape{});
shapes.push_back(ov::Shape{1});
} else {
if (has_attn) {
shapes.push_back(targetInputStaticShapes[2]);
shapes.push_back(targetInputStaticShapes[3]);
}
if (has_scale) {
shapes.push_back(ov::Shape{1});
Expand All @@ -163,10 +229,11 @@ void ScaledAttnLayerGPUTest::generate_inputs(const std::vector<ov::Shape>& targe
TEST_P(ScaledAttnLayerGPUTest, CompareWithRefs) {
ov::element::Type inType;
std::vector<InputShape> inputShapes;
std::vector<std::vector<int64_t>> input_transpose;
bool is_causal;
bool has_attn;
bool has_scale;
std::tie(inType, inputShapes, is_causal, has_attn, has_scale) = this->GetParam();
std::tie(inType, inputShapes, is_causal, has_attn, has_scale, input_transpose) = this->GetParam();
run();
}

Expand Down Expand Up @@ -261,11 +328,15 @@ const std::vector<std::vector<InputShape>> shapes{
},
};

const std::vector<std::vector<int64_t>> disable_transpose{};
const std::vector<std::vector<int64_t>> enable_transpose{{0, 1, 2, 3}, {0, 1, 2, 3}, {0, 2, 1, 3}};

const auto params = testing::Combine(testing::Values(ov::element::f16 /*, ov::element::f32 */),
testing::ValuesIn(shapes),
testing::Values(true, false),
testing::Values(true, false),
testing::Values(true, false));
testing::Values(true, false),
testing::ValuesIn({disable_transpose, enable_transpose}));

INSTANTIATE_TEST_SUITE_P(smoke_ScaledAttn_GPU,
ScaledAttnLayerGPUTest,
Expand Down
Loading