diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 4df893c76e0..6cba2413b83 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1181,8 +1181,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> { let builders = [TFL_BroadcastableBinaryBuilder]; } -def TFL_GreaterOp : TFL_Op<"greater", [ - Broadcastable, NoSideEffect, NoQuantizableResult]> { +def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> { let summary = "Greater operator"; let description = [{ @@ -1195,8 +1194,6 @@ def TFL_GreaterOp : TFL_Op<"greater", [ let results = (outs AnyTensor:$output); - let builders = [TFL_ComparisonBinaryBuilder]; - let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }]; let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index f1178302b9e..bab643309fe 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -318,25 +318,6 @@ func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf // CHECK: return %[[fc]] } -// CHECK-LABEL: @FuseFullyConnectedReshapeAddConst -func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { - %cst = constant dense<3.0> : tensor<40x40xf32> - %cst2 = constant dense<2.0> : tensor<40xf32> - %shape1 = constant dense<[1, 40, 40]> : tensor<3xi32> - %shape2 = constant dense<[40, 40]> : tensor<2xi32> - - %0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>) - %1 = "tfl.reshape"(%0, %shape1) : (tensor<40x40xf32>, tensor<3xi32>) -> tensor<1x40x40xf32> - %2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x40xf32>, tensor<40xf32>) -> tensor<1x40x40xf32> - %3 = "tfl.reshape"(%2, %shape2) : (tensor<1x40x40xf32>, tensor<2xi32>) -> tensor<40x40xf32> - - return %3 : tensor<40x40xf32> - - // CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32> - // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) - // CHECK: return %[[fc]] -} - // CHECK-LABEL: @FuseFullyConnectedRelu func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> { %0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32> diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 47e342f1397..1313bae97a1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -80,21 +80,6 @@ bool IsBroadcastableElementsAttrAndType(Type a, Type b) { return OpTrait::util::getBroadcastedType(a, b) != Type(); } -// Returns whether if `type1` dimensions are the same as the ending dimensions -// of `type2`. This is more restricted than broadcastable. -bool IsTailOfShape(Type type1, Type type2) { - auto tail_type = type1.dyn_cast(); - auto full_type = type2.dyn_cast(); - if (!tail_type || !full_type || tail_type.getRank() > full_type.getRank()) - return false; - auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend(); - auto i2 = full_type.getShape().rbegin(); - for (; i1 != e1; ++i1, ++i2) { - if (*i1 != *i2) return false; - } - return true; -} - bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val, bool is_depthwise) { // Make sure the val tensor has shape where all dimensions are 1 except diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 129d0708eba..1a22c80d38c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -249,17 +249,9 @@ multiclass L2NormalizePatterns { foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]] in defm : L2NormalizePatterns; -//===----------------------------------------------------------------------===// -// Binary ops patterns. -//===----------------------------------------------------------------------===// def AreBroadcastableTypes : ConstraintgetType(), $1->getType())">>; -def IsTailOfShape : ConstraintgetType(), $1->getType())">>; - -def HaveSameType : ConstraintgetType(), $1->getType()">>; - // Pattern for skipping Tile if it is mainly for broadcasting and the // Op is already supporting broadcasting. multiclass FuseTileBroadcastIntoFollowingBinary { @@ -274,58 +266,18 @@ multiclass FuseTileBroadcastIntoFollowingBinary { [(AreBroadcastableTypes $operand, $input)]>; } -// Multi-pattern consisting of matching stand-alone op or op followed by relu. -multiclass FusedBinaryActivationFuncOpPat { - foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], - [TFL_Relu6Op, TFL_AF_Relu6], - [TFL_Relu1Op, TFL_AF_Relu1]] in { - def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)), - (BinaryOp $lhs, $rhs, actFnPair[1])>; - } -} - -foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in { - defm : FuseTileBroadcastIntoFollowingBinary; - - // Instantiated FusedBinary patterns for the from-to pairs of ops. - defm : FusedBinaryActivationFuncOpPat; - - // Move binary op before reshape: reshape -> binary => binary -> reshape. - // This is valid only when the binary operand is constant and the shape is the - // tail of the other op and the intermediate result isn't used by other ops. - def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), - (ConstantOp:$rhs $a), TFL_AF_None), - (TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape), - [(IsTailOfShape $rhs, $lhs)]>; -} - -foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, - TFL_MaximumOp, TFL_LessOp, TFL_LessEqualOp, TFL_GreaterOp, - TFL_GreaterEqualOp] in { - // Move binary op before reshape: reshape -> binary => binary -> reshape. - // This is valid only when the binary operand is constant and the shape is the - // tail of the other op and the intermediate result isn't used by other ops. - def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), - (ConstantOp:$rhs $a)), - (TFL_ReshapeOp (BinaryOp $input, $rhs), $shape), - [(IsTailOfShape $rhs, $lhs)]>; -} +foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] + in defm : FuseTileBroadcastIntoFollowingBinary; // Returns shape of a ranked tensor. // if called without a ranked tensor it will fail. def GetShape: NativeCodeCall<"GetShape($0)">; -// Convert squeeze to reshape def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), (TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))), [(AnyStaticShapeTensor $squeeze_op)]>; -// Convert remove double reshapes if the final shape is as same the input shape. -def : Pat<(TFL_ReshapeOp:$output (TFL_ReshapeOp $input, $shape1), $shape2), - (replaceWithValue $input), - [(HaveSameType $output, $input)]>; - class ValueEquals : Constraint().getNumElements() == 1 &&" "*$0.cast().getValues().begin() == " # val>>; @@ -342,6 +294,21 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input, (TFL_Relu1Op $input), [(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>; +// Multi-pattern consisting of matching stand-alone op or op followed by relu. +multiclass FusedBinaryActivationFuncOpPat { + foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], + [TFL_Relu6Op, TFL_AF_Relu6], + [TFL_Relu1Op, TFL_AF_Relu1]] in { + def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)), + (BinaryOp $lhs, $rhs, actFnPair[1])>; + } +} + +// Instantiated FusedBinary patterns for the from-to pairs of ops. +foreach BinaryOps = [TFL_AddOp, TFL_DivOp, + TFL_MulOp, TFL_SubOp] in + defm : FusedBinaryActivationFuncOpPat; + // The constant folding in this pass might produce constant in the tf dialect. // This rule is to legalize these constant to the tfl dialect. def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;