diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 1efb194272b..00b52a9f9aa 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -501,40 +501,49 @@ LogicalResult FoldOperandsPermutation( //===----------------------------------------------------------------------===// namespace { -// Utility methods that returns Identity value to use for selected ops. - -APFloat GetIdentity(AddV2Op op) { return APFloat(0.0f); } -APFloat GetIdentity(SubOp op) { return APFloat(0.0f); } -APFloat GetIdentity(MulOp op) { return APFloat(1.0f); } -APFloat GetIdentity(DivOp op) { return APFloat(1.0f); } - // Folder that returns LHS of an Arithmetic Op if the RHS is a constant // known to be Identity (e.g X+0) -template -OpFoldResult TrivialArithmeticOpFolder(OP arithmetic_op) { - DenseFPElementsAttr rhs_value; - auto constant_val = arithmetic_op.y(); - if (!matchPattern(constant_val, m_Constant(&rhs_value))) { - return {}; - } +template ::value>::type * = nullptr> +OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, + ArrayRef operands) { auto result_op_type = arithmetic_op.getResult().getType(); - auto lhs_type = arithmetic_op.x().getType(); - if (!result_op_type.template isa() || - !lhs_type.template isa() || - !result_op_type.template cast().hasStaticShape()) { - return {}; - } + auto lhs_type = arithmetic_op.x().getType().template cast(); + if (!result_op_type.template cast().hasStaticShape()) return {}; + // We only handle non-broadcastable case. if (result_op_type != lhs_type) { return {}; } - auto identity_val = GetIdentity(arithmetic_op); - for (auto it = rhs_value.float_value_begin(); - it != rhs_value.float_value_end(); ++it) { - if (*it != identity_val) return {}; + + // Mul and Div ops have identity value one while AddV2 and SubOp have identity + // value zero. + int identity = + (std::is_same::value || std::is_same::value); + + Type element_ty = lhs_type.getElementType(); + Attribute identity_attr; + if (auto ty = element_ty.template dyn_cast()) { + identity_attr = FloatAttr::get(ty, static_cast(identity)); + } else if (auto ty = element_ty.template dyn_cast()) { + identity_attr = IntegerAttr::get(ty, static_cast(identity)); + } else { + return {}; } - return arithmetic_op.x(); + if (auto attr = operands[1].dyn_cast_or_null()) { + if (attr.isSplat() && attr.getSplatValue() == identity_attr) + return arithmetic_op.x(); + } + + bool is_symmetric = + (std::is_same::value || std::is_same::value); + if (auto attr = operands[0].dyn_cast_or_null()) { + if (is_symmetric && attr.isSplat() && attr.getSplatValue() == identity_attr) + return arithmetic_op.y(); + } + return {}; } } // namespace @@ -570,7 +579,7 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, } OpFoldResult AddV2Op::fold(ArrayRef operands) { - return TrivialArithmeticOpFolder(*this); + return IdentityArithmeticOpFolder(*this, operands); } //===----------------------------------------------------------------------===// @@ -1320,7 +1329,7 @@ void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, } OpFoldResult DivOp::fold(ArrayRef operands) { - return TrivialArithmeticOpFolder(*this); + return IdentityArithmeticOpFolder(*this, operands); } //===----------------------------------------------------------------------===// @@ -2026,7 +2035,7 @@ LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { //===----------------------------------------------------------------------===// OpFoldResult MulOp::fold(ArrayRef operands) { - return TrivialArithmeticOpFolder(*this); + return IdentityArithmeticOpFolder(*this, operands); } //===----------------------------------------------------------------------===// @@ -2998,7 +3007,7 @@ void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results, } OpFoldResult SubOp::fold(ArrayRef operands) { - return TrivialArithmeticOpFolder(*this); + return IdentityArithmeticOpFolder(*this, operands); } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index ed420799d75..1b581369de0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -303,6 +303,24 @@ func @RemoveTrivialAdd(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> } +func @RemoveTrivialAddBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { + %cst = constant dense<0.0> : tensor<2x2xbf16> + %0 = "tf.Add"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> + return %0 : tensor<2x2xbf16> + + // CHECK-LABEL: RemoveTrivialAdd + // CHECK-NEXT: return %arg0 : tensor<2x2xbf16> +} + +func @RemoveTrivialAddBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { + %cst = constant dense<0.0> : tensor<2x2xbf16> + %0 = "tf.Add"(%cst, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> + return %0 : tensor<2x2xbf16> + + // CHECK-LABEL: RemoveTrivialAdd + // CHECK-NEXT: return %arg0 : tensor<2x2xbf16> +} + func @RemoveTrivialAddV2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<0.0> : tensor<2x2xf32> %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> @@ -325,6 +343,15 @@ func @RemoveTrivialSub(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> } +func @RemoveTrivialSubInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> { + %cst = constant dense<0> : tensor<2x2xi8> + %0 = "tf.Sub"(%arg0, %cst) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8> + return %0 : tensor<2x2xi8> + + // CHECK-LABEL: RemoveTrivialSubInt8 + // CHECK-NEXT: return %arg0 : tensor<2x2xi8> +} + func @RemoveTrivialMul(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<1.0> : tensor<2x2xf32> %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> @@ -347,6 +374,33 @@ func @RemoveTrivialDiv(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor // CHECK-NEXT: return %[[RESULT]] : tensor<2x2xf32> } +func @RemoveTrivialDivBf16RHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { + %cst = constant dense<1.0> : tensor<2x2xbf16> + %0 = "tf.Div"(%arg0, %cst) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> + return %0 : tensor<2x2xbf16> + + // CHECK-LABEL: RemoveTrivialDiv + // CHECK-NEXT: return %arg0 : tensor<2x2xbf16> +} + +func @RemoveTrivialMulInt8(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> { + %cst = constant dense<1> : tensor<2x2xi8> + %0 = "tf.Mul"(%cst, %arg0) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8> + return %0 : tensor<2x2xi8> + + // CHECK-LABEL: RemoveTrivialMulInt8 + // CHECK-NEXT: return %arg0 : tensor<2x2xi8> +} + +func @DivBf16LHS(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { + %cst = constant dense<1.0> : tensor<2x2xbf16> + %0 = "tf.Div"(%cst, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> + return %0 : tensor<2x2xbf16> + + // CHECK-LABEL: DivBf16LHS + // CHECK: tf.Div +} + func @DontRemoveTrivialAdd(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2xf32> { %cst = constant dense<0.0> : tensor<2x2xf32> %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32>