Add transformations for TF/TFLite:
1) Reorder 2 successive adds with constants to allow constant folding. 2) Remove trivial Add/Sub/Mul/Div. PiperOrigin-RevId: 309480295 Change-Id: Ie84f5fd4da79dc80a3a5ee806051bd63b5b49888
This commit is contained in:
parent
6b65281daa
commit
412b12ea8d
@ -958,3 +958,16 @@ func @FusingdivRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32>
|
||||
// Fusing: %[[div2:[0-9].*]] = tfl.div %[[relu]], %[[div1]] {fused_activation_function = "RELU6"} : tensor<1xf32>
|
||||
// Fusing: return
|
||||
}
|
||||
|
||||
func @ReorderAddWithConstant(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<1.0> : tensor<2x2xf32>
|
||||
%cst_1 = constant dense<2.0> : tensor<2x2xf32>
|
||||
%0 = "tfl.add"(%arg0, %cst) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%1 = "tfl.add"(%0, %cst_1) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: ReorderAddWithConstant
|
||||
// CHECK: %[[CONST:.*]] = constant dense<3.000000e+00> : tensor<2x2xf32>
|
||||
// CHECK: %[[RESULT:.*]] = tfl.add %arg0, %[[CONST]] {fused_activation_function = "NONE"} : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
@ -457,3 +457,21 @@ def : Pat<(TFL_AddOp
|
||||
// The constant folding in this pass might produce constant in the tf dialect.
|
||||
// This rule is to legalize these constant to the tfl dialect.
|
||||
def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>;
|
||||
|
||||
// Reorders adds to allow constant folding.
|
||||
// Add --> Add $input, $constantA
|
||||
// \--> $constantB
|
||||
// To
|
||||
// Add --> $input
|
||||
// \--> Add ($constantA, $constantB)
|
||||
foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in {
|
||||
def : Pat<(TFL_AddOp
|
||||
(TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None),
|
||||
(ConstantOp $b), ActFun),
|
||||
(TFL_AddOp $input,
|
||||
(TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None),
|
||||
ActFun),
|
||||
[(HasOneUse $first_output)]>;
|
||||
}
|
||||
|
||||
|
||||
|
@ -160,6 +160,8 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_AllOp : TF_Op<"All", [NoSideEffect]> {
|
||||
@ -1963,6 +1965,8 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
@ -4960,6 +4964,8 @@ def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShap
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_MulNoNanOp : TF_Op<"MulNoNan", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
@ -8119,6 +8125,8 @@ def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TF_SumOp : TF_Op<"Sum", [NoSideEffect]> {
|
||||
|
@ -494,6 +494,48 @@ LogicalResult FoldOperandsPermutation(
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Rewrite Pattern for removing trivial Arithmetic op.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// Utility methods that returns Identity value to use for selected ops.
|
||||
|
||||
APFloat GetIdentity(AddV2Op op) { return APFloat(0.0f); }
|
||||
APFloat GetIdentity(SubOp op) { return APFloat(0.0f); }
|
||||
APFloat GetIdentity(MulOp op) { return APFloat(1.0f); }
|
||||
APFloat GetIdentity(DivOp op) { return APFloat(1.0f); }
|
||||
|
||||
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
|
||||
// known to be Identity (e.g X+0)
|
||||
template <typename OP>
|
||||
OpFoldResult TrivialArithmeticOpFolder(OP arithmetic_op) {
|
||||
DenseFPElementsAttr rhs_value;
|
||||
auto constant_val = arithmetic_op.y();
|
||||
if (!matchPattern(constant_val, m_Constant(&rhs_value))) {
|
||||
return {};
|
||||
}
|
||||
auto result_op_type = arithmetic_op.getResult().getType();
|
||||
auto lhs_type = arithmetic_op.x().getType();
|
||||
if (!result_op_type.template isa<ShapedType>() ||
|
||||
!lhs_type.template isa<ShapedType>() ||
|
||||
!result_op_type.template cast<ShapedType>().hasStaticShape()) {
|
||||
return {};
|
||||
}
|
||||
// We only handle non-broadcastable case.
|
||||
if (result_op_type != lhs_type) {
|
||||
return {};
|
||||
}
|
||||
auto identity_val = GetIdentity(arithmetic_op);
|
||||
for (auto it = rhs_value.float_value_begin();
|
||||
it != rhs_value.float_value_end(); ++it) {
|
||||
if (*it != identity_val) return {};
|
||||
}
|
||||
|
||||
return arithmetic_op.x();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc"
|
||||
} // namespace
|
||||
@ -525,6 +567,10 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<AddV2OfNegLeft, AddV2OfNegRight>(context);
|
||||
}
|
||||
|
||||
OpFoldResult AddV2Op::fold(ArrayRef<Attribute> operands) {
|
||||
return TrivialArithmeticOpFolder<AddV2Op>(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1271,6 +1317,10 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<DivWithSqrtDivisor>(context);
|
||||
}
|
||||
|
||||
OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
|
||||
return TrivialArithmeticOpFolder<DivOp>(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicStitchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1936,6 +1986,14 @@ LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
|
||||
return TrivialArithmeticOpFolder<MulOp>(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NegOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2904,6 +2962,10 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
results.insert<SubOfNeg>(context);
|
||||
}
|
||||
|
||||
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
|
||||
return TrivialArithmeticOpFolder<SubOp>(*this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SumOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -251,3 +251,84 @@ func @testTensorListElementShape(%arg0: tensor<!tf.variant<tensor<2x4xf32>>>) ->
|
||||
// CHECK-NEXT: return [[cst]] : tensor<2xi32>
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||
%0 = "tf.Add"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%1 = "tf.Add"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialAdd
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%1 = "tf.AddV2"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialAddV2
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%1 = "tf.Sub"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialSub
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<1.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%1 = "tf.Mul"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialMul
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<1.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%1 = "tf.Div"(%0, %cst) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialDiv
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
%1 = "tf.AddV2"(%0, %cst) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %1 : tensor<2x2xf32>
|
||||
|
||||
// CHECK-LABEL: DontRemoveTrivialAdd
|
||||
// CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32>
|
||||
// CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @DontRemoveTrivialAdd2(%arg0: tensor<?x?xf32>, %arg1: tensor<2x2xf32>) -> tensor<?x?xf32> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
|
||||
%1 = "tf.AddV2"(%0, %cst) : (tensor<?x?xf32> , tensor<2x2xf32>) -> tensor<?x?xf32>
|
||||
return %1 :tensor<?x?xf32>
|
||||
|
||||
// CHECK-LABEL: DontRemoveTrivialAdd2
|
||||
// CHECK: %[[CONST:.*]] = constant dense<0.000000e+00> : tensor<2x2xf32>
|
||||
// CHECK: %[[add:.*]] = "tf.AddV2"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.AddV2"(%[[add]], %[[CONST]]) : (tensor<?x?xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<?x?xf32>
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user