From 412b12ea8d6e8ee50578a704194cb180bf2a317f Mon Sep 17 00:00:00 2001
From: Karim Nosir <karimnosseir@google.com>
Date: Fri, 1 May 2020 15:00:32 -0700
Subject: [PATCH] Add transformations for TF/TFLite: 1) Reorder 2 successive
 adds with constants to allow constant folding. 2) Remove trivial
 Add/Sub/Mul/Div.

PiperOrigin-RevId: 309480295
Change-Id: Ie84f5fd4da79dc80a3a5ee806051bd63b5b49888
---
 .../compiler/mlir/lite/tests/optimize.mlir    | 13 +++
 .../mlir/lite/transforms/optimize_patterns.td | 18 +++++
 .../mlir/tensorflow/ir/tf_generated_ops.td    |  8 ++
 .../compiler/mlir/tensorflow/ir/tf_ops.cc     | 62 ++++++++++++++
 .../mlir/tensorflow/tests/constant-fold.mlir  | 81 +++++++++++++++++++
 5 files changed, 182 insertions(+)

diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index d1c0dd20c05..2815afd14b9 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -958,3 +958,16 @@ func @FusingdivRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32>
 // Fusing:  %[[div2:[0-9].*]] = tfl.div %[[relu]], %[[div1]] {fused_activation_function = "RELU6"} : tensor<1xf32>
 // Fusing:  return
 }
+
+func @ReorderAddWithConstant(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  %cst = constant dense<1.0> : tensor<2x2xf32>
+  %cst_1 = constant dense<2.0> : tensor<2x2xf32>
+  %0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  %1 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %1 : tensor<2x2xf32>
+
+  // CHECK-LABEL: ReorderAddWithConstant
+  // CHECK: %[[CONST:.*]] = constant dense<3.000000e+00> : tensor<2x2xf32>
+  // CHECK: %[[RESULT:.*]] = tfl.add %arg0, %[[CONST]] {fused_activation_function = "NONE"} : tensor<2x2xf32>
+}
+
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
index 82d9a76fab3..a3244f31053 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td
@@ -457,3 +457,21 @@ def : Pat<(TFL_AddOp
 // The constant folding in this pass might produce constant in the tf dialect.
 // This rule is to legalize these constant to the tfl dialect.
 def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
+
+// Reorders adds to allow constant folding.
+// Add --> Add $input, $constantA
+//    \--> $constantB
+// To
+// Add --> $input
+//    \--> Add ($constantA, $constantB)
+foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
+  def : Pat<(TFL_AddOp
+              (TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None),
+              (ConstantOp $b), ActFun),
+            (TFL_AddOp $input,
+              (TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
+              ActFun),
+  [(HasOneUse $first_output)]>;
+}
+
+
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index fbbfa8e95e5..1d973b8059a 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -160,6 +160,8 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 
   let hasCanonicalizer = 1;
+
+  let hasFolder = 1;
 }
 
 def TF_AllOp : TF_Op<"All", [NoSideEffect]> {
@@ -1963,6 +1965,8 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape]>,
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 
   let hasCanonicalizer = 1;
+
+  let hasFolder = 1;
 }
 
 def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape]>,
@@ -4960,6 +4964,8 @@ def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShap
   );
 
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
+
+  let hasFolder = 1;
 }
 
 def TF_MulNoNanOp : TF_Op<"MulNoNan", [NoSideEffect, ResultsBroadcastableShape]>,
@@ -8119,6 +8125,8 @@ def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>,
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
 
   let hasCanonicalizer = 1;
