Skip to content

Commit

Permalink
Update MLIR converter code to TF 2.9
Browse files Browse the repository at this point in the history
  • Loading branch information
lgeiger committed Apr 12, 2022
1 parent 8595502 commit 9e94da6
Show file tree
Hide file tree
Showing 35 changed files with 279 additions and 247 deletions.
16 changes: 8 additions & 8 deletions larq_compute_engine/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ td_library(
srcs = ["transforms/op_removal_patterns.td"],
includes = ["/external/org_tensorflow"],
deps = [
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
)
Expand All @@ -43,7 +43,7 @@ td_library(
includes = ["/external/org_tensorflow"],
deps = [
":lce_ops_td_file",
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
],
)
Expand All @@ -54,7 +54,7 @@ td_library(
includes = ["/external/org_tensorflow"],
deps = [
":lce_ops_td_file",
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
Expand Down Expand Up @@ -182,7 +182,7 @@ gentbl_cc_library(
td_file = "transforms/bitpack_activations_patterns.td",
deps = [
":lce_ops_td_file",
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
],
)
Expand All @@ -199,7 +199,7 @@ gentbl_cc_library(
td_file = "transforms/bitpack_weights_patterns.td",
deps = [
":lce_ops_td_file",
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
Expand Down Expand Up @@ -288,7 +288,7 @@ cc_library(
"transforms/passes.h",
],
deps = [
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:FuncDialect",
"@org_tensorflow//tensorflow/compiler/mlir/tensorflow",
],
alwayslink = 1,
Expand All @@ -308,7 +308,7 @@ cc_library(
deps = [
":larq_compute_engine",
"//larq_compute_engine/core:types",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:FuncDialect",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
"@org_tensorflow//tensorflow/compiler/mlir/lite:validators",
Expand Down Expand Up @@ -429,7 +429,7 @@ cc_library(
"transforms/passes.h",
],
deps = [
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:FuncDialect",
],
alwayslink = 1,
)
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/ir/lce_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def LarqDialect : Dialect {
//===----------------------------------------------------------------------===//

// Base class for the operation in this dialect
class LQ_Op<string mnemonic, list<OpTrait> traits = []> :
class LQ_Op<string mnemonic, list<Trait> traits = []> :
Op<LarqDialect, mnemonic, traits> {

let extraClassDeclaration = [{
Expand Down
11 changes: 5 additions & 6 deletions larq_compute_engine/mlir/lce_mlir_opt.cc
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
#include "larq_compute_engine/mlir/ir/lce_ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Support/MlirOptMain.h"
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "mlir/Transforms/Passes.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

int main(int argc, char** argv) {
mlir::registerTransformsPasses();
mlir::DialectRegistry registry;
registry.insert<mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect,
registry.insert<mlir::arith::ArithmeticDialect, mlir::func::FuncDialect,
mlir::quant::QuantizationDialect, mlir::TF::TensorFlowDialect,
mlir::TFL::TensorFlowLiteDialect, mlir::lq::LarqDialect>();
return failed(mlir::MlirOptMain(argc, argv,
"Larq Compute Engine pass driver\n", registry,
/*preloadDialectsInContext=*/false));
return failed(mlir::MlirOptMain(
argc, argv, "Larq Compute Engine pass driver\n", registry));
}
11 changes: 6 additions & 5 deletions larq_compute_engine/mlir/python/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ LCETarget GetLCETarget(const std::string& target_str) {
}
}

Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs) {
Status GetNumInputs(mlir::OwningOpRef<mlir::ModuleOp>* module,
int* num_inputs) {
*num_inputs = 0;
mlir::FuncOp entry_function = nullptr;
for (auto func : module->get().getOps<mlir::FuncOp>()) {
mlir::func::FuncOp entry_function = nullptr;
for (auto func : module->get().getOps<mlir::func::FuncOp>()) {
if (auto tf_attrs =
func->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function")) {
// TODO(jaesung): There could be multiple entry functions. Let's handle
Expand Down Expand Up @@ -70,13 +71,13 @@ Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs) {
}

pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer(
mlir::OwningModuleRef* module, mlir::MLIRContext& context,
mlir::OwningOpRef<mlir::ModuleOp>* module, mlir::MLIRContext& context,
const LCETarget target, const pybind11::object& default_ranges,
const std::unordered_set<std::string>& saved_model_tags,
llvm::StringRef saved_model_dir,
llvm::Optional<tensorflow::Session*> session, const int num_inputs,
const bool should_quantize, const bool mark_as_post_training_quant) {
mlir::TFL::QuantizationSpecs quant_specs;
mlir::quant::QuantizationSpecs quant_specs;
if (should_quantize) {
// Normally we'd only set `inference_type` to QINT8 when there are
// fake_quant nodes in the graph. However this did not work reliably, and
Expand Down
4 changes: 2 additions & 2 deletions larq_compute_engine/mlir/python/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ namespace tensorflow {

LCETarget GetLCETarget(const std::string& target_str);

Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs);
Status GetNumInputs(mlir::OwningOpRef<mlir::ModuleOp>* module, int* num_inputs);

pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer(
mlir::OwningModuleRef* module, mlir::MLIRContext& context,
mlir::OwningOpRef<mlir::ModuleOp>* module, mlir::MLIRContext& context,
const LCETarget target, const pybind11::object& default_ranges,
const std::unordered_set<std::string>& saved_model_tags,
llvm::StringRef saved_model_dir,
Expand Down
2 changes: 1 addition & 1 deletion larq_compute_engine/mlir/tests/bitpack-weights.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: lce-tf-opt %s -tfl-lce-bitpack-weights -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @bitpack_bconv2d_filters
func @bitpack_bconv2d_filters(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
func.func @bitpack_bconv2d_filters(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> {
%cst = arith.constant dense<1.0> : tensor<16x3x3x3xf32>
%0 = "lq.Bconv2d"(%arg0, %cst, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32>
return %0 : tensor<256x30x30x16xf32>
Expand Down
4 changes: 2 additions & 2 deletions larq_compute_engine/mlir/tests/const-fold.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: lce-tf-opt %s -canonicalize | FileCheck %s

// CHECK-LABEL: @quantize
func @quantize() -> (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) {
func.func @quantize() -> (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) {
%pos = arith.constant dense< 0.5> : tensor<1x1x2x32xf32>
%neg = arith.constant dense<-0.5> : tensor<1x1x2x32xf32>
%0 = "lq.Quantize"(%pos) {} : (tensor<1x1x2x32xf32>) -> tensor<1x1x2x1xi32>
Expand All @@ -14,7 +14,7 @@ func @quantize() -> (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) {
}

// CHECK-LABEL: @dequantize
func @dequantize() -> (tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32>) {
func.func @dequantize() -> (tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32>) {
%pos = arith.constant dense<0> : tensor<1x1x2x1xi32>
%neg = arith.constant dense<-1> : tensor<1x1x2x1xi32>
%0 = "lq.Dequantize"(%pos) {} : (tensor<1x1x2x1xi32>) -> tensor<1x1x2x32xf32>
Expand Down
12 changes: 6 additions & 6 deletions larq_compute_engine/mlir/tests/fuse_padding.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: lce-tf-opt %s -tfl-fuse-padding -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @fuse_pad_into_conv_valid
func @fuse_pad_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
func.func @fuse_pad_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<16x3x3x8xf32>
%cst2 = arith.constant dense<1.0> : tensor<16xf32>
Expand All @@ -14,7 +14,7 @@ func @fuse_pad_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x1
}

// CHECK-LABEL: @fuse_padv2_into_conv_valid
func @fuse_padv2_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
func.func @fuse_padv2_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<0.0> : tensor<f32>
%cst2 = arith.constant dense<1.0> : tensor<16x3x3x8xf32>
Expand All @@ -28,7 +28,7 @@ func @fuse_padv2_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64
}

// CHECK-LABEL: @fuse_pad_into_dwconv_valid
func @fuse_pad_into_dwconv_valid(%arg0: tensor<1x64x64x16xf32>) -> tensor<1x64x64x16xf32> {
func.func @fuse_pad_into_dwconv_valid(%arg0: tensor<1x64x64x16xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<1x3x3x16xf32>
%cst2 = arith.constant dense<1.0> : tensor<16xf32>
Expand All @@ -41,7 +41,7 @@ func @fuse_pad_into_dwconv_valid(%arg0: tensor<1x64x64x16xf32>) -> tensor<1x64x6
}

// CHECK-LABEL: @do_not_fuse_padv2_into_conv_wrong_pad_value
func @do_not_fuse_padv2_into_conv_wrong_pad_value(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
func.func @do_not_fuse_padv2_into_conv_wrong_pad_value(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<f32>
%cst2 = arith.constant dense<1.0> : tensor<16x3x3x8xf32>
Expand All @@ -54,7 +54,7 @@ func @do_not_fuse_padv2_into_conv_wrong_pad_value(%arg0: tensor<1x64x64x8xf32>)
}

// CHECK-LABEL: @do_not_fuse_pad_into_conv_same
func @do_not_fuse_pad_into_conv_same(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x66x66x16xf32> {
func.func @do_not_fuse_pad_into_conv_same(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x66x66x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<f32>
%cst2 = arith.constant dense<1.0> : tensor<16x3x3x8xf32>
Expand All @@ -67,7 +67,7 @@ func @do_not_fuse_pad_into_conv_same(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x6
}

// CHECK-LABEL: @do_not_fuse_pad_into_dwconv_channelpad
func @do_not_fuse_pad_into_dwconv_channelpad(%arg0: tensor<1x64x64x12xf32>) -> tensor<1x64x64x16xf32> {
func.func @do_not_fuse_pad_into_dwconv_channelpad(%arg0: tensor<1x64x64x12xf32>) -> tensor<1x64x64x16xf32> {
%cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [1, 3]]> : tensor<4x2xi32>
%cst1 = arith.constant dense<1.0> : tensor<1x3x3x16xf32>
%cst2 = arith.constant dense<1.0> : tensor<16xf32>
Expand Down
8 changes: 4 additions & 4 deletions larq_compute_engine/mlir/tests/legalize-lce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// RUN: lce-tf-opt %s -tfl-legalize-lce -lce-translate-tfl -verify-diagnostics | FileCheck %s --check-prefix=TRANSLATE

// CHECK-LABEL: @legalize_bconv2d
func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: none) -> tensor<256x30x30x16xf32> {
func.func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: none) -> tensor<256x30x30x16xf32> {
%0 = "lq.Bconv2d"(%arg0, %arg1, %arg2, %arg3, %arg4) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32>
return %0 : tensor<256x30x30x16xf32>

Expand All @@ -14,7 +14,7 @@ func @legalize_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf3
}

// CHECK-LABEL: @legalize_bmax_pool2d
func @legalize_bmax_pool2d(%arg0: tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32> {
func.func @legalize_bmax_pool2d(%arg0: tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32> {
%0 = "lq.BMaxPool2d"(%arg0) {filter_height = 2 : i32, filter_width = 2 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32} : (tensor<256x32x32x3xi32>) -> tensor<256x16x16x3xi32>
return %0 : tensor<256x16x16x3xi32>

Expand All @@ -26,7 +26,7 @@ func @legalize_bmax_pool2d(%arg0: tensor<256x32x32x3xi32>) -> tensor<256x16x16x3
}

// CHECK-LABEL: @legalize_quantize
func @legalize_quantize(%arg0: tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32> {
func.func @legalize_quantize(%arg0: tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32> {
%0 = "lq.Quantize"(%arg0) {} : (tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi32>
return %0 : tensor<256x32x32x2xi32>

Expand All @@ -38,7 +38,7 @@ func @legalize_quantize(%arg0: tensor<256x32x32x64xf32>) -> tensor<256x32x32x2xi
}

// CHECK-LABEL: @legalize_dequantize
func @legalize_dequantize(%arg0: tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32> {
func.func @legalize_dequantize(%arg0: tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32> {
%0 = "lq.Dequantize"(%arg0) {} : (tensor<256x32x32x2xi32>) -> tensor<256x32x32x64xf32>
return %0 : tensor<256x32x32x64xf32>

Expand Down
10 changes: 5 additions & 5 deletions larq_compute_engine/mlir/tests/op-removal.mlir
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
// RUN: lce-tf-opt %s -lce-op-removal-tf -verify-diagnostics | FileCheck %s

// CHECK-LABEL: @snapshot
func @snapshot(%arg0: tensor<3xi32>) -> tensor<3xi32> {
func.func @snapshot(%arg0: tensor<3xi32>) -> tensor<3xi32> {
%0 = "tf.Snapshot"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
// Should be converted to Identity and then from Identity to value
// CHECK-NEXT: return %arg0 : tensor<3xi32>
}

// CHECK-LABEL: @stop_gradient
func @stop_gradient(%arg0: tensor<3xi32>) -> tensor<3xi32> {
func.func @stop_gradient(%arg0: tensor<3xi32>) -> tensor<3xi32> {
%0 = "tf.StopGradient"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
// Should be converted to Identity and then from Identity to value
// CHECK-NEXT: return %arg0 : tensor<3xi32>
}

// CHECK-LABEL: @check_numerics
func @check_numerics(%arg0: tensor<3xf32>) -> tensor<3xf32> {
func.func @check_numerics(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "tf.CheckNumerics"(%arg0) {message = ""}: (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// Should be converted to Identity and then from Identity to value
// CHECK-NEXT: return %arg0 : tensor<3xf32>
}

// CHECK-LABEL: @placeholder_with_default
func @placeholder_with_default(%arg0: tensor<3xf32>) -> tensor<3xf32> {
func.func @placeholder_with_default(%arg0: tensor<3xf32>) -> tensor<3xf32> {
%0 = "tf.PlaceholderWithDefault"(%arg0): (tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
// Should be converted to Identity and then from Identity to value
// CHECK-NEXT: return %arg0 : tensor<3xf32>
}

// CHECK-LABEL: @identity
func @identity(%arg0: tensor<10xi32>, %arg1: tensor<20xi32>, %arg2: tensor<30xi32>) -> (tensor<10xi32>, tensor<20xi32>, tensor<30xi32>) {
func.func @identity(%arg0: tensor<10xi32>, %arg1: tensor<20xi32>, %arg2: tensor<30xi32>) -> (tensor<10xi32>, tensor<20xi32>, tensor<30xi32>) {
%0 = "tf.Identity"(%arg0) : (tensor<10xi32>) -> tensor<10xi32>
%1:2 = "tf.IdentityN"(%arg1,%arg2) : (tensor<20xi32>, tensor<30xi32>) -> (tensor<20xi32>, tensor<30xi32>)
return %0, %1#0, %1#1: tensor<10xi32>, tensor<20xi32>, tensor<30xi32>
Expand Down
Loading

0 comments on commit 9e94da6

Please sign in to comment.