From 079a0bcc556201abdf435826a3f267ad47e7d275 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Thu, 19 Dec 2019 14:10:21 -0800 Subject: [PATCH] move binary ops before reshape If the other operand of the binary op is a constant and its shape is broadcastable to the other operand and also the reshape has only one use, the order of these two ops can be switched. This implements the MoveBinaryOperatorBeforeReshape pass in TOCO. PiperOrigin-RevId: 286460837 Change-Id: I27da6cba2b3bb48f3fe3da74b2ce7c1c9ecf2e77 --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 5 +- .../compiler/mlir/lite/tests/optimize.mlir | 19 ++++++ .../compiler/mlir/lite/transforms/optimize.cc | 15 +++++ .../mlir/lite/transforms/optimize_patterns.td | 67 ++++++++++++++----- 4 files changed, 88 insertions(+), 18 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 6cba2413b83..4df893c76e0 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1181,7 +1181,8 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> { let builders = [TFL_BroadcastableBinaryBuilder]; } -def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> { +def TFL_GreaterOp : TFL_Op<"greater", [ + Broadcastable, NoSideEffect, NoQuantizableResult]> { let summary = "Greater operator"; let description = [{ @@ -1194,6 +1195,8 @@ def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> { let results = (outs AnyTensor:$output); + let builders = [TFL_ComparisonBinaryBuilder]; + let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }]; let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index bab643309fe..f1178302b9e 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -318,6 +318,25 @@ func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf // CHECK: return %[[fc]] } +// CHECK-LABEL: @FuseFullyConnectedReshapeAddConst +func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { + %cst = constant dense<3.0> : tensor<40x40xf32> + %cst2 = constant dense<2.0> : tensor<40xf32> + %shape1 = constant dense<[1, 40, 40]> : tensor<3xi32> + %shape2 = constant dense<[40, 40]> : tensor<2xi32> + + %0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>) + %1 = "tfl.reshape"(%0, %shape1) : (tensor<40x40xf32>, tensor<3xi32>) -> tensor<1x40x40xf32> + %2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x40xf32>, tensor<40xf32>) -> tensor<1x40x40xf32> + %3 = "tfl.reshape"(%2, %shape2) : (tensor<1x40x40xf32>, tensor<2xi32>) -> tensor<40x40xf32> + + return %3 : tensor<40x40xf32> + + // CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32> + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) + // CHECK: return %[[fc]] +} + // CHECK-LABEL: @FuseFullyConnectedRelu func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> { %0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32> diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 1313bae97a1..47e342f1397 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -80,6 +80,21 @@ bool IsBroadcastableElementsAttrAndType(Type a, Type b) { return OpTrait::util::getBroadcastedType(a, b) != Type(); } +// Returns whether if `type1` dimensions are the same as the ending dimensions +// of `type2`. This is more restricted than broadcastable. +bool IsTailOfShape(Type type1, Type type2) { + auto tail_type = type1.dyn_cast(); + auto full_type = type2.dyn_cast(); + if (!tail_type || !full_type || tail_type.getRank() > full_type.getRank()) + return false; + auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend(); + auto i2 = full_type.getShape().rbegin(); + for (; i1 != e1; ++i1, ++i2) { + if (*i1 != *i2) return false; + } + return true; +} + bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val, bool is_depthwise) { // Make sure the val tensor has shape where all dimensions are 1 except diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 1a22c80d38c..129d0708eba 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -249,9 +249,17 @@ multiclass L2NormalizePatterns { foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]] in defm : L2NormalizePatterns; +//===----------------------------------------------------------------------===// +// Binary ops patterns. +//===----------------------------------------------------------------------===// def AreBroadcastableTypes : ConstraintgetType(), $1->getType())">>; +def IsTailOfShape : ConstraintgetType(), $1->getType())">>; + +def HaveSameType : ConstraintgetType(), $1->getType()">>; + // Pattern for skipping Tile if it is mainly for broadcasting and the // Op is already supporting broadcasting. multiclass FuseTileBroadcastIntoFollowingBinary { @@ -266,18 +274,58 @@ multiclass FuseTileBroadcastIntoFollowingBinary { [(AreBroadcastableTypes $operand, $input)]>; } -foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] - in defm : FuseTileBroadcastIntoFollowingBinary; +// Multi-pattern consisting of matching stand-alone op or op followed by relu. +multiclass FusedBinaryActivationFuncOpPat { + foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], + [TFL_Relu6Op, TFL_AF_Relu6], + [TFL_Relu1Op, TFL_AF_Relu1]] in { + def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)), + (BinaryOp $lhs, $rhs, actFnPair[1])>; + } +} + +foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in { + defm : FuseTileBroadcastIntoFollowingBinary; + + // Instantiated FusedBinary patterns for the from-to pairs of ops. + defm : FusedBinaryActivationFuncOpPat; + + // Move binary op before reshape: reshape -> binary => binary -> reshape. + // This is valid only when the binary operand is constant and the shape is the + // tail of the other op and the intermediate result isn't used by other ops. + def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), + (ConstantOp:$rhs $a), TFL_AF_None), + (TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape), + [(IsTailOfShape $rhs, $lhs)]>; +} + +foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, + TFL_MaximumOp, TFL_LessOp, TFL_LessEqualOp, TFL_GreaterOp, + TFL_GreaterEqualOp] in { + // Move binary op before reshape: reshape -> binary => binary -> reshape. + // This is valid only when the binary operand is constant and the shape is the + // tail of the other op and the intermediate result isn't used by other ops. + def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), + (ConstantOp:$rhs $a)), + (TFL_ReshapeOp (BinaryOp $input, $rhs), $shape), + [(IsTailOfShape $rhs, $lhs)]>; +} // Returns shape of a ranked tensor. // if called without a ranked tensor it will fail. def GetShape: NativeCodeCall<"GetShape($0)">; +// Convert squeeze to reshape def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), (TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))), [(AnyStaticShapeTensor $squeeze_op)]>; +// Convert remove double reshapes if the final shape is as same the input shape. +def : Pat<(TFL_ReshapeOp:$output (TFL_ReshapeOp $input, $shape1), $shape2), + (replaceWithValue $input), + [(HaveSameType $output, $input)]>; + class ValueEquals : Constraint().getNumElements() == 1 &&" "*$0.cast().getValues().begin() == " # val>>; @@ -294,21 +342,6 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input, (TFL_Relu1Op $input), [(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>; -// Multi-pattern consisting of matching stand-alone op or op followed by relu. -multiclass FusedBinaryActivationFuncOpPat { - foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], - [TFL_Relu6Op, TFL_AF_Relu6], - [TFL_Relu1Op, TFL_AF_Relu1]] in { - def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)), - (BinaryOp $lhs, $rhs, actFnPair[1])>; - } -} - -// Instantiated FusedBinary patterns for the from-to pairs of ops. -foreach BinaryOps = [TFL_AddOp, TFL_DivOp, - TFL_MulOp, TFL_SubOp] in - defm : FusedBinaryActivationFuncOpPat; - // The constant folding in this pass might produce constant in the tf dialect. // This rule is to legalize these constant to the tfl dialect. def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;