Skip to content

Commit

Permalink
[Sync] Support CUDA multi-stream work queue and other updates (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochengji authored Dec 12, 2023
1 parent 56b8397 commit f99e390
Show file tree
Hide file tree
Showing 22 changed files with 343 additions and 541 deletions.
50 changes: 50 additions & 0 deletions compiler/lib/Dialect/mhlo/Transforms/CanonicalizeExt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,55 @@ struct FoldLargeBinaryOp : OpRewritePattern<Op> {
}
};

struct FoldClampOp : public OpRewritePattern<mhlo::ClampOp> {
using OpRewritePattern<mhlo::ClampOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::ClampOp op,
PatternRewriter &rewriter) const override {
mhlo::ConstantOp constOp =
op.getOperand().getDefiningOp<mhlo::ConstantOp>();
mhlo::ConstantOp minOp = op.getMin().getDefiningOp<mhlo::ConstantOp>();
mhlo::ConstantOp maxOp = op.getMax().getDefiningOp<mhlo::ConstantOp>();
if (!constOp || !minOp || !maxOp) {
return failure();
}

RankedTensorType operandType =
op.getOperand().getType().cast<RankedTensorType>();
ElementsAttr minValue = minOp.getValue();
ElementsAttr maxValue = maxOp.getValue();
if (minValue.getShapedType().getRank() == 0) {
minValue = DenseElementsAttr::get(operandType,
minValue.getValues<Attribute>()[0]);
}
if (maxValue.getShapedType().getRank() == 0) {
maxValue = DenseElementsAttr::get(operandType,
maxValue.getValues<Attribute>()[0]);
}

Attribute result;
if (operandType.getElementType().isa<FloatType>()) {
result = BinaryFolder<mhlo::ClampOp, FloatType, APFloat, Max<APFloat>>(
&op, ArrayRef<Attribute>{minValue, constOp.getValue()});
result = BinaryFolder<mhlo::ClampOp, FloatType, APFloat, Min<APFloat>>(
&op, ArrayRef<Attribute>{maxValue, result});

} else if (operandType.getElementType().isa<IntegerType>()) {
result = BinaryFolder<mhlo::ClampOp, IntegerType, APInt, Max<APSInt>>(
&op, ArrayRef<Attribute>{minValue, constOp.getValue()});
result = BinaryFolder<mhlo::ClampOp, IntegerType, APInt, Min<APSInt>>(
&op, ArrayRef<Attribute>{maxValue, result});
}
if (!result) {
return failure();
}

mhlo::ConstantOp newConstOp =
rewriter.create<mhlo::ConstantOp>(op->getLoc(), result);
rewriter.replaceOp(op, newConstOp.getOutput());
return success();
}
};

struct FoldLargeCompareOp : public OpRewritePattern<mhlo::CompareOp> {
using OpRewritePattern<mhlo::CompareOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::CompareOp op,
Expand Down Expand Up @@ -1921,6 +1970,7 @@ void mlir::mhlo::populateFoldLargeBinaryOpPatterns(
patterns.add<FoldLargeBinaryOp<mhlo::MinOp, Min>>(ctx);
patterns.add<FoldLargeBinaryOp<mhlo::PowOp, Pow>>(ctx);
patterns.add<FoldLargeCompareOp>(ctx);
patterns.add<FoldClampOp>(ctx);
}

void mlir::mhlo::populateFoldBeneficialConstantConvertOpPattern(
Expand Down
32 changes: 21 additions & 11 deletions compiler/numerical/hlo/canonicalize_ext.mlir

Large diffs are not rendered by default.

16 changes: 5 additions & 11 deletions compiler/python/byteir/dialects/cat/ir_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ def __init__(self,
self.module = None
self.ait_reuse_recorder = {} # key: hash str, value: Tuple(dll_name, ait_module_path)
self.compile_parallelism = min(compile_parallelism, MAX_COMPILATION_PARALLELISM)
if self.compile_parallelism > 1:
self.pool = multiprocessing.Pool(compile_parallelism)
else:
self.pool = None
self.pool = multiprocessing.Pool(compile_parallelism)
self.byteir_cache = AITCache()
self.verbose = verbose
self.disable_byteir_ait_cache = disable_byteir_ait_cache
Expand Down Expand Up @@ -162,14 +159,11 @@ def ait_opt_pass(self, anchor_only=False, dump_ir=False):
print("compile ait module using {} processes".format(min(len(work_items_not_in_cache), self.compile_parallelism)))
t_st = time.time()
for func_ir_str in work_items_not_in_cache:
if self.pool:
self.pool.apply_async(_parallel_ait_compile, (self.workdir, func_ir_str))
else:
_parallel_ait_compile(self.workdir, func_ir_str)
self.pool.apply_async(_parallel_ait_compile, (self.workdir, func_ir_str))
# _parallel_ait_compile(self.workdir, func_ir_str)

if self.pool:
self.pool.close()
self.pool.join()
self.pool.close()
self.pool.join()
t_ed = time.time()
print("compilation finished in {}s".format(t_ed-t_st))

Expand Down
2 changes: 1 addition & 1 deletion compiler/python/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.7.0
1.7.1.0
21 changes: 21 additions & 0 deletions compiler/test/Dialect/Transform/detensorizeInsertion.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: byteir-opt %s --insert-detensorize-transform="func-anchor=test_func_anchor match-prefix=test_prefix" --transform-dialect-interpreter --canonicalize --cse --split-input-file | FileCheck %s

// CHECK-LABEL: func.func @elementwise
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>, %[[ARG1:.+]]: tensor<f32>)
func.func @elementwise(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> attributes {test_func_anchor, other} {
%0 = linalg.elemwise_unary ins(%arg0 : tensor<f32>)
outs(%arg1: tensor<f32>) -> tensor<f32>
// CHECK: %[[EXTRACT:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
// CHECK: %[[UNARY:.+]] = math.exp %[[EXTRACT]] : f32
%1 = linalg.elemwise_binary ins(%0, %arg0 : tensor<f32>, tensor<f32>)
outs(%arg1: tensor<f32>) -> tensor<f32>
// CHECK: %[[BINARY:.+]] = arith.addf %[[UNARY]], %[[EXTRACT]] : f32
// CHECK: %[[RET:.+]] = tensor.from_elements %[[BINARY]] : tensor<f32>
return %1 : tensor<f32>
// CHECK: return %[[RET]] : tensor<f32>
}

// CHECK: transform.sequence failures(propagate) {
// CHECK: ^bb0(%arg0: !pdl.operation):
// CHECK: %0 = transform.structured.match attributes {test_prefix_0}
// CHECK: transform.structured.detensorize %0
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,3 @@ index 38198a91..fd4a40df 100644
if (decompose)
markDecomposedOpsAsIllegal(context, target, backendLegalOpsSet);
return target;
@@ -386,6 +432,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
target.addIllegalOp<AtenTanhBackwardOp>();
target.addIllegalOp<AtenAddmmOp>();
+ target.addIllegalOp<AtenEinsumOp>();
target.addIllegalOp<AtenMeanOp>();
target.addIllegalOp<AtenMeanDimOp>();
target.addIllegalOp<AtenNormScalarOptDimOp>();
Loading

0 comments on commit f99e390

Please sign in to comment.