move binary ops before reshape
If the other operand of the binary op is a constant and its shape is broadcastable to the other operand and also the reshape has only one use, the order of these two ops can be switched. This implements the MoveBinaryOperatorBeforeReshape pass in TOCO. PiperOrigin-RevId: 286584501 Change-Id: I467aad941a590d2b97678f5a6762a6beca262cfe
This commit is contained in:
parent
8a001a4521
commit
09a10c139a
@ -1181,8 +1181,7 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> {
|
||||
let builders = [TFL_BroadcastableBinaryBuilder];
|
||||
}
|
||||
|
||||
def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
Broadcastable, NoSideEffect, NoQuantizableResult]> {
|
||||
def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "Greater operator";
|
||||
|
||||
let description = [{
|
||||
@ -1195,8 +1194,6 @@ def TFL_GreaterOp : TFL_Op<"greater", [
|
||||
|
||||
let results = (outs AnyTensor:$output);
|
||||
|
||||
let builders = [TFL_ComparisonBinaryBuilder];
|
||||
|
||||
let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }];
|
||||
|
||||
let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }];
|
||||
|
@ -318,25 +318,6 @@ func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf
|
||||
// CHECK: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedReshapeAddConst
|
||||
func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
|
||||
%cst = constant dense<3.0> : tensor<40x40xf32>
|
||||
%cst2 = constant dense<2.0> : tensor<40xf32>
|
||||
%shape1 = constant dense<[1, 40, 40]> : tensor<3xi32>
|
||||
%shape2 = constant dense<[40, 40]> : tensor<2xi32>
|
||||
|
||||
%0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>)
|
||||
%1 = "tfl.reshape"(%0, %shape1) : (tensor<40x40xf32>, tensor<3xi32>) -> tensor<1x40x40xf32>
|
||||
%2 = "tfl.add"(%1, %cst2) {fused_activation_function = "NONE"} : (tensor<1x40x40xf32>, tensor<40xf32>) -> tensor<1x40x40xf32>
|
||||
%3 = "tfl.reshape"(%2, %shape2) : (tensor<1x40x40xf32>, tensor<2xi32>) -> tensor<40x40xf32>
|
||||
|
||||
return %3 : tensor<40x40xf32>
|
||||
|
||||
// CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32>
|
||||
// CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]])
|
||||
// CHECK: return %[[fc]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @FuseFullyConnectedRelu
|
||||
func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> {
|
||||
%0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32>
|
||||
|
@ -80,21 +80,6 @@ bool IsBroadcastableElementsAttrAndType(Type a, Type b) {
|
||||
return OpTrait::util::getBroadcastedType(a, b) != Type();
|
||||
}
|
||||
|
||||
// Returns whether if `type1` dimensions are the same as the ending dimensions
|
||||
// of `type2`. This is more restricted than broadcastable.
|
||||
bool IsTailOfShape(Type type1, Type type2) {
|
||||
auto tail_type = type1.dyn_cast<ShapedType>();
|
||||
auto full_type = type2.dyn_cast<ShapedType>();
|
||||
if (!tail_type || !full_type || tail_type.getRank() > full_type.getRank())
|
||||
return false;
|
||||
auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend();
|
||||
auto i2 = full_type.getShape().rbegin();
|
||||
for (; i1 != e1; ++i1, ++i2) {
|
||||
if (*i1 != *i2) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val,
|
||||
bool is_depthwise) {
|
||||
// Make sure the val tensor has shape where all dimensions are 1 except
|
||||
|
@ -249,17 +249,9 @@ multiclass L2NormalizePatterns<dag FirstOp, dag SecondOp> {
|
||||
foreach L2NormalizePairs = [[TFL_MulOp, TFL_RsqrtOp], [TFL_DivOp, TFL_SqrtOp]]
|
||||
in defm : L2NormalizePatterns<L2NormalizePairs[0], L2NormalizePairs[1]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def AreBroadcastableTypes : Constraint<CPred<
|
||||
"TFL::IsBroadcastableElementsAttrAndType($0->getType(), $1->getType())">>;
|
||||
|
||||
def IsTailOfShape : Constraint<CPred<
|
||||
"TFL::IsTailOfShape($0->getType(), $1->getType())">>;
|
||||
|
||||
def HaveSameType : Constraint<CPred<"$0->getType(), $1->getType()">>;
|
||||
|
||||
// Pattern for skipping Tile if it is mainly for broadcasting and the
|
||||
// Op is already supporting broadcasting.
|
||||
multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
|
||||
@ -274,58 +266,18 @@ multiclass FuseTileBroadcastIntoFollowingBinary<dag BinaryOp> {
|
||||
[(AreBroadcastableTypes $operand, $input)]>;
|
||||
}
|
||||
|
||||
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
|
||||
multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
|
||||
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||
[TFL_Relu1Op, TFL_AF_Relu1]] in {
|
||||
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
|
||||
(BinaryOp $lhs, $rhs, actFnPair[1])>;
|
||||
}
|
||||
}
|
||||
|
||||
foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in {
|
||||
defm : FuseTileBroadcastIntoFollowingBinary<BinaryOp>;
|
||||
|
||||
// Instantiated FusedBinary patterns for the from-to pairs of ops.
|
||||
defm : FusedBinaryActivationFuncOpPat<BinaryOp>;
|
||||
|
||||
// Move binary op before reshape: reshape -> binary => binary -> reshape.
|
||||
// This is valid only when the binary operand is constant and the shape is the
|
||||
// tail of the other op and the intermediate result isn't used by other ops.
|
||||
def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
|
||||
(ConstantOp:$rhs $a), TFL_AF_None),
|
||||
(TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape),
|
||||
[(IsTailOfShape $rhs, $lhs)]>;
|
||||
}
|
||||
|
||||
foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp,
|
||||
TFL_MaximumOp, TFL_LessOp, TFL_LessEqualOp, TFL_GreaterOp,
|
||||
TFL_GreaterEqualOp] in {
|
||||
// Move binary op before reshape: reshape -> binary => binary -> reshape.
|
||||
// This is valid only when the binary operand is constant and the shape is the
|
||||
// tail of the other op and the intermediate result isn't used by other ops.
|
||||
def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)),
|
||||
(ConstantOp:$rhs $a)),
|
||||
(TFL_ReshapeOp (BinaryOp $input, $rhs), $shape),
|
||||
[(IsTailOfShape $rhs, $lhs)]>;
|
||||
}
|
||||
foreach BroadcastingOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp]
|
||||
in defm : FuseTileBroadcastIntoFollowingBinary<BroadcastingOp>;
|
||||
|
||||
// Returns shape of a ranked tensor.
|
||||
// if called without a ranked tensor it will fail.
|
||||
def GetShape: NativeCodeCall<"GetShape($0)">;
|
||||
|
||||
// Convert squeeze to reshape
|
||||
def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims),
|
||||
(TFL_ReshapeOp $input,
|
||||
(ConstantOp (GetShape $squeeze_op))),
|
||||
[(AnyStaticShapeTensor $squeeze_op)]>;
|
||||
|
||||
// Convert remove double reshapes if the final shape is as same the input shape.
|
||||
def : Pat<(TFL_ReshapeOp:$output (TFL_ReshapeOp $input, $shape1), $shape2),
|
||||
(replaceWithValue $input),
|
||||
[(HaveSameType $output, $input)]>;
|
||||
|
||||
class ValueEquals<string val> : Constraint<CPred<
|
||||
"$0.cast<DenseElementsAttr>().getNumElements() == 1 &&"
|
||||
"*$0.cast<DenseElementsAttr>().getValues<float>().begin() == " # val>>;
|
||||
@ -342,6 +294,21 @@ def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input,
|
||||
(TFL_Relu1Op $input),
|
||||
[(ValueEquals<"-1"> $NegOne), (ValueEquals<"1"> $One)]>;
|
||||
|
||||
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
|
||||
multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
|
||||
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||
[TFL_Relu1Op, TFL_AF_Relu1]] in {
|
||||
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
|
||||
(BinaryOp $lhs, $rhs, actFnPair[1])>;
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiated FusedBinary patterns for the from-to pairs of ops.
|
||||
foreach BinaryOps = [TFL_AddOp, TFL_DivOp,
|
||||
TFL_MulOp, TFL_SubOp] in
|
||||
defm : FusedBinaryActivationFuncOpPat<BinaryOps>;
|
||||
|
||||
// The constant folding in this pass might produce constant in the tf dialect.
|
||||
// This rule is to legalize these constant to the tfl dialect.
|
||||
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
|
||||
|
Loading…
Reference in New Issue
Block a user