From 7a6fa7dbc0c2a382d03092e718960a9f3f93635b Mon Sep 17 00:00:00 2001 From: TensorFlower Gardener <gardener@tensorflow.org> Date: Fri, 6 Nov 2020 04:41:00 -0800 Subject: [PATCH] Merge pull request #41790 from lgeiger:fix-matmul-fusion PiperOrigin-RevId: 341026586 Change-Id: Idb578a4cf48c82abaad2f811612c4df4c2f2752c --- .../compiler/mlir/lite/tests/optimize.mlir | 4 +- .../compiler/mlir/lite/transforms/optimize.cc | 38 +++++++------- .../compiler/mlir/lite/utils/validators.cc | 16 ++++++ .../compiler/mlir/lite/utils/validators.h | 4 ++ .../mlir/tensorflow/transforms/optimize.td | 10 ++-- tensorflow/lite/python/lite_v2_test.py | 49 +++++++++++++++++++ 6 files changed, 98 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index bedf77f726a..8385d868c6b 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -286,7 +286,7 @@ func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { return %1 : tensor<4x2xf32> -// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> +// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00], [6.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> // CHECK: %[[CONSTANT0:.*]] = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %[[CONSTANT0]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} // CHECK: return %[[RES]] : tensor<4x2xf32> @@ -384,7 +384,7 @@ func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> te return %1 : tensor<4x2xf32> -// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> +// CHECK: %[[CONSTANT:.*]] = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00], [6.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> // CHECK: %[[RES:.*]] = "tfl.fully_connected"(%arg0, %[[CONSTANT]], %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32> // CHECK: return %[[RES]] : tensor<4x2xf32> } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 3c11fe2b610..4317db5957b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -438,26 +438,28 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern<TFL::MulOp> { return failure(); if (fc_op.fused_activation_function() != "NONE") return failure(); - // Broadcast the constant operand of Mul if it isn't compatible to the - // filter input. We only support broadcasting the operand along the depth - // dimension, when the operand's depth is 1. - Value new_const_val = constant_val; - if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter.getType())) { - auto original_shape = cst.getType().getShape(); - llvm::SmallVector<int64_t, 4> normalized_shape(original_shape.begin(), - original_shape.end()); - normalized_shape.push_back(1); - auto new_cst = cst.reshape(RankedTensorType::get( - normalized_shape, cst.getType().getElementType())); - Type new_type = new_cst.getType(); - if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) { - return failure(); - } - auto new_op = - rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst); - new_const_val = new_op.getResult(); + // Only fuse multiplier if all dimensions other than the depth dimension + // are equal to 1 since otherwise + // `matmul(x, filter) * cst != matmul(x, filter * cst)` + // even if `filter` and `cst` are be broadcastable. + auto shape = cst.getType().getShape(); + if (!IsDimensionsDegenerateExceptLastOne(shape)) return failure(); + + int64_t element_size = shape.empty() ? 1 : shape[shape.size() - 1]; + // Expand and transpose the multiplier since weights are using the + // OHWI data format in TFLite. + int64_t normalized_shape[2] = {element_size, 1}; + auto new_cst = cst.reshape(RankedTensorType::get( + normalized_shape, cst.getType().getElementType())); + Type new_type = new_cst.getType(); + if (!IsBroadcastableElementsAttrAndType(new_type, filter.getType())) { + return failure(); } + auto new_op = + rewriter.create<ConstantOp>(mul_op.getLoc(), new_type, new_cst); + Value new_const_val = new_op.getResult(); + // Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands, // TF::MulOp is used to fold the constant. // TODO(b/139192933): switch to the TFL constant folding diff --git a/tensorflow/compiler/mlir/lite/utils/validators.cc b/tensorflow/compiler/mlir/lite/utils/validators.cc index f863eeed0d6..8b38ff45998 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.cc +++ b/tensorflow/compiler/mlir/lite/utils/validators.cc @@ -78,5 +78,21 @@ bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b) { return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type(); } +bool IsDimensionsDegenerateExceptLastOne(ArrayRef<int64_t> elements_shape) { + if (elements_shape.empty()) return true; + + for (auto dim : elements_shape.drop_back(1)) { + if (dim != 1) return false; + } + return true; +} + +bool IsDimensionsDegenerateExceptLastOne(Attribute val) { + if (auto ranked_type = val.getType().dyn_cast<RankedTensorType>()) { + return IsDimensionsDegenerateExceptLastOne(ranked_type.getShape()); + } + return false; +} + } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h index 247947c3adc..406d90eeeb5 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.h +++ b/tensorflow/compiler/mlir/lite/utils/validators.h @@ -70,6 +70,10 @@ inline bool TFPaddingIsSameOrValid(Operation *op, StringAttr *padding) { /// Returns whether the given `a` and `b` have broadcast-compatible /// types. bool IsBroadcastableElementsAttrs(mlir::Attribute a, mlir::Attribute b); +// Returns true if every dimension of the attribute is 1 except the last one. +bool IsDimensionsDegenerateExceptLastOne(mlir::Attribute val); +// Returns true if every element is 1 except the last one. +bool IsDimensionsDegenerateExceptLastOne(ArrayRef<int64_t> elements_shape); } // end namespace TFL } // end namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td index 75d2bc06482..dfce5078e68 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td @@ -17,8 +17,12 @@ include "mlir/Dialect/StandardOps/IR/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" def IsDataFormatNHWC : ConstantAttr<TF_ConvnetDataFormatAttr, "NHWC">; -def BroadcastableElements : - Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1)">>; + +// Only fuse multiplier if all dimensions other than the channel dimension +// are equal to 1. +def CanFuseMulAndConv2D : + Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1) && TFL::IsDimensionsDegenerateExceptLastOne($1)">>; + def F32ElementsAttr : ElementsAttrBase< CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">; def DefinedByConv2D : Constraint<CPred<"llvm::isa_and_nonnull<mlir::TF::Conv2DOp>($0.getDefiningOp())">>; @@ -40,7 +44,7 @@ def FuseMulAndConv2D : (location $mul)), $strides, $use_cudnn, $padding, $explicit_padding, $data_format, $dilations, (location $conv)), - [(BroadcastableElements $filter_value, $mul_value), (HasOneUse $conv)]>; + [(CanFuseMulAndConv2D $filter_value, $mul_value), (HasOneUse $conv)]>; // This rule does the following pattern match and rewrite: // diff --git a/tensorflow/lite/python/lite_v2_test.py b/tensorflow/lite/python/lite_v2_test.py index db38287d9c2..ad5811de054 100644 --- a/tensorflow/lite/python/lite_v2_test.py +++ b/tensorflow/lite/python/lite_v2_test.py @@ -1304,5 +1304,54 @@ class UnknownShapes(lite_v2_test_util.ModelTest): str(error.exception)) +class AffineOpThenMulFusionTest(lite_v2_test_util.ModelTest): + + @parameterized.named_parameters(('should_fuse_1d', [2], True), + ('should_fuse_1x2', [1, 2], True), + ('should_not_fuse_2x1', [2, 1], False), + ('should_not_fuse_2x2', [2, 2], False)) + @test_util.run_v2_only + def testFullyConnectedFusion(self, multiplier_shape, can_fuse): + """Test fusion of (x ∗ w) * m into fullyconnected.""" + + @tf.function + def func(x): + w = tf.constant([3., 4., 5., 6.], shape=[2, 2]) + m_value = [7., 8.] if sum(multiplier_shape) < 4 else [7., 8., 9., 10.] + m = tf.constant(m_value, shape=multiplier_shape) + return tf.matmul(x, w) * m + + input_data = tf.constant([1., 2.], shape=[1, 2]) + self._checkAffineFusion(func, input_data, 1 if can_fuse else 2) + + @parameterized.named_parameters(('should_fuse_1d', [2], True), + ('should_fuse_1x2', [1, 2], True), + ('should_not_fuse_2x1', [2, 1], False)) + @test_util.run_v2_only + def testConvFusion(self, multiplier_shape, can_fuse): + """Test fusion of (x ∗ w) * m into conv2d.""" + + @tf.function + def func(x): + w = tf.constant([3., 4., 5., 6.], shape=[2, 1, 1, 2]) + m = tf.constant([7., 8.], shape=multiplier_shape) + return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') * m + + input_data = tf.constant([1., 2.], shape=[1, 1, 2, 1]) + self._checkAffineFusion(func, input_data, 1 if can_fuse else 2) + + def _checkAffineFusion(self, func, input_data, expected_number_of_ops): + concrete_func = func.get_concrete_function(input_data) + converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func]) + tflite_model = converter.convert() + + interpreter = Interpreter(model_content=tflite_model) + assert len(interpreter._get_ops_details()) == expected_number_of_ops + + expected_value = func(input_data) + actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0] + self.assertAllClose(expected_value.numpy(), actual_value) + + if __name__ == '__main__': test.main()