From 7a6fa7dbc0c2a382d03092e718960a9f3f93635b Mon Sep 17 00:00:00 2001
From: TensorFlower Gardener <gardener@tensorflow.org>
Date: Fri, 6 Nov 2020 04:41:00 -0800
Subject: [PATCH] Merge pull request #41790 from lgeiger:fix-matmul-fusion

PiperOrigin-RevId: 341026586
Change-Id: Idb578a4cf48c82abaad2f811612c4df4c2f2752c
---
 .../compiler/mlir/lite/tests/optimize.mlir    |  4 +-
 .../compiler/mlir/lite/transforms/optimize.cc | 38 +++++++-------
 .../compiler/mlir/lite/utils/validators.cc    | 16 ++++++
 .../compiler/mlir/lite/utils/validators.h     |  4 ++
 .../mlir/tensorflow/transforms/optimize.td    | 10 ++--
 tensorflow/lite/python/lite_v2_test.py        | 49 +++++++++++++++++++
 6 files changed, 98 insertions(+), 23 deletions(-)

diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
index bedf77f726a..8385d868c6b 100644
--- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir
+++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir
@@ -286,7 +286,7 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> {
 
   return %1 : tensor<4x2xf32>
 
-// CHECK:  %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
+// CHECK:  %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00], [6.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
 // CHECK:  %[[CONSTANT0:.*]] = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32>
 // CHECK:  %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"}
 // CHECK:  return %[[RES]] : tensor<4x2xf32>
@@ -384,7 +384,7 @@ func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> te
 
   return %1 : tensor<4x2xf32>
 
-// CHECK:  %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
+// CHECK:  %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00], [6.000000e+00, 8.000000e+00]]> : tensor<2x2xf32>
 // CHECK:  %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32>
 // CHECK:  return %[[RES]] : tensor<4x2xf32>
 }
diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
index 3c11fe2b610..4317db5957b 100644
--- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc
@@ -438,26 +438,28 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> {
       return failure();
     if (fc_op.fused_activation_function() != "NONE") return failure();
 
-    // Broadcast the constant operand of Mul if it isn't compatible to the
-    // filter input. We only support broadcasting the operand along the depth
-    // dimension, when the operand's depth is 1.
-    Value new_const_val = constant_val;
-    if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter.getType())) {
-      auto original_shape = cst.getType().getShape();
-      llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(),
-                                                     original_shape.end());
-      normalized_shape.push_back(1);
-      auto new_cst = cst.reshape(RankedTensorType::get(
-          normalized_shape, cst.getType().getElementType()));
-      Type new_type = new_cst.getType();
-      if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
-        return failure();
-      }
-      auto new_op =
-          rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
-      new_const_val = new_op.getResult();
+    // Only fuse multiplier if all dimensions other than the depth dimension
+    // are equal to 1 since otherwise
+    // `matmul(x, filter) * cst != matmul(x, filter * cst)`
+    // even if `filter` and `cst` are be broadcastable.
+    auto shape = cst.getType().getShape();
+    if (!IsDimensionsDegenerateExceptLastOne(shape)) return failure();
+
+    int64_t element_size = shape.empty() ? 1 : shape[shape.size() - 1];
+    // Expand and transpose the multiplier since weights are using the
+    // OHWI data format in TFLite.
+    int64_t normalized_shape[2] = {element_size, 1};
+    auto new_cst = cst.reshape(RankedTensorType::get(
+        normalized_shape, cst.getType().getElementType()));
+    Type new_type = new_cst.getType();
+    if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) {
+      return failure();
     }
 
+    auto new_op =
+        rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst);
+    Value new_const_val = new_op.getResult();
+
     // Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands,
     // TF::MulOp is used to fold the constant.
     // TODO(b/139192933): switch to the TFL constant folding
diff --git a/tensorflow/compiler/mlir/lite/utils/validators.cc b/tensorflow/compiler/mlir/lite/utils/validators.cc
index f863eeed0d6..8b38ff45998 100644
--- a/tensorflow/compiler/mlir/lite/utils/validators.cc
+++ b/tensorflow/compiler/mlir/lite/utils/validators.cc
@@ -78,5 +78,21 @@ bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b) {
   return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type();
 }
 
+bool IsDimensionsDegenerateExceptLastOne(ArrayRef<int64_t> elements_shape) {
+  if (elements_shape.empty()) return true;
+
+  for (auto dim : elements_shape.drop_back(1)) {
+    if (dim != 1) return false;
+  }
+  return true;
+}
+
+bool IsDimensionsDegenerateExceptLastOne(Attribute val) {
+  if (auto ranked_type = val.getType().dyn_cast<RankedTensorType>()) {
+    return IsDimensionsDegenerateExceptLastOne(ranked_type.getShape());
+  }
+  return false;
+}
+
 }  // namespace TFL
 }  // namespace mlir
diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h
index 247947c3adc..406d90eeeb5 100644
--- a/tensorflow/compiler/mlir/lite/utils/validators.h
+++ b/tensorflow/compiler/mlir/lite/utils/validators.h
@@ -70,6 +70,10 @@ inline bool TFPaddingIsSameOrValid(Operation *op, StringAttr *padding) {
 /// Returns whether the given `a` and `b` have broadcast-compatible
 /// types.
 bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b);
+// Returns true if every dimension of the attribute is 1 except the last one.
+bool IsDimensionsDegenerateExceptLastOne(mlir::Attribute val);
+// Returns true if every element is 1 except the last one.
+bool IsDimensionsDegenerateExceptLastOne(ArrayRef<int64_t> elements_shape);
 
 }  // end namespace TFL
 }  // end namespace mlir
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td
index 75d2bc06482..dfce5078e68 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td
@@ -17,8 +17,12 @@ include "mlir/Dialect/StandardOps/IR/Ops.td"
 include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
 
 def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">;
-def BroadcastableElements :
-    Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1)">>;
+
+// Only fuse multiplier if all dimensions other than the channel dimension
+// are equal to 1.
+def CanFuseMulAndConv2D :
+    Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1) && TFL::IsDimensionsDegenerateExceptLastOne($1)">>;
+
 def F32ElementsAttr : ElementsAttrBase<
     CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
 def DefinedByConv2D : Constraint<CPred<"llvm::isa_and_nonnull<mlir::TF::Conv2DOp>($0.getDefiningOp())">>;
@@ -40,7 +44,7 @@ def FuseMulAndConv2D :
                     (location $mul)),
           $strides, $use_cudnn, $padding, $explicit_padding, $data_format,
           $dilations, (location $conv)),
-      [(BroadcastableElements $filter_value, $mul_value), (HasOneUse $conv)]>;
+      [(CanFuseMulAndConv2D $filter_value, $mul_value), (HasOneUse $conv)]>;
 
 // This rule does the following pattern match and rewrite:
 //
diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py
index db38287d9c2..ad5811de054 100644
--- a/tensorflow/lite/python/lite_v2_test.py
+++ b/tensorflow/lite/python/lite_v2_test.py
@@ -1304,5 +1304,54 @@ class UnknownShapes(lite_v2_test_util.ModelTest):
         str(error.exception))
 
 
+class AffineOpThenMulFusionTest(lite_v2_test_util.ModelTest):
+
+  @parameterized.named_parameters(('should_fuse_1d', [2], True),
+                                  ('should_fuse_1x2', [1, 2], True),
+                                  ('should_not_fuse_2x1', [2, 1], False),
+                                  ('should_not_fuse_2x2', [2, 2], False))
+  @test_util.run_v2_only
+  def testFullyConnectedFusion(self, multiplier_shape, can_fuse):
+    """Test fusion of (x ∗ w) * m into fullyconnected."""
+
+    @tf.function
+    def func(x):
+      w = tf.constant([3., 4., 5., 6.], shape=[2, 2])
+      m_value = [7., 8.] if sum(multiplier_shape) < 4 else [7., 8., 9., 10.]
+      m = tf.constant(m_value, shape=multiplier_shape)
+      return tf.matmul(x, w) * m
+
+    input_data = tf.constant([1., 2.], shape=[1, 2])
+    self._checkAffineFusion(func, input_data, 1 if can_fuse else 2)
+
+  @parameterized.named_parameters(('should_fuse_1d', [2], True),
+                                  ('should_fuse_1x2', [1, 2], True),
+                                  ('should_not_fuse_2x1', [2, 1], False))
+  @test_util.run_v2_only
+  def testConvFusion(self, multiplier_shape, can_fuse):
+    """Test fusion of (x ∗ w) * m into conv2d."""
+
+    @tf.function
+    def func(x):
+      w = tf.constant([3., 4., 5., 6.], shape=[2, 1, 1, 2])
+      m = tf.constant([7., 8.], shape=multiplier_shape)
+      return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') * m
+
+    input_data = tf.constant([1., 2.], shape=[1, 1, 2, 1])
+    self._checkAffineFusion(func, input_data, 1 if can_fuse else 2)
+
+  def _checkAffineFusion(self, func, input_data, expected_number_of_ops):
+    concrete_func = func.get_concrete_function(input_data)
+    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
+    tflite_model = converter.convert()
+
+    interpreter = Interpreter(model_content=tflite_model)
+    assert len(interpreter._get_ops_details()) == expected_number_of_ops
+
+    expected_value = func(input_data)
+    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
+    self.assertAllClose(expected_value.numpy(), actual_value)
+
+
 if __name__ == '__main__':
   test.main()