Improve TensorFlow arithmatic op folders
* Fix crash for non fp32 types by creating identity attribute of the appropriate type. * Handle integer types * Handle identity value as lhs for symmetric ops * Utilize isSplat and getSplatValue helpers. PiperOrigin-RevId: 310162612 Change-Id: I863a5bf5cb64c832dd938e22c7694d34236dcfc3
This commit is contained in:
parent
c9aeeea2a8
commit
5bb727ee34
@ -501,40 +501,49 @@ LogicalResult FoldOperandsPermutation(
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Utility methods that returns Identity value to use for selected ops.
|
|
||||||
|
|
||||||
APFloat GetIdentity(AddV2Op op) { return APFloat(0.0f); }
|
|
||||||
APFloat GetIdentity(SubOp op) { return APFloat(0.0f); }
|
|
||||||
APFloat GetIdentity(MulOp op) { return APFloat(1.0f); }
|
|
||||||
APFloat GetIdentity(DivOp op) { return APFloat(1.0f); }
|
|
||||||
|
|
||||||
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
|
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
|
||||||
// known to be Identity (e.g X+0)
|
// known to be Identity (e.g X+0)
|
||||||
template <typename OP>
|
template <typename OpT,
|
||||||
OpFoldResult TrivialArithmeticOpFolder(OP arithmetic_op) {
|
typename std::enable_if<llvm::is_one_of<
|
||||||
DenseFPElementsAttr rhs_value;
|
OpT, AddV2Op, SubOp, MulOp, DivOp>::value>::type * = nullptr>
|
||||||
auto constant_val = arithmetic_op.y();
|
OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
|
||||||
if (!matchPattern(constant_val, m_Constant(&rhs_value))) {
|
ArrayRef<Attribute> operands) {
|
||||||
return {};
|
|
||||||
}
|
|
||||||
auto result_op_type = arithmetic_op.getResult().getType();
|
auto result_op_type = arithmetic_op.getResult().getType();
|
||||||
auto lhs_type = arithmetic_op.x().getType();
|
auto lhs_type = arithmetic_op.x().getType().template cast<ShapedType>();
|
||||||
if (!result_op_type.template isa<ShapedType>() ||
|
if (!result_op_type.template cast<ShapedType>().hasStaticShape()) return {};
|
||||||
!lhs_type.template isa<ShapedType>() ||
|
|
||||||
!result_op_type.template cast<ShapedType>().hasStaticShape()) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
// We only handle non-broadcastable case.
|
// We only handle non-broadcastable case.
|
||||||
if (result_op_type != lhs_type) {
|
if (result_op_type != lhs_type) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
auto identity_val = GetIdentity(arithmetic_op);
|
|
||||||
for (auto it = rhs_value.float_value_begin();
|
// Mul and Div ops have identity value one while AddV2 and SubOp have identity
|
||||||
it != rhs_value.float_value_end(); ++it) {
|
// value zero.
|
||||||
if (*it != identity_val) return {};
|
int identity =
|
||||||
|
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value);
|
||||||
|
|
||||||
|
Type element_ty = lhs_type.getElementType();
|
||||||
|
Attribute identity_attr;
|
||||||
|
if (auto ty = element_ty.template dyn_cast<FloatType>()) {
|
||||||
|
identity_attr = FloatAttr::get(ty, static_cast<double>(identity));
|
||||||
|
} else if (auto ty = element_ty.template dyn_cast<IntegerType>()) {
|
||||||
|
identity_attr = IntegerAttr::get(ty, static_cast<int64_t>(identity));
|
||||||
|
} else {
|
||||||
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
return arithmetic_op.x();
|
if (auto attr = operands[1].dyn_cast_or_null<DenseElementsAttr>()) {
|
||||||
|
if (attr.isSplat() && attr.getSplatValue() == identity_attr)
|
||||||
|
return arithmetic_op.x();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_symmetric =
|
||||||
|
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
|
||||||
|
if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
|
||||||
|
if (is_symmetric && attr.isSplat() && attr.getSplatValue() == identity_attr)
|
||||||
|
return arithmetic_op.y();
|
||||||
|
}
|
||||||
|
return {};
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -570,7 +579,7 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult AddV2Op::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult AddV2Op::fold(ArrayRef<Attribute> operands) {
|
||||||
return TrivialArithmeticOpFolder<AddV2Op>(*this);
|
return IdentityArithmeticOpFolder<AddV2Op>(*this, operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -1320,7 +1329,7 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult DivOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return TrivialArithmeticOpFolder<DivOp>(*this);
|
return IdentityArithmeticOpFolder<DivOp>(*this, operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -2026,7 +2035,7 @@ LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return TrivialArithmeticOpFolder<MulOp>(*this);
|
return IdentityArithmeticOpFolder<MulOp>(*this, operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -2998,7 +3007,7 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return TrivialArithmeticOpFolder<SubOp>(*this);
|
return IdentityArithmeticOpFolder<SubOp>(*this, operands);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -303,6 +303,24 @@ func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor
|
|||||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @RemoveTrivialAddBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
|
||||||
|
%cst = constant dense<0.0> : tensor<2x2xbf16>
|
||||||
|
%0 = "tf.Add"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
|
||||||
|
return %0 : tensor<2x2xbf16>
|
||||||
|
|
||||||
|
// CHECK-LABEL: RemoveTrivialAdd
|
||||||
|
// CHECK-NEXT: return %arg0 : tensor<2x2xbf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @RemoveTrivialAddBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
|
||||||
|
%cst = constant dense<0.0> : tensor<2x2xbf16>
|
||||||
|
%0 = "tf.Add"(%cst, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
|
||||||
|
return %0 : tensor<2x2xbf16>
|
||||||
|
|
||||||
|
// CHECK-LABEL: RemoveTrivialAdd
|
||||||
|
// CHECK-NEXT: return %arg0 : tensor<2x2xbf16>
|
||||||
|
}
|
||||||
|
|
||||||
func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||||
@ -325,6 +343,15 @@ func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor
|
|||||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> {
|
||||||
|
%cst = constant dense<0> : tensor<2x2xi8>
|
||||||
|
%0 = "tf.Sub"(%arg0, %cst) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8>
|
||||||
|
return %0 : tensor<2x2xi8>
|
||||||
|
|
||||||
|
// CHECK-LABEL: RemoveTrivialSubInt8
|
||||||
|
// CHECK-NEXT: return %arg0 : tensor<2x2xi8>
|
||||||
|
}
|
||||||
|
|
||||||
func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
%cst = constant dense<1.0> : tensor<2x2xf32>
|
%cst = constant dense<1.0> : tensor<2x2xf32>
|
||||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||||
@ -347,6 +374,33 @@ func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor
|
|||||||
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
// CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @RemoveTrivialDivBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
|
||||||
|
%cst = constant dense<1.0> : tensor<2x2xbf16>
|
||||||
|
%0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
|
||||||
|
return %0 : tensor<2x2xbf16>
|
||||||
|
|
||||||
|
// CHECK-LABEL: RemoveTrivialDiv
|
||||||
|
// CHECK-NEXT: return %arg0 : tensor<2x2xbf16>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @RemoveTrivialMulInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> {
|
||||||
|
%cst = constant dense<1> : tensor<2x2xi8>
|
||||||
|
%0 = "tf.Mul"(%cst, %arg0) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8>
|
||||||
|
return %0 : tensor<2x2xi8>
|
||||||
|
|
||||||
|
// CHECK-LABEL: RemoveTrivialMulInt8
|
||||||
|
// CHECK-NEXT: return %arg0 : tensor<2x2xi8>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @DivBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> {
|
||||||
|
%cst = constant dense<1.0> : tensor<2x2xbf16>
|
||||||
|
%0 = "tf.Div"(%cst, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16>
|
||||||
|
return %0 : tensor<2x2xbf16>
|
||||||
|
|
||||||
|
// CHECK-LABEL: DivBf16LHS
|
||||||
|
// CHECK: tf.Div
|
||||||
|
}
|
||||||
|
|
||||||
func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
|
func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> {
|
||||||
%cst = constant dense<0.0> : tensor<2x2xf32>
|
%cst = constant dense<0.0> : tensor<2x2xf32>
|
||||||
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
|
%0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||||
|
Loading…
Reference in New Issue
Block a user