diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 13ff445131b..e396a56bd62 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -252,6 +252,7 @@ cc_library( "@local_config_mlir//:Analysis", "@local_config_mlir//:IR", "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardOps", "@local_config_mlir//:Support", ], alwayslink = 1, diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 983cdf0cbd0..ee659cf8bd6 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -96,6 +96,38 @@ func @intermOpUsedTwice(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf } +// CHECK-LABEL: @fuseMulIntoFullyConnected +func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { + %cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %cst1 = constant dense<2.0> : tensor<2xf32> + %cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32> + + %0 = "tfl.fully_connected"(%arg0, %cst0, %cst1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + + return %1 : tensor<4x2xf32> + +// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> +// CHECK: %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> +// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> +// CHECK: return %0 : tensor<4x2xf32> +} + +// CHECK-LABEL: @fuseMulIntoFullyConnectedNoBias +func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<4x2xf32> { + %cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32> + + %0 = "tfl.fully_connected"(%arg0, %cst0, %arg1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32> + %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + + return %1 : tensor<4x2xf32> + +// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> +// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32> +// CHECK: return %0 : tensor<4x2xf32> +} + // CHECK-LABEL: @fuseMulIntoDepthwiseConv2d func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> { %cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32> diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 0fd695f3c66..410cac51c95 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -29,7 +29,6 @@ class ExtractI32At : NativeCodeCall< "$_builder.getI32IntegerAttr($_self.cast().getValue()[" # i # "].cast().getInt())">; - // Merge the two Attributes to a ArrayAttr; def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">; diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index d93c01a806c..a5ca0abcbd1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -21,12 +21,17 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Matchers.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/Support/Functional.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" @@ -45,14 +50,20 @@ struct Optimize : public FunctionPass { void runOnFunction() override; }; +// Returns whether the given type `a` is broadcast-compatible with `b`. +bool IsBroadcastableElementsAttrAndType(Type a, Type b) { + return OpTrait::util::getBroadcastedType(a, b) != Type(); +} + // Returns whether the given `a` and `b` ElementsAttr have broadcast-compatible // types. bool IsBroadcastableElementsAttrs(Attribute a, Attribute b) { - return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type(); + return IsBroadcastableElementsAttrAndType(a.getType(), b.getType()); } #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc" -// Fuse Add with FullyConnected. + +// Fuse Add with proceeding FullyConnected. // Note that this assumes that the bias in the fullyConnected // is always None. // TODO(b/136285429): Move to tablegen when variadic is supported @@ -153,6 +164,76 @@ struct FuseFullyConnectedAndRelu : public RewritePattern { } }; +// Fuse Mul with proceeding FullyConnected. +// TODO(b/136285429): Move to tablegen when variadic is supported +struct FuseFullyConnectedAndMul : public RewritePattern { + explicit FuseFullyConnectedAndMul(MLIRContext *context) + : RewritePattern(TFL::MulOp::getOperationName(), + {"tfl.fully_connected", "tfl.mul", "std.constant"}, 4, + context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // Mul. + auto mul_op = cast(op); + DenseElementsAttr cst; + Value *constant_val = mul_op.rhs(); + if (!matchPattern(constant_val, m_Constant(&cst))) { + return matchFailure(); + } + + // Fully Connected. + auto fc_op = + dyn_cast_or_null(mul_op.lhs()->getDefiningOp()); + if (!fc_op) return matchFailure(); + Value *filter = fc_op.filter(); + Value *bias = fc_op.bias(); + if (fc_op.fused_activation_function().equals("None")) return matchFailure(); + + // Broadcast the constant operand of Mul if it isn't compatible to the + // filter input. We only support broadcasting the operand along the depth + // dimension, when the operand's depth is 1. + Value *new_const_val = constant_val; + if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) { + auto original_shape = cst.getType().getShape(); + llvm::SmallVector normalized_shape(original_shape.begin(), + original_shape.end()); + normalized_shape.push_back(1); + auto new_cst = cst.reshape(rewriter.getTensorType( + normalized_shape, cst.getType().getElementType())); + Type new_type = new_cst.getType(); + if (!IsBroadcastableElementsAttrAndType(new_type, filter->getType())) { + return matchFailure(); + } + auto new_op = + rewriter.create(mul_op.getLoc(), new_type, new_cst); + new_const_val = new_op.getResult(); + } + + // Rewrite. + Location loc = fc_op.getLoc(); + auto af_none = rewriter.getStringAttr(fc_op.fused_activation_function()); + auto new_filter = + rewriter.create(loc, filter, new_const_val, af_none); + // If bias isn't None, it needs to be multiplied as well. + if (!bias->getType().isa()) { + bias = rewriter.create(loc, bias, constant_val, af_none).output(); + } + + rewriter.replaceOpWithNewOp( + mul_op, mul_op.getType(), + /*input=*/fc_op.input(), + /*filter=*/new_filter.output(), + /*bias=*/bias, + /*fused_activation_function=*/ + rewriter.getStringAttr(mul_op.fused_activation_function()), + /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()), + /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims())); + + return matchSuccess(); + } +}; + // StridedSlice can have complicated atributes like begin_axis_mask, // end_axis_mask, ellipsis_axis_mask, new_axis_mask, shrink_axis_mask. These // masks will complicate the strided_slice computation logic, we can simplify @@ -238,11 +319,11 @@ void Optimize::runOnFunction() { OwningRewritePatternList patterns; auto *ctx = &getContext(); auto func = getFunction(); + // Add the generated patterns to the list. TFL::populateWithGenerated(ctx, &patterns); patterns.insert(ctx); - + FuseFullyConnectedAndMul, PadStridedSliceDims>(ctx); applyPatternsGreedily(func, std::move(patterns)); }