Adds PReLU pattern & fixes bug that prevent its legalization

PiperOrigin-RevId: 292230500
Change-Id: I365427126e09d0d80a5f0a839bd712d70255f8b3
This commit is contained in:
Sachin Joglekar 2020-01-29 15:26:27 -08:00 committed by TensorFlower Gardener
parent 71c6f97e2d
commit 05ee75360e
5 changed files with 87 additions and 13 deletions

View File

@ -1301,6 +1301,19 @@ OpFoldResult AbsOp::fold(ArrayRef<Attribute> operands) {
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//
OpFoldResult NegOp::fold(ArrayRef<Attribute> operands) {
Type result_type = getType();
// Only constant fold for tensor of f32 is implemented.
if (!IsF32ShapedType(result_type)) return nullptr;
auto compute = [](APFloat value) -> APFloat { return llvm::neg(value); };
return ConstFoldUnaryOp(result_type, operands[0], compute);
}
//===----------------------------------------------------------------------===//
// SinOp
//===----------------------------------------------------------------------===//

View File

@ -1810,6 +1810,8 @@ def TFL_NegOp: TFL_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
let results = (outs AnyTensor:$y);
let hasOptions = 0b1;
let hasFolder = 1;
}
def TFL_PackOp : TFL_Op<"pack", [NoSideEffect, SameOperandsAndResultsScale]> {

View File

@ -758,6 +758,33 @@ func @leaky_relu_not_fused(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: %[[RESULT:[0-9].*]] = "tfl.maximum"
}
// CHECK-LABEL: prelu_fusion
func @prelu_fusion(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%alpha = constant dense<-0.2> : tensor<3xf32>
%0 = "tfl.relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%1 = "tfl.neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%2 = "tfl.relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%3 = "tfl.mul"(%alpha, %2) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
%4 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %4 : tensor<2x3xf32>
// CHECK: %[[RESULT:[0-9].*]] = "tfl.prelu"
}
// CHECK-LABEL: prelu_not_fused
// Rank of alpha should be one less than input for PReLU, which is not the case.
func @prelu_not_fused(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
%alpha = constant dense<-0.2> : tensor<f32>
%0 = "tfl.relu"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%1 = "tfl.neg"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%2 = "tfl.relu"(%1) : (tensor<2x3xf32>) -> tensor<2x3xf32>
%3 = "tfl.mul"(%alpha, %2) {fused_activation_function = "NONE"} : (tensor<f32>, tensor<2x3xf32>) -> tensor<2x3xf32>
%4 = "tfl.add"(%0, %3) {fused_activation_function = "NONE"} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
return %4 : tensor<2x3xf32>
// CHECK: %[[RESULT:[0-9].*]] = "tfl.relu"
}
// CHECK-LABEL: NotfuseAddIntoConv2d_MultipleUsers
func @NotfuseAddIntoConv2d_MultipleUsers(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
%cst = constant dense<1.5> : tensor<16xf32>

View File

@ -54,6 +54,10 @@ def HasSameStaticShapesPred : CPred<"HasSameStaticShapes($0.getDefiningOp())">;
def HasSameStaticShapes : Constraint<HasSameStaticShapesPred, "op must have static same input shapes">;
def HasNotSameStaticShapes : Constraint<Neg<HasSameStaticShapesPred>, "op must have not static same input shapes">;
// Checks if the value has only one user.
// TODO(karimnosseir): Move to a common place?
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
//===----------------------------------------------------------------------===//
// Nullary ops patterns.
//===----------------------------------------------------------------------===//
@ -198,16 +202,20 @@ def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>;
def : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>;
// Multi-pattern consisting of matching stand-alone op or op followed by relu.
// TODO(karimnosseir): Can the activation part here be removed by modifying the
// very similar pass in optimize_patterns.td?
multiclass FusedBinaryActivationFuncOpPat<dag FromOp, dag ToOp> {
def : Pat<(FromOp AnyTensor:$l, AnyTensor:$r),
(ToOp $l, $r, TFL_AF_None)>;
foreach actFnPair = [[TF_ReluOp, TFL_AF_Relu],
[TF_Relu6Op, TFL_AF_Relu6]] in {
def : Pat<(actFnPair[0] (FromOp $lhs, $rhs)),
(ToOp $lhs, $rhs, actFnPair[1])>;
def : Pat<(actFnPair[0] (FromOp:$bin_out $lhs, $rhs)),
(ToOp $lhs, $rhs, actFnPair[1]),
[(HasOneUse $bin_out)]>;
// TODO: Maybe move these below to general pass?
def : Pat<(actFnPair[0] (ToOp $lhs, $rhs, TFL_AF_None)),
(ToOp $lhs, $rhs, actFnPair[1])>;
def : Pat<(actFnPair[0] (ToOp:$bin_out $lhs, $rhs, TFL_AF_None)),
(ToOp $lhs, $rhs, actFnPair[1]),
[(HasOneUse $bin_out)]>;
}
}

View File

@ -26,26 +26,31 @@ def F32ElementsAttr : ElementsAttrBase<
def ExtractSingleElementAsFloat : NativeCodeCall<
"ExtractSingleElementAsFloat($_self.cast<ElementsAttr>())">;
// Checks if the value has only one user.
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
//===----------------------------------------------------------------------===//
// Ternary ops patterns.
//===----------------------------------------------------------------------===//
// Multi-pattern consisting of matching stand-alone convolution op followed by
// activation op.
multiclass FuseActFnIntoConvOpPat<dag ActFnOp, dag ActFnAttr> {
def : Pat<(ActFnOp (TFL_Conv2DOp $input, $filter, $bias,
def : Pat<(ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias,
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w)),
(TFL_Conv2DOp $input, $filter, $bias,
$h_factor, $w_factor, ActFnAttr,
$padding, $stride_h, $stride_w)>;
def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp $input, $filter, $bias,
$padding, $stride_h, $stride_w),
[(HasOneUse $conv_out)]>;
def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias,
$h_factor, $w_factor, TFL_AF_None,
$padding, $stride_h, $stride_w,
$multiplier)),
(TFL_DepthwiseConv2DOp $input, $filter, $bias,
$h_factor, $w_factor, ActFnAttr,
$padding, $stride_h, $stride_w,
$multiplier)>;
$multiplier),
[(HasOneUse $conv_out)]>;
}
// TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused
@ -61,9 +66,6 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
class CanFuseConvOrDepthwiseConv<string is_depthwise> : Constraint<
CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>;
// Checks if the value has only one user.
def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
// If we see a binary op (add, sub) op adding a constant value to a convolution
// op with constant bias, we can fuse the binary op into the convolution op by
// constant folding the bias and the binary op's constant operand. The following
@ -291,8 +293,9 @@ multiclass FusedBinaryActivationFuncOpPat<dag BinaryOp> {
foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu],
[TFL_Relu6Op, TFL_AF_Relu6],
[TFL_Relu1Op, TFL_AF_Relu1]] in {
def : Pat<(actFnPair[0] (BinaryOp $lhs, $rhs, TFL_AF_None)),
(BinaryOp $lhs, $rhs, actFnPair[1])>;
def : Pat<(actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)),
(BinaryOp $lhs, $rhs, actFnPair[1]),
[(HasOneUse $binary_out)]>;
}
}
@ -382,6 +385,27 @@ def : Pat<(TFL_MaximumOp (TFL_MulOp:$mul_out $input1,
(EqualOperands $input1, $input2),
(HasOneUse $mul_out)]>;
// Checks if the operand0's rank is one less than operand1's rank.
def PReluAlphaRankCheck : Constraint<
CPred<"$0.getType().cast<ShapedType>().getRank() == "
"$1.getType().cast<ShapedType>().getRank() - 1">>;
// PReLU pattern from Keras:
// f(x) = Relu(x) + (-alpha * Relu(-x))
def : Pat<(TFL_AddOp
(TFL_ReluOp:$relu_out $input1),
(TFL_MulOp:$mul_out
(TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)),
$neg_alpha,
TFL_AF_None),
TFL_AF_None),
(TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)),
[(EqualOperands $input1, $input2),
(PReluAlphaRankCheck $neg_alpha, $input1),
(HasOneUse $relu_out),
(HasOneUse $mul_out),
(HasOneUse $input_neg_out)]>;
// The constant folding in this pass might produce constant in the tf dialect.
// This rule is to legalize these constant to the tfl dialect.
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;