Skip to content

Commit

Permalink
[compiler] refactor anonymous patterns to named patterns (#81)
Browse files Browse the repository at this point in the history
* [compiler] refactor anonymous patterns to named patterns

* update
  • Loading branch information
qingyunqu authored Nov 15, 2023
1 parent 16da0f9 commit 06b18bf
Show file tree
Hide file tree
Showing 3 changed files with 980 additions and 964 deletions.
111 changes: 1 addition & 110 deletions compiler/include/byteir/Dialect/mhlo/Transforms/CanonicalizeExt.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,122 +24,13 @@
namespace mlir {
class MLIRContext;

namespace tensor {
class InsertSliceOp;
}

namespace mhlo {
class AddOp;
class ClampOp;
class ConvertOp;
class CompareOp;
class CustomCallOp;
class TransposeOp;
class BroadcastInDimOp;
class ConcatenateOp;
class DynamicBroadcastInDimOp;
class DynamicConvOp;
class DynamicGatherOp;
class ReduceWindowOp;
class ReshapeOp;
class MulOp;
class SliceOp;
class ReverseOp;
class GatherOp;

// Most of these will push back to upstream
// So this file only includes patterns, not a pass.

///
/// foldBroadcastInDimConstWithBinary
///
/// BroadcastInDim could be folded in some special cases. Ex.
///
/// const
/// \
/// broadcast_in_dim const
/// \ /
/// mul
LogicalResult foldBroadcastInDimConstWithBinary(mhlo::BroadcastInDimOp op,
PatternRewriter &rewriter);

///
/// Fold concatenate of continuous slices
///
LogicalResult foldConcatWithContinuousSlices(mhlo::ConcatenateOp op,
PatternRewriter &rewriter);

// fold multi op with zero
LogicalResult foldMultiplyZero(mhlo::MulOp op, PatternRewriter &rewriter);

// fold binary op with large constant op
template <typename Op, template <typename> typename Func>
LogicalResult foldLargeBinaryOp(Op op, PatternRewriter &rewriter);

// mhlo.dynamic_conv => mhlo.convolution canonicalization
LogicalResult simplifyDynamicConvToConv(mhlo::DynamicConvOp op,
PatternRewriter &rewriter);

// constant folding for mhlo.concatenate with large result
LogicalResult foldLargeConcatenate(mhlo::ConcatenateOp op,
PatternRewriter &rewriter);

LogicalResult foldTransposeNonSplat(mhlo::TransposeOp op,
PatternRewriter &rewriter);

LogicalResult foldBeneficialConstantConvertOp(mhlo::ConvertOp op,
PatternRewriter &rewriter);

LogicalResult foldLargeCompareOp(mhlo::CompareOp op, PatternRewriter &rewriter);

// const + broadcast_in_dim => const + broadcast_in_dim
LogicalResult canonicalizeBroadcastInDimConst(mhlo::BroadcastInDimOp op,
PatternRewriter &rewriter);

// simplify an addOp of two chain of insert_slice's
// into a chain of insert_slice's
// when those insert_slice's are
// 1) not overlaped
// 2) along a single axis
// 3) sharing a zero Dest
LogicalResult simplifyAddInsertSlicesToInsertSlices(mhlo::AddOp op,
PatternRewriter &rewriter);

// simplify a chain of insert_slice's into a concat
// when those insert_slice's are
// 1) not overlaped
// 2) along a single axis
// 3) covering the entire Dest
LogicalResult simplifyFullInsertSlicesToConcat(mlir::tensor::InsertSliceOp op,
PatternRewriter &rewriter);

// simplify byteir.addn => mhlo.add
LogicalResult simplifyByteIRAddNToAdd(mhlo::CustomCallOp op,
PatternRewriter &rewriter);

LogicalResult foldLargeSliceOp(mhlo::SliceOp op, PatternRewriter &rewriter);

LogicalResult foldConcatWithSlicesAndRehape(mhlo::ConcatenateOp op,
PatternRewriter &rewriter);

// concat(broadcast_in_dim(x), broadcast_in_dim(x)) => broadcast_in_dim
LogicalResult canonicalizeConcatWithBroadcast(mhlo::ConcatenateOp op,
PatternRewriter &rewriter);

LogicalResult eliminateRedundantConvertFromI1(mhlo::ConvertOp op,
PatternRewriter &rewriter);

LogicalResult simplifyCumsumToIota(mhlo::ReduceWindowOp op,
PatternRewriter &rewriter);

// transpose(reshape(transpose(x))) => reshape(x)
LogicalResult simplifyTransposeReshapeTranspose(mhlo::TransposeOp op,
PatternRewriter &rewriter);

LogicalResult foldReverseWithConstant(mhlo::ReverseOp op,
PatternRewriter &rewriter);

LogicalResult foldGatherWithInput(mhlo::GatherOp op, PatternRewriter &rewriter);
void populateFoldMultiplyZeroPattern(RewritePatternSet &patterns);

// populate canonicalizeExt patterns
void populateCanonicalizeExtPatterns(RewritePatternSet &patterns,
Expand Down
Loading

0 comments on commit 06b18bf

Please sign in to comment.