Improve TensorFlow arithmatic op folders

* Fix crash for non fp32 types by creating identity attribute of the appropriate type.
* Handle integer types
* Handle identity value as lhs for symmetric ops
* Utilize isSplat and getSplatValue helpers.

PiperOrigin-RevId: 310162612
Change-Id: I863a5bf5cb64c832dd938e22c7694d34236dcfc3
This commit is contained in:
Smit Hinsu 2020-05-06 09:04:48 -07:00 committed by TensorFlower Gardener
parent c9aeeea2a8
commit 5bb727ee34
2 changed files with 92 additions and 29 deletions

View File

@ -501,40 +501,49 @@ LogicalResult FoldOperandsPermutation(
//===----------------------------------------------------------------------===//
namespace {
// Utility methods that returns Identity value to use for selected ops.
APFloat GetIdentity(AddV2Op op) { return APFloat(0.0f); }
APFloat GetIdentity(SubOp op) { return APFloat(0.0f); }
APFloat GetIdentity(MulOp op) { return APFloat(1.0f); }
APFloat GetIdentity(DivOp op) { return APFloat(1.0f); }
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
// known to be Identity (e.g X+0)
template <typename OP>
OpFoldResult TrivialArithmeticOpFolder(OP arithmetic_op) {
DenseFPElementsAttr rhs_value;
auto constant_val = arithmetic_op.y();
if (!matchPattern(constant_val, m_Constant(&rhs_value))) {
return {};
}
template <typename OpT,
typename std::enable_if<llvm::is_one_of<
OpT, AddV2Op, SubOp, MulOp, DivOp>::value>::type * = nullptr>
OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
ArrayRef<Attribute> operands) {
auto result_op_type = arithmetic_op.getResult().getType();
auto lhs_type = arithmetic_op.x().getType();
if (!result_op_type.template isa<ShapedType>() ||
!lhs_type.template isa<ShapedType>() ||
!result_op_type.template cast<ShapedType>().hasStaticShape()) {
return {};
}
auto lhs_type = arithmetic_op.x().getType().template cast<ShapedType>();
if (!result_op_type.template cast<ShapedType>().hasStaticShape()) return {};
// We only handle non-broadcastable case.
if (result_op_type != lhs_type) {
return {};
}
auto identity_val = GetIdentity(arithmetic_op);
for (auto it = rhs_value.float_value_begin();
it != rhs_value.float_value_end(); ++it) {
if (*it != identity_val) return {};
// Mul and Div ops have identity value one while AddV2 and SubOp have identity
// value zero.
int identity =
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value);
Type element_ty = lhs_type.getElementType();
Attribute identity_attr;
if (auto ty = element_ty.template dyn_cast<FloatType>()) {
identity_attr = FloatAttr::get(ty, static_cast<double>(identity));
} else if (auto ty = element_ty.template dyn_cast<IntegerType>()) {
identity_attr = IntegerAttr::get(ty, static_cast<int64_t>(identity));
} else {
return {};
}
return arithmetic_op.x();
if (auto attr = operands[1].dyn_cast_or_null<DenseElementsAttr>()) {
if (attr.isSplat() && attr.getSplatValue() == identity_attr)
return arithmetic_op.x();
}
bool is_symmetric =
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
if (is_symmetric && attr.isSplat() && attr.getSplatValue() == identity_attr)
return arithmetic_op.y();
}
return {};
}
} // namespace
@ -570,7 +579,7 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
}
OpFoldResult AddV2Op::fold(ArrayRef<Attribute> operands) {
return TrivialArithmeticOpFolder<AddV2Op>(*this);
return IdentityArithmeticOpFolder<AddV2Op>(*this, operands);
}
//===----------------------------------------------------------------------===//
@ -1320,7 +1329,7 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
}
OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
return TrivialArithmeticOpFolder<DivOp>(*this);
return IdentityArithmeticOpFolder<DivOp>(*this, operands);
}
//===----------------------------------------------------------------------===//
@ -2026,7 +2035,7 @@ LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
//===----------------------------------------------------------------------===//
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
return TrivialArithmeticOpFolder<MulOp>(*this);
return IdentityArithmeticOpFolder<MulOp>(*this, operands);
}
//===----------------------------------------------------------------------===//
@ -2998,7 +3007,7 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
}
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
return TrivialArithmeticOpFolder<SubOp>(*this);
return IdentityArithmeticOpFolder<SubOp>(*this, operands);
}
//===----------------------------------------------------------------------===//

View File

@ -303,6 +303,24 @@ func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
}
func @RemoveTrivialAddBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
%cst = constant dense<0.0> : tensor<2x2xbf16>
%0 = "tf.Add"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
return %0 : tensor<2x2xbf16>
// CHECK-LABEL: RemoveTrivialAdd
// CHECK-NEXT: return %arg0 : tensor<2x2xbf16>
}
func @RemoveTrivialAddBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
%cst = constant dense<0.0> : tensor<2x2xbf16>
%0 = "tf.Add"(%cst, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
return %0 : tensor<2x2xbf16>
// CHECK-LABEL: RemoveTrivialAdd
// CHECK-NEXT: return %arg0 : tensor<2x2xbf16>
}
func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<0.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
@ -325,6 +343,15 @@ func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
}
func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> {
%cst = constant dense<0> : tensor<2x2xi8>
%0 = "tf.Sub"(%arg0, %cst) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8>
return %0 : tensor<2x2xi8>
// CHECK-LABEL: RemoveTrivialSubInt8
// CHECK-NEXT: return %arg0 : tensor<2x2xi8>
}
func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<1.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
@ -347,6 +374,33 @@ func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
}
func @RemoveTrivialDivBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
%cst = constant dense<1.0> : tensor<2x2xbf16>
%0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
return %0 : tensor<2x2xbf16>
// CHECK-LABEL: RemoveTrivialDiv
// CHECK-NEXT: return %arg0 : tensor<2x2xbf16>
}
func @RemoveTrivialMulInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> {
%cst = constant dense<1> : tensor<2x2xi8>
%0 = "tf.Mul"(%cst, %arg0) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8>
return %0 : tensor<2x2xi8>
// CHECK-LABEL: RemoveTrivialMulInt8
// CHECK-NEXT: return %arg0 : tensor<2x2xi8>
}
func @DivBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
%cst = constant dense<1.0> : tensor<2x2xbf16>
%0 = "tf.Div"(%cst, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
return %0 : tensor<2x2xbf16>
// CHECK-LABEL: DivBf16LHS
// CHECK: tf.Div
}
func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
%cst = constant dense<0.0> : tensor<2x2xf32>
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>