Skip to content

Commit

Permalink
support broadcast in dim which shape[dim]==1 for stablehlo
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxinyu committed Apr 14, 2024
1 parent 45eaeaa commit 6fd3f53
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
58 changes: 58 additions & 0 deletions lib/Conversion/TorchToStablehlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
llvm::seq<int64_t>(leadingRank, minRank + leadingRank));
auto lhsShape = lhsRankTy.getShape();
auto rhsShape = rhsRankTy.getShape();

if (lhsRank < rhsRank) {
std::vector<int64_t> newShape(rhsShape.begin(),
rhsShape.begin() + leadingRank);
Expand All @@ -169,6 +170,63 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
broadcastDims);
}

lhsShape = lhs.getType().cast<RankedTensorType>().getShape();
rhsShape = rhs.getType().cast<RankedTensorType>().getShape();
assert(lhsShape.size() == rhsShape.size());

int64_t resultRank = lhsShape.size();

// check shape compatibility, check if we should broadcast

// first, we should got a new shape. Check from (0, shape - 2)
// TODO: add dot 1dx1d, 1dx2d, 2dx1d
SmallVector<int64_t> lhsBroadcastDims;
SmallVector<int64_t> rhsBroadcastDims;
SmallVector<int64_t> newBatchShape;
for (size_t i = 0; i < resultRank - 2; i++) {
if (lhsShape[i] != rhsShape[i]) {
if (lhsShape[i] == 1) {
lhsBroadcastDims.push_back(i);
newBatchShape.push_back(rhsShape[i]);
} else if (rhsShape[i] == 1) {
rhsBroadcastDims.push_back(i);
newBatchShape.push_back(lhsShape[i]);
} else {
assert(false && "shape mismatch");
}
} else {
newBatchShape.push_back(lhsShape[i]);
}
}

auto lhsDimSizes =
*hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
auto rhsDimSizes =
*hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);

broadcastDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, resultRank));

if (!lhsBroadcastDims.empty()) {
SmallVector<int64_t> lhsNewShape(newBatchShape);
lhsNewShape.insert(lhsNewShape.end(),
lhsShape.begin() + lhsShape.size() - 2, lhsShape.end());
for (auto i : lhsBroadcastDims) {
lhsDimSizes[i] = rhsDimSizes[i];
}
lhs = getBroadcastTensor(rewriter, op, lhs, lhsNewShape, lhsDimSizes,
broadcastDims);
}
if (!rhsBroadcastDims.empty()) {
SmallVector<int64_t> rhsNewShape(newBatchShape);
rhsNewShape.insert(rhsNewShape.end(),
rhsShape.begin() + rhsShape.size() - 2, rhsShape.end());
for (auto i : rhsBroadcastDims) {
rhsDimSizes[i] = lhsDimSizes[i];
}
rhs = getBroadcastTensor(rewriter, op, rhs, rhsNewShape, rhsDimSizes,
broadcastDims);
}

inpLhs = lhs;
inpRhs = rhs;
}
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,7 @@
"Matmul_dot",
"Matmul_matvec",
"Matmul_vecmat",
"MatmulStaticBroadcast_basic",
"MaxPool2dStaticModule_basic",
"MeanDimAllReduceModule_basic",
"MeanDimEmptyDimModule_basic",
Expand Down

0 comments on commit 6fd3f53

Please sign in to comment.