Improve TensorFlow arithmatic op folders
* Fix crash for non fp32 types by creating identity attribute of the appropriate type. * Handle integer types * Handle identity value as lhs for symmetric ops * Utilize isSplat and getSplatValue helpers. PiperOrigin-RevId: 310162612 Change-Id: I863a5bf5cb64c832dd938e22c7694d34236dcfc3
This commit is contained in:
parent
c9aeeea2a8
commit
5bb727ee34
@ -501,40 +501,49 @@ LogicalResult FoldOperandsPermutation(
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// Utility methods that returns Identity value to use for selected ops.
|
||||
|
||||
APFloat GetIdentity(AddV2Op op) { return APFloat(0.0f); }
|
||||
APFloat GetIdentity(SubOp op) { return APFloat(0.0f); }
|
||||
APFloat GetIdentity(MulOp op) { return APFloat(1.0f); }
|
||||
APFloat GetIdentity(DivOp op) { return APFloat(1.0f); }
|
||||
|
||||
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
|
||||
// known to be Identity (e.g X+0)
|
||||
template <typename OP>
|
||||
OpFoldResult TrivialArithmeticOpFolder(OP arithmetic_op) {
|
||||
DenseFPElementsAttr rhs_value;
|
||||
auto constant_val = arithmetic_op.y();
|
||||
if (!matchPattern(constant_val, m_Constant(&rhs_value))) {
|
||||
return {};
|
||||
}
|
||||
template <typename OpT,
|
||||
typename std::enable_if<llvm::is_one_of<
|
||||
OpT, AddV2Op, SubOp, MulOp, DivOp>::value>::type * = nullptr>
|
||||
OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
|
||||
ArrayRef<Attribute> operands) {
|
||||
auto result_op_type = arithmetic_op.getResult().getType();
|
||||
auto lhs_type = arithmetic_op.x().getType();
|
||||
if (!result_op_type.template isa<ShapedType>() ||
|
||||
!lhs_type.template isa<ShapedType>() ||
|
||||
!result_op_type.template cast<ShapedType>().hasStaticShape()) {
|
||||
return {};
|
||||
}
|
||||
auto lhs_type = arithmetic_op.x().getType().template cast<ShapedType>();
|
||||
if (!result_op_type.template cast<ShapedType>().hasStaticShape()) return {};
|
||||
|
||||
// We only handle non-broadcastable case.
|
||||
if (result_op_type != lhs_type) {
|
||||
return {};
|
||||
}
|
||||
auto identity_val = GetIdentity(arithmetic_op);
|
||||
for (auto it = rhs_value.float_value_begin();
|
||||
it != rhs_value.float_value_end(); ++it) {
|
||||
if (*it != identity_val) return {};
|
||||
|
||||
// Mul and Div ops have identity value one while AddV2 and SubOp have identity
|
||||
// value zero.
|
||||
int identity =
|
||||
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value);
|
||||
|
||||
Type element_ty = lhs_type.getElementType();
|
||||
Attribute identity_attr;
|
||||
if (auto ty = element_ty.template dyn_cast<FloatType>()) {
|
||||
identity_attr = FloatAttr::get(ty, static_cast<double>(identity));
|
||||
} else if (auto ty = element_ty.template dyn_cast<IntegerType>()) {
|
||||
identity_attr = IntegerAttr::get(ty, static_cast<int64_t>(identity));
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
|
||||
return arithmetic_op.x();
|
||||
if (auto attr = operands[1].dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
if (attr.isSplat() && attr.getSplatValue() == identity_attr)
|
||||
return arithmetic_op.x();
|
||||
}
|
||||
|
||||
bool is_symmetric =
|
||||
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
|
||||
if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
if (is_symmetric && attr.isSplat() && attr.getSplatValue() == identity_attr)
|
||||
return arithmetic_op.y();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@ -570,7 +579,7 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
}
|
||||
|
||||
OpFoldResult AddV2Op::fold(ArrayRef<Attribute> operands) {
|
||||
return TrivialArithmeticOpFolder<AddV2Op>(*this);
|
||||
return IdentityArithmeticOpFolder<AddV2Op>(*this, operands);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1320,7 +1329,7 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
}
|
||||
|
||||
OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
|
||||
return TrivialArithmeticOpFolder<DivOp>(*this);
|
||||
return IdentityArithmeticOpFolder<DivOp>(*this, operands);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2026,7 +2035,7 @@ LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
|
||||
return TrivialArithmeticOpFolder<MulOp>(*this);
|
||||
return IdentityArithmeticOpFolder<MulOp>(*this, operands);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2998,7 +3007,7 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
}
|
||||
|
||||
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
|
||||
return TrivialArithmeticOpFolder<SubOp>(*this);
|
||||
return IdentityArithmeticOpFolder<SubOp>(*this, operands);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -303,6 +303,24 @@ func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialAddBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xbf16>
|
||||
%0 = "tf.Add"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
|
||||
return %0 : tensor<2x2xbf16>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialAdd
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xbf16>
|
||||
}
|
||||
|
||||
func @RemoveTrivialAddBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xbf16>
|
||||
%0 = "tf.Add"(%cst, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
|
||||
return %0 : tensor<2x2xbf16>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialAdd
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xbf16>
|
||||
}
|
||||
|
||||
func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
@ -325,6 +343,15 @@ func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> {
|
||||
%cst = constant dense<0> : tensor<2x2xi8>
|
||||
%0 = "tf.Sub"(%arg0, %cst) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8>
|
||||
return %0 : tensor<2x2xi8>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialSubInt8
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xi8>
|
||||
}
|
||||
|
||||
func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<1.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
@ -347,6 +374,33 @@ func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor
|
||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
func @RemoveTrivialDivBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
|
||||
%cst = constant dense<1.0> : tensor<2x2xbf16>
|
||||
%0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
|
||||
return %0 : tensor<2x2xbf16>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialDiv
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xbf16>
|
||||
}
|
||||
|
||||
func @RemoveTrivialMulInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> {
|
||||
%cst = constant dense<1> : tensor<2x2xi8>
|
||||
%0 = "tf.Mul"(%cst, %arg0) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8>
|
||||
return %0 : tensor<2x2xi8>
|
||||
|
||||
// CHECK-LABEL: RemoveTrivialMulInt8
|
||||
// CHECK-NEXT: return %arg0 : tensor<2x2xi8>
|
||||
}
|
||||
|
||||
func @DivBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
|
||||
%cst = constant dense<1.0> : tensor<2x2xbf16>
|
||||
%0 = "tf.Div"(%cst, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
|
||||
return %0 : tensor<2x2xbf16>
|
||||
|
||||
// CHECK-LABEL: DivBf16LHS
|
||||
// CHECK: tf.Div
|
||||
}
|
||||
|
||||
func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
|
||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
|
Loading…
Reference in New Issue
Block a user