Adds PReLU pattern & fixes bug that prevent its legalization
PiperOrigin-RevId: 292230500 Change-Id: I365427126e09d0d80a5f0a839bd712d70255f8b3
This commit is contained in:
parent
71c6f97e2d
commit
05ee75360e
@ -1301,6 +1301,19 @@ OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
|
||||
return ConstFoldUnaryOp(result_type, operands[0], compute);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NegOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
|
||||
Type result_type = getType();
|
||||
// Only constant fold for tensor of f32 is implemented.
|
||||
if (!IsF32ShapedType(result_type)) return nullptr;
|
||||
|
||||
auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
|
||||
return ConstFoldUnaryOp(result_type, operands[0], compute);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SinOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -1810,6 +1810,8 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let results = (outs AnyTensor:$y);
|
||||
|
||||
let hasOptions = 0b1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {
|
||||
|
||||
@ -758,6 +758,33 @@ func @leaky_relu_not_fused(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
// CHECK: %[[RESULT:[0-9].*]] = "tfl.maximum"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: prelu_fusion
|
||||
func @prelu_fusion(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%alpha = constant dense<-0.2> : tensor<3xf32>
|
||||
%0 = "tfl.relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%1 = "tfl.neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%2 = "tfl.relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%3 = "tfl.mul"(%alpha, %2) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%4 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %4 : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:[0-9].*]] = "tfl.prelu"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: prelu_not_fused
|
||||
// Rank of alpha should be one less than input for PReLU, which is not the case.
|
||||
func @prelu_not_fused(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%alpha = constant dense<-0.2> : tensor<f32>
|
||||
%0 = "tfl.relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%1 = "tfl.neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%2 = "tfl.relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%3 = "tfl.mul"(%alpha, %2) {fused_activation_function = "NONE"} : (tensor<f32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%4 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %4 : tensor<2x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:[0-9].*]] = "tfl.relu"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: NotfuseAddIntoConv2d_MultipleUsers
|
||||
func @NotfuseAddIntoConv2d_MultipleUsers(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
|
||||
%cst = constant dense<1.5> : tensor<16xf32>
|
||||
|
||||
@ -54,6 +54,10 @@ def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">;
|
||||
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
|
||||
def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">;
|
||||
|
||||
// Checks if the value has only one user.
|
||||
// TODO(karimnosseir): Move to a common place?
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Nullary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -198,16 +202,20 @@ def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>;
|
||||
def : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>;
|
||||
|
||||
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
|
||||
// TODO(karimnosseir): Can the activation part here be removed by modifying the
|
||||
// very similar pass in optimize_patterns.td?
|
||||
multiclass FusedBinaryActivationFuncOpPat<dag FromOp, dag ToOp> {
|
||||
def : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
|
||||
(ToOp $l, $r, TFL_AF_None)>;
|
||||
foreach actFnPair = [[TF_ReluOp, TFL_AF_Relu],
|
||||
[TF_Relu6Op, TFL_AF_Relu6]] in {
|
||||
def : Pat<(actFnPair[0] (FromOp $lhs, $rhs)),
|
||||
(ToOp $lhs, $rhs, actFnPair[1])>;
|
||||
def : Pat<(actFnPair[0] (FromOp:$bin_out $lhs, $rhs)),
|
||||
(ToOp $lhs, $rhs, actFnPair[1]),
|
||||
[(HasOneUse $bin_out)]>;
|
||||
// TODO: Maybe move these below to general pass?
|
||||
def : Pat<(actFnPair[0] (ToOp $lhs, $rhs, TFL_AF_None)),
|
||||
(ToOp $lhs, $rhs, actFnPair[1])>;
|
||||
def : Pat<(actFnPair[0] (ToOp:$bin_out $lhs, $rhs, TFL_AF_None)),
|
||||
(ToOp $lhs, $rhs, actFnPair[1]),
|
||||
[(HasOneUse $bin_out)]>;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -26,26 +26,31 @@ def F32ElementsAttr : ElementsAttrBase<
|
||||
def ExtractSingleElementAsFloat : NativeCodeCall<
|
||||
"ExtractSingleElementAsFloat($_self.cast<ElementsAttr>())">;
|
||||
|
||||
// Checks if the value has only one user.
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ternary ops patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Multi-pattern consisting of matching stand-alone convolution op followed by
|
||||
// activation op.
|
||||
multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
|
||||
def : Pat<(ActFnOp (TFL_Conv2DOp $input, $filter, $bias,
|
||||
def : Pat<(ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias,
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w)),
|
||||
(TFL_Conv2DOp $input, $filter, $bias,
|
||||
$h_factor, $w_factor, ActFnAttr,
|
||||
$padding, $stride_h, $stride_w)>;
|
||||
def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp $input, $filter, $bias,
|
||||
$padding, $stride_h, $stride_w),
|
||||
[(HasOneUse $conv_out)]>;
|
||||
def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias,
|
||||
$h_factor, $w_factor, TFL_AF_None,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier)),
|
||||
(TFL_DepthwiseConv2DOp $input, $filter, $bias,
|
||||
$h_factor, $w_factor, ActFnAttr,
|
||||
$padding, $stride_h, $stride_w,
|
||||
$multiplier)>;
|
||||
$multiplier),
|
||||
[(HasOneUse $conv_out)]>;
|
||||
}
|
||||
|
||||
// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
|
||||
@ -61,9 +66,6 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
|
||||
CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>;
|
||||
|
||||
// Checks if the value has only one user.
|
||||
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||
|
||||
// If we see a binary op (add, sub) op adding a constant value to a convolution
|
||||
// op with constant bias, we can fuse the binary op into the convolution op by
|
||||
// constant folding the bias and the binary op's constant operand. The following
|
||||
@ -291,8 +293,9 @@ multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
|
||||
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
|
||||
[TFL_Relu6Op, TFL_AF_Relu6],
|
||||
[TFL_Relu1Op, TFL_AF_Relu1]] in {
|
||||
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
|
||||
(BinaryOp $lhs, $rhs, actFnPair[1])>;
|
||||
def : Pat<(actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)),
|
||||
(BinaryOp $lhs, $rhs, actFnPair[1]),
|
||||
[(HasOneUse $binary_out)]>;
|
||||
}
|
||||
}
|
||||
|
||||
@ -382,6 +385,27 @@ def : Pat<(TFL_MaximumOp (TFL_MulOp:$mul_out $input1,
|
||||
(EqualOperands $input1, $input2),
|
||||
(HasOneUse $mul_out)]>;
|
||||
|
||||
// Checks if the operand0's rank is one less than operand1's rank.
|
||||
def PReluAlphaRankCheck : Constraint<
|
||||
CPred<"$0.getType().cast<ShapedType>().getRank() == "
|
||||
"$1.getType().cast<ShapedType>().getRank() - 1">>;
|
||||
|
||||
// PReLU pattern from Keras:
|
||||
// f(x) = Relu(x) + (-alpha * Relu(-x))
|
||||
def : Pat<(TFL_AddOp
|
||||
(TFL_ReluOp:$relu_out $input1),
|
||||
(TFL_MulOp:$mul_out
|
||||
(TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)),
|
||||
$neg_alpha,
|
||||
TFL_AF_None),
|
||||
TFL_AF_None),
|
||||
(TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)),
|
||||
[(EqualOperands $input1, $input2),
|
||||
(PReluAlphaRankCheck $neg_alpha, $input1),
|
||||
(HasOneUse $relu_out),
|
||||
(HasOneUse $mul_out),
|
||||
(HasOneUse $input_neg_out)]>;
|
||||
|
||||
// The constant folding in this pass might produce constant in the tf dialect.
|
||||
// This rule is to legalize these constant to the tfl dialect.
|
||||
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user