From 09a10c139aa6c8f0b5873fa2e9578451d8d9f7d7 Mon Sep 17 00:00:00 2001 From: Karim Nosir Date: Fri, 20 Dec 2019 08:40:19 -0800 Subject: [PATCH] move binary ops before reshape If the other operand of the binary op is a constant and its shape is broadcastable to the other operand and also the reshape has only one use, the order of these two ops can be switched. This implements the MoveBinaryOperatorBeforeReshape pass in TOCO. PiperOrigin-RevId: 286584501 Change-Id: I467aad941a590d2b97678f5a6762a6beca262cfe --- tensorflow/compiler/mlir/lite/ir/tfl_ops.td | 5 +- .../compiler/mlir/lite/tests/optimize.mlir | 19 ------ .../compiler/mlir/lite/transforms/optimize.cc | 15 ----- .../mlir/lite/transforms/optimize_patterns.td | 67 +++++-------------- 4 files changed, 18 insertions(+), 88 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 4df893c76e0..6cba2413b83 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1181,8 +1181,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> { let builders = [TFL_BroadcastableBinaryBuilder]; } -def TFL_GreaterOp : TFL_Op<"greater", [ - Broadcastable, NoSideEffect, NoQuantizableResult]> { +def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> { let summary = "Greater operator"; let description = [{ @@ -1195,8 +1194,6 @@ def TFL_GreaterOp : TFL_Op<"greater", [ let results = (outs AnyTensor:$output); - let builders = [TFL_ComparisonBinaryBuilder]; - let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }]; let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index f1178302b9e..bab643309fe 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -318,25 +318,6 @@ func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf // CHECK: return %[[fc]] } -// CHECK-LABEL: @FuseFullyConnectedReshapeAddConst -func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { - %cst = constant dense<3.0> : tensor<40x40xf32> - %cst2 = constant dense<2.0> : tensor<40xf32> - %shape1 = constant dense<[1, 40, 40]> : tensor<3xi32> - %shape2 = constant dense<[40, 40]> : tensor<2xi32> - - %0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>) - %1 = "tfl.reshape"(%0, %shape1) : (tensor<40x40xf32>, tensor<3xi32>) -> tensor<1x40x40xf32> - %2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x40xf32>, tensor<40xf32>) -> tensor<1x40x40xf32> - %3 = "tfl.reshape"(%2, %shape2) : (tensor<1x40x40xf32>, tensor<2xi32>) -> tensor<40x40xf32> - - return %3 : tensor<40x40xf32> - - // CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32> - // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) - // CHECK: return %[[fc]] -} - // CHECK-LABEL: @FuseFullyConnectedRelu func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> { %0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32> diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 47e342f1397..1313bae97a1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -80,21 +80,6 @@ bool IsBroadcastableElementsAttrAndType(Type a, Type b) { return OpTrait::util::getBroadcastedType(a, b) != Type(); } -// Returns whether if `type1` dimensions are the same as the ending dimensions -// of `type2`. This is more restricted than broadcastable. -bool IsTailOfShape(Type type1, Type type2) { - auto tail_type = type1.dyn_cast(); - auto full_type = type2.dyn_cast(); - if (!tail_type || !full_type || tail_type.getRank() > full_type.getRank()) - return false; - auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend(); - auto i2 = full_type.getShape().rbegin(); - for (; i1 != e1; ++i1, ++i2) { - if (*i1 != *i2) return false; - } - return true; -} - bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val, bool is_depthwise) { // Make sure the val tensor has shape where all dimensions are 1 except diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 129d0708eba..1a22c80d38c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -249,17 +249,9 @@ multiclass L2NormalizePatterns { foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]] in defm : L2NormalizePatterns; -//===----------------------------------------------------------------------===// -// Binary ops patterns. -//===----------------------------------------------------------------------===// def AreBroadcastableTypes : ConstraintgetType(), $1->getType())">>; -def IsTailOfShape : ConstraintgetType(), $1->getType())">>; - -def HaveSameType : ConstraintgetType(), $1->getType()">>; - // Pattern for skipping Tile if it is mainly for broadcasting and the // Op is already supporting broadcasting. multiclass FuseTileBroadcastIntoFollowingBinary { @@ -274,58 +266,18 @@ multiclass FuseTileBroadcastIntoFollowingBinary { [(AreBroadcastableTypes $operand, $input)]>; } -// Multi-pattern consisting of matching stand-alone op or op followed by relu. -multiclass FusedBinaryActivationFuncOpPat { - foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], - [TFL_Relu6Op, TFL_AF_Relu6], - [TFL_Relu1Op, TFL_AF_Relu1]] in { - def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)), - (BinaryOp $lhs, $rhs, actFnPair[1])>; - } -} - -foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in { - defm : FuseTileBroadcastIntoFollowingBinary; - - // Instantiated FusedBinary patterns for the from-to pairs of ops. - defm : FusedBinaryActivationFuncOpPat; - - // Move binary op before reshape: reshape -> binary => binary -> reshape. - // This is valid only when the binary operand is constant and the shape is the - // tail of the other op and the intermediate result isn't used by other ops. - def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), - (ConstantOp:$rhs $a), TFL_AF_None), - (TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape), - [(IsTailOfShape $rhs, $lhs)]>; -} - -foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, - TFL_MaximumOp, TFL_LessOp, TFL_LessEqualOp, TFL_GreaterOp, - TFL_GreaterEqualOp] in { - // Move binary op before reshape: reshape -> binary => binary -> reshape. - // This is valid only when the binary operand is constant and the shape is the - // tail of the other op and the intermediate result isn't used by other ops. - def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), - (ConstantOp:$rhs $a)), - (TFL_ReshapeOp (BinaryOp $input, $rhs), $shape), - [(IsTailOfShape $rhs, $lhs)]>; -} +foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] + in defm : FuseTileBroadcastIntoFollowingBinary; // Returns shape of a ranked tensor. // if called without a ranked tensor it will fail. def GetShape: NativeCodeCall<"GetShape($0)">; -// Convert squeeze to reshape def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), (TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))), [(AnyStaticShapeTensor $squeeze_op)]>; -// Convert remove double reshapes if the final shape is as same the input shape. -def : Pat<(TFL_ReshapeOp:$output (TFL_ReshapeOp $input, $shape1), $shape2), - (replaceWithValue $input), - [(HaveSameType $output, $input)]>; - class ValueEquals : Constraint().getNumElements() == 1 &&" "*$0.cast().getValues().begin() == " # val>>; @@ -342,6 +294,21 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input, (TFL_Relu1Op $input), [(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>; +// Multi-pattern consisting of matching stand-alone op or op followed by relu. +multiclass FusedBinaryActivationFuncOpPat { + foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], + [TFL_Relu6Op, TFL_AF_Relu6], + [TFL_Relu1Op, TFL_AF_Relu1]] in { + def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)), + (BinaryOp $lhs, $rhs, actFnPair[1])>; + } +} + +// Instantiated FusedBinary patterns for the from-to pairs of ops. +foreach BinaryOps = [TFL_AddOp, TFL_DivOp, + TFL_MulOp, TFL_SubOp] in + defm : FusedBinaryActivationFuncOpPat; + // The constant folding in this pass might produce constant in the tf dialect. // This rule is to legalize these constant to the tfl dialect. def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;