From 7d6b60f1ed6eaf0c47e842b552a5c4f5c730d49f Mon Sep 17 00:00:00 2001
From: Feng Liu <fengliuai@google.com>
Date: Tue, 6 Aug 2019 20:50:51 -0700
Subject: [PATCH] Fold the MulOp with the proceeding FullyConnectedOp

This CL folded the MulOp into the filter and bias of the FullyConnectedOp for
this case. If the bias is a NoneType, it only folded the MulOp to the filter.

PiperOrigin-RevId: 262056643
---
 tensorflow/compiler/mlir/lite/BUILD           |  1 +
 .../compiler/mlir/lite/tests/optimize.mlir    | 32 +++++++
 .../mlir/lite/transforms/legalize_patterns.td |  1 -
 .../compiler/mlir/lite/transforms/optimize.cc | 89 ++++++++++++++++++-
 4 files changed, 118 insertions(+), 5 deletions(-)

diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 13ff445131b..e396a56bd62 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -252,6 +252,7 @@ cc_library(
         "@local_config_mlir//:Analysis",
         "@local_config_mlir//:IR",
         "@local_config_mlir//:Pass",
+        "@local_config_mlir//:StandardOps",
         "@local_config_mlir//:Support",
     ],
     alwayslink = 1,
diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index 983cdf0cbd0..ee659cf8bd6 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -96,6 +96,38 @@ func @intermOpUsedTwice(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf
 
 }
 
+// CHECK-LABEL: @fuseMulIntoFullyConnected
+func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
+  %cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
+  %cst1 = constant dense<2.0> : tensor<2xf32>
+  %cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32>
+
+  %0 = "tfl.fully_connected"(%arg0, %cst0, %cst1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
+  %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
+
+  return %1 : tensor<4x2xf32>
+
+// CHECK:  %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
+// CHECK:  %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
+// CHECK:  %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
+// CHECK:  return %0 : tensor<4x2xf32>
+}
+
+// CHECK-LABEL: @fuseMulIntoFullyConnectedNoBias
+func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<4x2xf32> {
+  %cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>
+  %cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32>
+
+  %0 = "tfl.fully_connected"(%arg0, %cst0, %arg1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
+  %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32>
+
+  return %1 : tensor<4x2xf32>
+
+// CHECK:  %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
+// CHECK:  %0 = "tfl.fully_connected"(%arg0, %cst, %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
+// CHECK:  return %0 : tensor<4x2xf32>
+}
+
 // CHECK-LABEL: @fuseMulIntoDepthwiseConv2d
 func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
   %cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>
diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
index 0fd695f3c66..410cac51c95 100644
--- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
+++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td
@@ -29,7 +29,6 @@ class ExtractI32At<int i> : NativeCodeCall<
     "$_builder.getI32IntegerAttr($_self.cast<ArrayAttr>().getValue()[" # i #
     "].cast<IntegerAttr>().getInt())">;
 
-
 // Merge the two Attributes to a ArrayAttr;
 def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">;
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index d93c01a806c..a5ca0abcbd1 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -21,12 +21,17 @@ limitations under the License.
 
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "mlir/IR/Attributes.h"  // TF:local_config_mlir
+#include "mlir/IR/Matchers.h"  // TF:local_config_mlir
 #include "mlir/IR/PatternMatch.h"  // TF:local_config_mlir
+#include "mlir/IR/StandardTypes.h"  // TF:local_config_mlir
 #include "mlir/Pass/Pass.h"  // TF:local_config_mlir
+#include "mlir/StandardOps/Ops.h"  // TF:local_config_mlir
 #include "mlir/Support/Functional.h"  // TF:local_config_mlir
+#include "mlir/Support/LLVM.h"  // TF:local_config_mlir
 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
@@ -45,14 +50,20 @@ struct Optimize : public FunctionPass<Optimize> {
   void runOnFunction() override;
 };
 
+// Returns whether the given type `a` is broadcast-compatible with `b`.
+bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
+  return OpTrait::util::getBroadcastedType(a, b) != Type();
+}
+
 // Returns whether the given `a` and `b` ElementsAttr have broadcast-compatible
 // types.
 bool IsBroadcastableElementsAttrs(Attribute a, Attribute b) {
-  return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type();
+  return IsBroadcastableElementsAttrAndType(a.getType(), b.getType());
 }
 
 #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc"
