Add transformations for TF/TFLite:

1) Reorder 2 successive adds with constants to allow constant folding.
2) Remove trivial Add/Sub/Mul/Div.

PiperOrigin-RevId: 309480295
Change-Id: Ie84f5fd4da79dc80a3a5ee806051bd63b5b49888
This commit is contained in:
Karim Nosir 2020-05-01 15:00:32 -07:00 committed by TensorFlower Gardener
parent 6b65281daa
commit 412b12ea8d
5 changed files with 182 additions and 0 deletions

View File

@ -958,3 +958,16 @@ func @FusingdivRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32>
// Fusing: %[[div2:[0-9].*]] = tfl.div %[[relu]], %[[div1]] {fused_activation_function = "RELU6"} : tensor<1xf32>
// Fusing: return
}
func @ReorderAddWithConstant(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<1.0> : tensor<2x2xf32>
%cst_1 = constant dense<2.0> : tensor<2x2xf32>
%0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
// CHECK-LABEL: ReorderAddWithConstant
// CHECK: %[[CONST:.*]] = constant dense<3.000000e+00> : tensor<2x2xf32>
// CHECK: %[[RESULT:.*]] = tfl.add %arg0, %[[CONST]] {fused_activation_function = "NONE"} : tensor<2x2xf32>
}

View File

@ -457,3 +457,21 @@ def : Pat<(TFL_AddOp
// The constant folding in this pass might produce constant in the tf dialect.
// This rule is to legalize these constant to the tfl dialect.
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
// Reorders adds to allow constant folding.
// Add --> Add $input, $constantA
// \--> $constantB
// To
// Add --> $input
// \--> Add ($constantA, $constantB)
foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
def : Pat<(TFL_AddOp
(TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None),
(ConstantOp $b), ActFun),
(TFL_AddOp $input,
(TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
ActFun),
[(HasOneUse $first_output)]>;
}

View File

@ -160,6 +160,8 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def TF_AllOp : TF_Op<"All", [NoSideEffect]> {
@ -1963,6 +1965,8 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape]>,
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape]>,
@ -4960,6 +4964,8 @@ def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShap
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasFolder = 1;
}
def TF_MulNoNanOp : TF_Op<"MulNoNan", [NoSideEffect, ResultsBroadcastableShape]>,
@ -8119,6 +8125,8 @@ def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>,
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def TF_SumOp : TF_Op<"Sum", [NoSideEffect]> {

View File

@ -494,6 +494,48 @@ LogicalResult FoldOperandsPermutation(
return success();
}
//===----------------------------------------------------------------------===//
// Rewrite Pattern for removing trivial Arithmetic op.
//===----------------------------------------------------------------------===//
namespace {
// Utility methods that returns Identity value to use for selected ops.
APFloat GetIdentity(AddV2Op op) { return APFloat(0.0f); }
APFloat GetIdentity(SubOp op) { return APFloat(0.0f); }
APFloat GetIdentity(MulOp op) { return APFloat(1.0f); }
APFloat GetIdentity(DivOp op) { return APFloat(1.0f); }
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
// known to be Identity (e.g X+0)
template <typename OP>
OpFoldResult TrivialArithmeticOpFolder(OP arithmetic_op) {
DenseFPElementsAttr rhs_value;
auto constant_val = arithmetic_op.y();
if (!matchPattern(constant_val, m_Constant(&rhs_value))) {
return {};
}
auto result_op_type = arithmetic_op.getResult().getType();
auto lhs_type = arithmetic_op.x().getType();
if (!result_op_type.template isa<ShapedType>() ||
!lhs_type.template isa<ShapedType>() ||
!result_op_type.template cast<ShapedType>().hasStaticShape()) {
return {};
}
// We only handle non-broadcastable case.
if (result_op_type != lhs_type) {
return {};
}
auto identity_val = GetIdentity(arithmetic_op);
for (auto it = rhs_value.float_value_begin();
it != rhs_value.float_value_end(); ++it) {
if (*it != identity_val) return {};
}
return arithmetic_op.x();
}
} // namespace
namespace {
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
} // namespace
@ -525,6 +567,10 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<AddV2OfNegLeft, AddV2OfNegRight>(context);
}
OpFoldResult AddV2Op::fold(ArrayRef<Attribute> operands) {
return TrivialArithmeticOpFolder<AddV2Op>(*this);
}
//===----------------------------------------------------------------------===//
// AllOp
//===----------------------------------------------------------------------===//
@ -1271,6 +1317,10 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<DivWithSqrtDivisor>(context);
}
OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
return TrivialArithmeticOpFolder<DivOp>(*this);
}
//===----------------------------------------------------------------------===//
// DynamicStitchOp
//===----------------------------------------------------------------------===//
@ -1936,6 +1986,14 @@ LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
return success();
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
return TrivialArithmeticOpFolder<MulOp>(*this);
}
//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//
@ -2904,6 +2962,10 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
results.insert<SubOfNeg>(context);
}
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
return TrivialArithmeticOpFolder<SubOp>(*this);
}
//===----------------------------------------------------------------------===//
// SumOp
//===----------------------------------------------------------------------===//

View File

@ -251,3 +251,84 @@ func @testTensorListElementShape(%arg0: tensor<!tf.variant<tensor<2x4xf32>>>) ->
// CHECK-NEXT: return [[cst]] : tensor<2xi32>
return %0: tensor<2xi32>
}
func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<0.0> : tensor<2x2xf32>
%0 = "tf.Add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = "tf.Add"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
// CHECK-LABEL: RemoveTrivialAdd
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
}
func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<0.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = "tf.AddV2"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
// CHECK-LABEL: RemoveTrivialAddV2
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
}
func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<0.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = "tf.Sub"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
// CHECK-LABEL: RemoveTrivialSub
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
}
func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<1.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = "tf.Mul"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
// CHECK-LABEL: RemoveTrivialMul
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
}
func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<1.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
%1 = "tf.Div"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
// CHECK-LABEL: RemoveTrivialDiv
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
}
func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<0.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
%1 = "tf.AddV2"(%0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %1 : tensor<2x2xf32>
// CHECK-LABEL: DontRemoveTrivialAdd
// CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32>
// CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: return %[[RESULT]] : tensor<2x2xf32>
}
func @DontRemoveTrivialAdd2(%arg0: tensor<?x?xf32>, %arg1: tensor<2x2xf32>) -> tensor<?x?xf32> {
%cst = constant dense<0.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
%1 = "tf.AddV2"(%0, %cst) : (tensor<?x?xf32> , tensor<2x2xf32>) -> tensor<?x?xf32>
return %1 :tensor<?x?xf32>
// CHECK-LABEL: DontRemoveTrivialAdd2
// CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32>
// CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
// CHECK: return %[[RESULT]] : tensor<?x?xf32>
}