diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index d1c0dd20c05..2815afd14b9 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -958,3 +958,16 @@ func @FusingdivRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> // Fusing: %[[div2:[0-9].*]] = tfl.div %[[relu]], %[[div1]] {fused_activation_function = "RELU6"} : tensor<1xf32> // Fusing: return } + +func @ReorderAddWithConstant(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2xf32> + %cst_1 = constant dense<2.0> : tensor<2x2xf32> + %0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: ReorderAddWithConstant + // CHECK: %[[CONST:.*]] = constant dense<3.000000e+00> : tensor<2x2xf32> + // CHECK: %[[RESULT:.*]] = tfl.add %arg0, %[[CONST]] {fused_activation_function = "NONE"} : tensor<2x2xf32> +} + diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 82d9a76fab3..a3244f31053 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -457,3 +457,21 @@ def : Pat<(TFL_AddOp // The constant folding in this pass might produce constant in the tf dialect. // This rule is to legalize these constant to the tfl dialect. def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>; + +// Reorders adds to allow constant folding. +// Add --> Add $input, $constantA +// \--> $constantB +// To +// Add --> $input +// \--> Add ($constantA, $constantB) +foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in { + def : Pat<(TFL_AddOp + (TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None), + (ConstantOp $b), ActFun), + (TFL_AddOp $input, + (TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None), + ActFun), + [(HasOneUse $first_output)]>; +} + + diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index fbbfa8e95e5..1d973b8059a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -160,6 +160,8 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let hasFolder = 1; } def TF_AllOp : TF_Op<"All", [NoSideEffect]> { @@ -1963,6 +1965,8 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape]>, TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let hasFolder = 1; } def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape]>, @@ -4960,6 +4964,8 @@ def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShap ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let hasFolder = 1; } def TF_MulNoNanOp : TF_Op<"MulNoNan", [NoSideEffect, ResultsBroadcastableShape]>, @@ -8119,6 +8125,8 @@ def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>, TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; let hasCanonicalizer = 1; + + let hasFolder = 1; } def TF_SumOp : TF_Op<"Sum", [NoSideEffect]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 1b915e3d5fc..0c706ac24a2 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -494,6 +494,48 @@ LogicalResult FoldOperandsPermutation( return success(); } +//===----------------------------------------------------------------------===// +// Rewrite Pattern for removing trivial Arithmetic op. +//===----------------------------------------------------------------------===// + +namespace { +// Utility methods that returns Identity value to use for selected ops. + +APFloat GetIdentity(AddV2Op op) { return APFloat(0.0f); } +APFloat GetIdentity(SubOp op) { return APFloat(0.0f); } +APFloat GetIdentity(MulOp op) { return APFloat(1.0f); } +APFloat GetIdentity(DivOp op) { return APFloat(1.0f); } + +// Folder that returns LHS of an Arithmetic Op if the RHS is a constant +// known to be Identity (e.g X+0) +template +OpFoldResult TrivialArithmeticOpFolder(OP arithmetic_op) { + DenseFPElementsAttr rhs_value; + auto constant_val = arithmetic_op.y(); + if (!matchPattern(constant_val, m_Constant(&rhs_value))) { + return {}; + } + auto result_op_type = arithmetic_op.getResult().getType(); + auto lhs_type = arithmetic_op.x().getType(); + if (!result_op_type.template isa() || + !lhs_type.template isa() || + !result_op_type.template cast().hasStaticShape()) { + return {}; + } + // We only handle non-broadcastable case. + if (result_op_type != lhs_type) { + return {}; + } + auto identity_val = GetIdentity(arithmetic_op); + for (auto it = rhs_value.float_value_begin(); + it != rhs_value.float_value_end(); ++it) { + if (*it != identity_val) return {}; + } + + return arithmetic_op.x(); +} +} // namespace + namespace { #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" } // namespace @@ -525,6 +567,10 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +OpFoldResult AddV2Op::fold(ArrayRef operands) { + return TrivialArithmeticOpFolder(*this); +} + //===----------------------------------------------------------------------===// // AllOp //===----------------------------------------------------------------------===// @@ -1271,6 +1317,10 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +OpFoldResult DivOp::fold(ArrayRef operands) { + return TrivialArithmeticOpFolder(*this); +} + //===----------------------------------------------------------------------===// // DynamicStitchOp //===----------------------------------------------------------------------===// @@ -1936,6 +1986,14 @@ LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { return success(); } +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +OpFoldResult MulOp::fold(ArrayRef operands) { + return TrivialArithmeticOpFolder(*this); +} + //===----------------------------------------------------------------------===// // NegOp //===----------------------------------------------------------------------===// @@ -2904,6 +2962,10 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +OpFoldResult SubOp::fold(ArrayRef operands) { + return TrivialArithmeticOpFolder(*this); +} + //===----------------------------------------------------------------------===// // SumOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index 2a34bbfacdc..abb2f4de40f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -251,3 +251,84 @@ func @testTensorListElementShape(%arg0: tensor>>) -> // CHECK-NEXT: return [[cst]] : tensor<2xi32> return %0: tensor<2xi32> } + +func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.Add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tf.Add"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialAdd + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> +} + +func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tf.AddV2"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialAddV2 + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> +} + +func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tf.Sub"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialSub + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> +} + +func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tf.Mul"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialMul + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> +} + +func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tf.Div"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: RemoveTrivialDiv + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> +} + +func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.AddV2"(%0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + + // CHECK-LABEL: DontRemoveTrivialAdd + // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32> + // CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: return %[[RESULT]] : tensor<2x2xf32> +} + +func @DontRemoveTrivialAdd2(%arg0: tensor, %arg1: tensor<2x2xf32>) -> tensor { + %cst = constant dense<0.0> : tensor<2x2xf32> + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor<2x2xf32>) -> tensor + %1 = "tf.AddV2"(%0, %cst) : (tensor , tensor<2x2xf32>) -> tensor + return %1 :tensor + + // CHECK-LABEL: DontRemoveTrivialAdd2 + // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32> + // CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor<2x2xf32>) -> tensor + // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor, tensor<2x2xf32>) -> tensor + // CHECK: return %[[RESULT]] : tensor +}