+
+  let hasFolder = 1;
 }
 
 def TF_SumOp : TF_Op<"Sum", [NoSideEffect]> {
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index 1b915e3d5fc..0c706ac24a2 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -494,6 +494,48 @@ LogicalResult FoldOperandsPermutation(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Rewrite Pattern for removing trivial Arithmetic op.
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Utility methods that returns Identity value to use for selected ops.
+
+APFloat GetIdentity(AddV2Op op) { return APFloat(0.0f); }
+APFloat GetIdentity(SubOp op) { return APFloat(0.0f); }
+APFloat GetIdentity(MulOp op) { return APFloat(1.0f); }
+APFloat GetIdentity(DivOp op) { return APFloat(1.0f); }
+
+// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
+// known to be Identity (e.g X+0)
+template <typename OP>
+OpFoldResult TrivialArithmeticOpFolder(OP arithmetic_op) {
+  DenseFPElementsAttr rhs_value;
+  auto constant_val = arithmetic_op.y();
+  if (!matchPattern(constant_val, m_Constant(&rhs_value))) {
+    return {};
+  }
+  auto result_op_type = arithmetic_op.getResult().getType();
+  auto lhs_type = arithmetic_op.x().getType();
+  if (!result_op_type.template isa<ShapedType>() ||
+      !lhs_type.template isa<ShapedType>() ||
+      !result_op_type.template cast<ShapedType>().hasStaticShape()) {
+    return {};
+  }
+  // We only handle non-broadcastable case.
+  if (result_op_type != lhs_type) {
+    return {};
+  }
+  auto identity_val = GetIdentity(arithmetic_op);
+  for (auto it = rhs_value.float_value_begin();
+       it != rhs_value.float_value_end(); ++it) {
+    if (*it != identity_val) return {};
+  }
+
+  return arithmetic_op.x();
+}
+}  // namespace
+
 namespace {
 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
 }  // namespace
@@ -525,6 +567,10 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<AddV2OfNegLeft, AddV2OfNegRight>(context);
 }
 
+OpFoldResult AddV2Op::fold(ArrayRef<Attribute> operands) {
+  return TrivialArithmeticOpFolder<AddV2Op>(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // AllOp
 //===----------------------------------------------------------------------===//
@@ -1271,6 +1317,10 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<DivWithSqrtDivisor>(context);
 }
 
+OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
+  return TrivialArithmeticOpFolder<DivOp>(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // DynamicStitchOp
 //===----------------------------------------------------------------------===//
@@ -1936,6 +1986,14 @@ LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// MulOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
+  return TrivialArithmeticOpFolder<MulOp>(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // NegOp
 //===----------------------------------------------------------------------===//
@@ -2904,6 +2962,10 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
   results.insert<SubOfNeg>(context);
 }
 
+OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
+  return TrivialArithmeticOpFolder<SubOp>(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // SumOp
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
index 2a34bbfacdc..abb2f4de40f 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir
@@ -251,3 +251,84 @@ func @testTensorListElementShape(%arg0: tensor<!tf.variant<tensor<2x4xf32>>>) ->
   // CHECK-NEXT:    return [[cst]] : tensor<2xi32>
   return %0: tensor<2xi32>
 }
+
+func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  %cst = constant dense<0.0> : tensor<2x2xf32>
+  %0 = "tf.Add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  %1 = "tf.Add"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %1 : tensor<2x2xf32>
+
+  // CHECK-LABEL: RemoveTrivialAdd
+  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
+}
+
+func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  %cst = constant dense<0.0> : tensor<2x2xf32>
+  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  %1 = "tf.AddV2"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %1 : tensor<2x2xf32>
+
+  // CHECK-LABEL: RemoveTrivialAddV2
+  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
+}
+
+func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  %cst = constant dense<0.0> : tensor<2x2xf32>
+  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  %1 = "tf.Sub"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %1 : tensor<2x2xf32>
+
+  // CHECK-LABEL: RemoveTrivialSub
+  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
+}
+
+func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  %cst = constant dense<1.0> : tensor<2x2xf32>
+  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  %1 = "tf.Mul"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %1 : tensor<2x2xf32>
+
+  // CHECK-LABEL: RemoveTrivialMul
+  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
+}
+
+func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
+  %cst = constant dense<1.0> : tensor<2x2xf32>
+  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  %1 = "tf.Div"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %1 : tensor<2x2xf32>
+
+  // CHECK-LABEL: RemoveTrivialDiv
+  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
+}
+
+func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
+  %cst = constant dense<0.0> : tensor<2x2xf32>
+  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
+  %1 = "tf.AddV2"(%0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %1 : tensor<2x2xf32>
+
+  // CHECK-LABEL: DontRemoveTrivialAdd
+  // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32>
+  // CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
+  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
+  // CHECK: return %[[RESULT]] : tensor<2x2xf32>
+}
+
+func @DontRemoveTrivialAdd2(%arg0: tensor<?x?xf32>, %arg1: tensor<2x2xf32>) -> tensor<?x?xf32> {
+  %cst = constant dense<0.0> : tensor<2x2xf32>
+  %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+  %1 = "tf.AddV2"(%0, %cst) : (tensor<?x?xf32> , tensor<2x2xf32>) -> tensor<?x?xf32>
+  return %1 :tensor<?x?xf32>
+
+  // CHECK-LABEL: DontRemoveTrivialAdd2
+  // CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32>
+  // CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+  // CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
+  // CHECK: return %[[RESULT]] : tensor<?x?xf32>
+}