Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
[MLIR] Use enum class for broadcast hint. (#4238)
Browse files Browse the repository at this point in the history
Co-authored-by: Scott Cyphers <[email protected]>
  • Loading branch information
ayzhuang and diyessi authored Jan 29, 2020
1 parent d3a8e62 commit 2651f73
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 14 deletions.
18 changes: 9 additions & 9 deletions src/contrib/mlir/backend/pass/affine_lowerer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1324,39 +1324,39 @@ namespace
}
attrs.gemmAttrs2d.ldc = attrs.gemmAttrs2d.n;

int broadcastHint = -2;
BroadcastType broadcastHint = BroadcastType::ERROR;
if (vBias.rank() == 0)
{
// Scalar
broadcastHint = 2;
broadcastHint = BroadcastType::ROWCOLUMN;
}
else if (vBias.rank() == 2)
{
if (biasShape[0] == attrs.gemmAttrs2d.m && biasShape[1] == 1)
{
broadcastHint = 1;
broadcastHint = BroadcastType::COLUMN;
}
else if (biasShape[0] == 1 && biasShape[1] == attrs.gemmAttrs2d.n)
{
broadcastHint = 0;
broadcastHint = BroadcastType::ROW;
}
else
else if (biasShape[0] == attrs.gemmAttrs2d.m && biasShape[1] == attrs.gemmAttrs2d.n)
{
broadcastHint = -1;
broadcastHint = BroadcastType::NONE;
}
}
else
{
if (biasShape[0] == attrs.gemmAttrs2d.m)
{
broadcastHint = 1;
broadcastHint = BroadcastType::COLUMN;
}
else if (biasShape[0] == attrs.gemmAttrs2d.n)
{
broadcastHint = 0;
broadcastHint = BroadcastType::ROW;
}
}
NGRAPH_CHECK(broadcastHint != -2, "Unhandled broadcast");
NGRAPH_CHECK(broadcastHint != BroadcastType::ERROR, "Unhandled broadcast");
attrs.gemmAttrs2d.broadcastHint = broadcastHint;

auto int64Ty = rewriter.getIntegerType(64);
Expand Down
11 changes: 10 additions & 1 deletion src/contrib/mlir/runtime/cpu/callback_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ namespace ngraph
SOFTMAX
};

enum class BroadcastType
{
NONE,
ROW,
COLUMN,
ROWCOLUMN,
ERROR
};

// These structs and union are used to pass attributes to callbacks.
template <int N>
struct poolAttrs
Expand All @@ -107,7 +116,7 @@ namespace ngraph
int64_t ldc;
float alpha;
float beta;
int64_t broadcastHint;
BroadcastType broadcastHint;
};

union opAttrs {
Expand Down
12 changes: 8 additions & 4 deletions src/contrib/mlir/runtime/cpu/cpu_callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut,
std::max<size_t>(1, ldc));

if (broadcastHint == 0)
if (broadcastHint == BroadcastType::ROW)
{
std::vector<float> ones(m, 1.0f);
cblas::cblas_sgemm(cblas::Layout::RowMajor,
Expand All @@ -654,7 +654,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut,
std::max<size_t>(1, ldc));
}
else if (broadcastHint == 1)
else if (broadcastHint == BroadcastType::COLUMN)
{
std::vector<float> ones(n, 1.0f);
cblas::cblas_sgemm(cblas::Layout::RowMajor,
Expand All @@ -672,7 +672,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut,
std::max<size_t>(1, ldc));
}
else if (broadcastHint == 2)
else if (broadcastHint == BroadcastType::ROWCOLUMN)
{
std::vector<float> ones(m, 1.0f);
std::vector<float> bias(n, *matC);
Expand All @@ -691,7 +691,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut,
std::max<size_t>(1, ldc));
}
else
else if (broadcastHint == BroadcastType::NONE)
{
std::vector<float> identity(n * n, 0.0f);
for (auto i = 0; i < n * n; i += n + 1)
Expand All @@ -713,6 +713,10 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
matOut,
std::max<size_t>(1, ldc));
}
else
{
NGRAPH_UNREACHABLE("Unsupported broadcast");
}
}

extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index, OpType type)
Expand Down

0 comments on commit 2651f73

Please sign in to comment.