diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index a1222c83775..cc07abe1c8d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -40,6 +40,8 @@ class TensorFlowLiteDialect : public Dialect { public: explicit TensorFlowLiteDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "tfl"; } + // Registered hook to materialize a constant operation from a given attribute // value with the desired resultant type. Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 5abb4612509..23708c0c1f1 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -41,9 +41,9 @@ func @squeezeAndReshape(%arg0: tensor<1x1x10xf32>, %arg1: tensor) -> i %4 = "some_op"(%1, %3) : (tensor<*xf32>, tensor<2x5xf32>) -> i32 return %4 : i32 // CHECK-LABEL: squeezeAndReshape -// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32> // CHECK: "tfl.squeeze"(%arg0) {squeeze_dims = [0]} : (tensor<1x1x10xf32>) -> tensor<1x10xf32> // CHECK: %1 = "tfl.squeeze"(%arg1) {squeeze_dims = []} : (tensor) -> tensor<*xf32> +// CHECK: %cst = constant dense<[2, 5]> : tensor<2xi32> // CHECK: %2 = "tfl.reshape"(%0, %cst) : (tensor<1x10xf32>, tensor<2xi32>) -> tensor<2x5xf32> // CHECK: %3 = "some_op"(%1, %2) : (tensor<*xf32>, tensor<2x5xf32>) -> i32 // CHECK: return @@ -74,7 +74,7 @@ func @dynamicReshapeI64Fold(%arg0: tensor<*xf32>) -> tensor<1x2xf32> { return %0 : tensor<1x2xf32> // CHECK-LABEL: dynamicReshapeI64Fold -// CHECK-NEXT: %[[cst:.*]] = constant dense<[1, 2]> : tensor<2xi32> +// CHECK: %[[cst:.*]] = constant dense<[1, 2]> : tensor<2xi32> // CHECK-NEXT: %[[reshape:.*]] = "tfl.reshape"(%arg0, %[[cst]]) : (tensor<*xf32>, tensor<2xi32>) -> tensor<1x2xf32> // CHECK-NEXT: return %[[reshape]] : tensor<1x2xf32> } @@ -114,10 +114,10 @@ func @softplus(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { return %0 : tensor<8x16xf32> // CHECK-LABEL: softplus -// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor -// CHECK-NEXT: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> -// CHECK-NEXT: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> -// CHECK-NEXT: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32> +// CHECK: %[[exp:.*]] = "tfl.exp"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> +// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor +// CHECK: %[[add:.*]] = "tfl.add"(%[[exp]], %[[cst]]) {fused_activation_function = "NONE"} : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> +// CHECK: %[[log:.*]] = "tfl.log"(%[[add]]) : (tensor<8x16xf32>) -> tensor<8x16xf32> } func @fakeQuantArgsFalse(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> { @@ -676,8 +676,9 @@ func @matrix_diag_v2_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32> // CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32> +// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32> // CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32> -// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> +// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV2"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> // CHECK: return [[VAL_4]] : tensor<8x16x16xf32> } @@ -709,8 +710,9 @@ func @matrix_diag_v3_no_match(%arg0: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK-SAME: [[VAL_0:%.*]]: tensor<8x16xf32>) -> tensor<8x16x16xf32> { // CHECK: [[VAL_1:%.*]] = constant dense<1> : tensor<1xi32> // CHECK: [[VAL_2:%.*]] = constant dense<-1> : tensor<1xi32> +// CHECK: [[VAL_5:%.*]] = constant dense<-1> : tensor<1xi32> // CHECK: [[VAL_3:%.*]] = constant dense<0> : tensor<2xi32> -// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_2]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> +// CHECK: [[VAL_4:%.*]] = "tf.MatrixDiagV3"([[VAL_0]], [[VAL_1]], [[VAL_2]], [[VAL_5]], [[VAL_3]]) : (tensor<8x16xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<8x16x16xf32> // CHECK: return [[VAL_4]] : tensor<8x16x16xf32> } @@ -1273,7 +1275,8 @@ func @conv2d_backprop_input(%arg0: tensor<4xi32>, %arg1: tensor<3x3x1x32xf32>, % // CHECK: %[[CST:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32> // CHECK: %[[ARG0:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> // CHECK: %[[ARG1:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG0]], %arg2) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32> - // CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> + // CHECK: %[[CST_1:.*]] = constant dense<[2, 0, 1, 3]> : tensor<4xi32> + // CHECK: %[[ARG2:.*]] = "tfl.transpose"(%arg1, %[[CST_1]]) : (tensor<3x3x1x32xf32>, tensor<4xi32>) -> tensor<1x3x3x32xf32> // CHECK: %[[ARG3:.*]] = "tfl.transpose_conv"(%arg0, %[[ARG2]], %arg2) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<1x3x3x32xf32>, tensor<15x14x14x32xf32>) -> tensor<15x28x28x1xf32> // CHECK: %[[RESULT:.*]] = tfl.add %[[ARG1]], %[[ARG3]] {fused_activation_function = "NONE"} : tensor<15x28x28x1xf32> // CHECK: return %[[RESULT]] : tensor<15x28x28x1xf32> diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 56c92995122..8967a300b30 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:llvm-project #include "mlir/Support/Functional.h" // TF:llvm-project #include "mlir/Support/LLVM.h" // TF:llvm-project +#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -465,8 +466,7 @@ PatternMatchResult ConvertTFMatrixDiagV3Op::matchAndRewrite( // TF Lite doesn't support Assert, we just drop the assert from the graph. PatternMatchResult ConvertTFAssertOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { - op->dropAllReferences(); - op->erase(); + rewriter.eraseOp(op); return matchSuccess(); } @@ -653,7 +653,7 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { // Rewire the output. op->getResult(2).replaceAllUsesWith(lstm_op.getResult()); - op->erase(); + rewriter.eraseOp(op); return matchSuccess(); } }; @@ -712,7 +712,7 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern { // Rewire the output. op->getResult(1).replaceAllUsesWith(rnn_op.getResult()); - op->erase(); + rewriter.eraseOp(op); return matchSuccess(); } @@ -720,22 +720,44 @@ struct LegalizeUnidirectionalSequenceRnn : public RewritePattern { void LegalizeTF::runOnFunction() { OwningRewritePatternList patterns; - auto* ctx = &getContext(); + auto* context = &getContext(); auto func = getFunction(); // Add the generated patterns to the list. - populateWithGenerated(ctx, &patterns); + populateWithGenerated(context, &patterns); patterns.insert(ctx); + ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context); // Ophint python converter converted tf node pattern. patterns.insert(ctx); - applyPatternsGreedily(func, patterns); + LegalizeUnidirectionalSequenceRnn>(context); + + ConversionTarget target(*context); + // It is legal to have TF ops in the graph still which can be + // used later or in the case of SELECT were we allow TF ops in the final + // graph. + target.addLegalOp(); + target.addLegalOp(); + target.addDynamicallyLegalDialect( + Optional([](Operation* op) { + auto tfl_op = dyn_cast_or_null(op); + if (!tfl_op) return false; + return succeeded(tfl_op.VerifyTflRuntimeTypes(tfl_op.getOperation())); + })); + // Keep trying to convert. + // TODO(karimnosseir): This is similar to what apply greedy patterns does. + // Look if there is a function that tries until it converge. + // Currently unit-test doesn't do multiple tries, so we need this. + const int max_iterations = 15; + for (int i = 0; i < max_iterations; ++i) { + if (failed(applyPartialConversion(func, target, patterns))) { + return; + } + } } } // namespace