-// Fuse Add with FullyConnected.
+
+// Fuse Add with proceeding FullyConnected.
 // Note that this assumes that the bias in the fullyConnected
 // is always None.
 // TODO(b/136285429): Move to tablegen when variadic is supported
@@ -153,6 +164,76 @@ struct FuseFullyConnectedAndRelu : public RewritePattern {
   }
 };
 
+// Fuse Mul with proceeding FullyConnected.
+// TODO(b/136285429): Move to tablegen when variadic is supported
+struct FuseFullyConnectedAndMul : public RewritePattern {
+  explicit FuseFullyConnectedAndMul(MLIRContext *context)
+      : RewritePattern(TFL::MulOp::getOperationName(),
+                       {"tfl.fully_connected", "tfl.mul", "std.constant"}, 4,
+                       context) {}
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override {
+    // Mul.
+    auto mul_op = cast<MulOp>(op);
+    DenseElementsAttr cst;
+    Value *constant_val = mul_op.rhs();
+    if (!matchPattern(constant_val, m_Constant(&cst))) {
+      return matchFailure();
+    }
+
+    // Fully Connected.
+    auto fc_op =
+        dyn_cast_or_null<TFL::FullyConnectedOp>(mul_op.lhs()->getDefiningOp());
+    if (!fc_op) return matchFailure();
+    Value *filter = fc_op.filter();
+    Value *bias = fc_op.bias();
+    if (fc_op.fused_activation_function().equals("None")) return matchFailure();
+
+    // Broadcast the constant operand of Mul if it isn't compatible to the
+    // filter input. We only support broadcasting the operand along the depth
+    // dimension, when the operand's depth is 1.
+    Value *new_const_val = constant_val;
+    if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) {
+      auto original_shape = cst.getType().getShape();
+      llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
+                                                     original_shape.end());
+      normalized_shape.push_back(1);
+      auto new_cst = cst.reshape(rewriter.getTensorType(
+          normalized_shape, cst.getType().getElementType()));
+      Type new_type = new_cst.getType();
+      if (!IsBroadcastableElementsAttrAndType(new_type, filter->getType())) {
+        return matchFailure();
+      }
+      auto new_op =
+          rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
+      new_const_val = new_op.getResult();
+    }
+
+    // Rewrite.
+    Location loc = fc_op.getLoc();
+    auto af_none = rewriter.getStringAttr(fc_op.fused_activation_function());
+    auto new_filter =
+        rewriter.create<MulOp>(loc, filter, new_const_val, af_none);
+    // If bias isn't None, it needs to be multiplied as well.
+    if (!bias->getType().isa<NoneType>()) {
+      bias = rewriter.create<MulOp>(loc, bias, constant_val, af_none).output();
+    }
+
+    rewriter.replaceOpWithNewOp<TFL::FullyConnectedOp>(
+        mul_op, mul_op.getType(),
+        /*input=*/fc_op.input(),
+        /*filter=*/new_filter.output(),
+        /*bias=*/bias,
+        /*fused_activation_function=*/
+        rewriter.getStringAttr(mul_op.fused_activation_function()),
+        /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()),
+        /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims()));
+
+    return matchSuccess();
+  }
+};
+
 // StridedSlice can have complicated atributes like begin_axis_mask,
 // end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These
 // masks will complicate the strided_slice computation logic, we can simplify
@@ -238,11 +319,11 @@ void Optimize::runOnFunction() {
   OwningRewritePatternList patterns;
   auto *ctx = &getContext();
   auto func = getFunction();
+
   // Add the generated patterns to the list.
   TFL::populateWithGenerated(ctx, &patterns);
   patterns.insert<FuseFullyConnectedAndAdd, FuseFullyConnectedAndRelu,
-                  PadStridedSliceDims>(ctx);
-
+                  FuseFullyConnectedAndMul, PadStridedSliceDims>(ctx);
   applyPatternsGreedily(func, std::move(patterns));
 }