diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index e017db0afc6..1dde5739bb1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -2068,6 +2068,62 @@ and `B, D, F, H` as group 1. Thus we get the outputs: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CumprodOp : TF_Op<"Cumprod", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> { + let summary = [{ +Compute the cumulative product of the tensor `x` along `axis`. + }]; + + let description = [{ +By default, this op performs an inclusive cumprod, which means that the first +element of the input is identical to the first element of the output: + +```python +tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] +``` + +By setting the `exclusive` kwarg to `True`, an exclusive cumprod is +performed instead: + +```python +tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] +``` + +By setting the `reverse` kwarg to `True`, the cumprod is performed in the +opposite direction: + +```python +tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] +``` + +This is more efficient than using separate `tf.reverse` ops. + +The `reverse` and `exclusive` kwargs can also be combined: + +```python +tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TF_I32OrI64Tensor:$axis, + + DefaultValuedAttr<BoolAttr, "false">:$exclusive, + DefaultValuedAttr<BoolAttr, "false">:$reverse + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; +} + def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> { let summary = "Compute the cumulative sum of the tensor `x` along `axis`."; @@ -2116,6 +2172,10 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; } def TF_DataFormatDimMapOp : TF_Op<"DataFormatDimMap", [NoSideEffect, SameOperandsAndResultType]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 3e9ed6f2941..cf9ae2c2174 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -774,6 +774,35 @@ void ConcatV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, context); } +//===----------------------------------------------------------------------===// +// CumsumOp and CumprodOp +//===----------------------------------------------------------------------===// + +template <typename OpT, typename std::enable_if<llvm::is_one_of< + OpT, CumsumOp, CumprodOp>::value>::type * = nullptr> +static LogicalResult Verify(OpT op) { + if (!IsOfRankOrUnranked(op.axis(), 0)) + return op.emitOpError("requires scalar axis operand"); + + DenseIntElementsAttr axis_attr; + if (matchPattern(op.axis(), m_Constant(&axis_attr))) { + auto input_ty = op.x().getType().template dyn_cast<RankedTensorType>(); + if (input_ty) { + int64_t rank = input_ty.getRank(); + assert(axis_attr.getNumElements() == 1 && + "scalar attribute should have exactly one element"); + int64_t axis = (*axis_attr.begin()).getSExtValue(); + if (axis < -rank || axis >= rank) { + return op.emitError() + << "axis operand should be within range [" << -rank << ", " + << rank << "); actual value: " << axis; + } + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // ConcatOffsetOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 365007f75e4..91d45395a46 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -3373,3 +3373,28 @@ func @testCaseRegionMismatchedResultTypes(%arg0: tensor<i32>, %arg1: tensor<f32> }) {is_stateless = false} : (tensor<i32>) -> tensor<i1> return } + +// ----- + +// Test valid tf.Cumsum +func @testCumsum(%arg: tensor<8x16xf32>, %axis: tensor<i32>) -> tensor<8x16xf32> { + %0 = "tf.Cumsum"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// ----- + +func @testCumprod(%arg: tensor<8x16xf32>, %axis: tensor<2xi32>) -> tensor<8x16xf32> { + // expected-error @+1 {{requires scalar axis operand}} + %0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<2xi32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// ----- + +func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> { + %axis = constant dense<-3> : tensor<i32> + // expected-error @+1 {{axis operand should be within range [-2, 2)}} + %0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<i32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 9b32fb97260..dd9ec6e3d8b 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -4668,6 +4668,20 @@ func @cumsum_dynamic(%arg0: tensor<?xf32>, %arg1: tensor<i32>) -> tensor<?xf32> return %0 : tensor<?xf32> } +//===----------------------------------------------------------------------===// +// Cumprod op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @cumprod +func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32> + // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ( { + // CHECK: mhlo.mul + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor<i32>} : () -> tensor<i32> + %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor<i32>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + //===----------------------------------------------------------------------===// // Qr op legalization //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 878feb85f75..7601a54088e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -5092,17 +5092,19 @@ class ConvertXlaDynamicUpdateSliceOp } }; -/// Converts the Cumsum TensorFlow op to the HLO ReduceWindow op by setting -/// appropriate window dimensions, with 'add' as the reduction function. The -/// input tensor needs to have a static shape, and 'axis' must be const. The -/// TableGen pattern is not used for this rewrite because it involves regions. -class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> { - using OpRewritePattern<TF::CumsumOp>::OpRewritePattern; +// Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by +// setting appropriate window dimensions, with the given aggregation op as the +// reduction function. The input tensor needs to have a static shape, and 'axis' +// must be const. The TableGen pattern is not used for this rewrite because it +// involves regions. +template <typename OpT, typename AggregationOp> +class ConvertCumOp : public OpRewritePattern<OpT> { + using OpRewritePattern<OpT>::OpRewritePattern; - LogicalResult matchAndRewrite(TF::CumsumOp op, + LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const override { auto input = op.x(); - auto input_type = input.getType().dyn_cast<ShapedType>(); + auto input_type = input.getType().template dyn_cast<ShapedType>(); if (!input_type || !input_type.hasStaticShape()) { return failure(); } @@ -5135,6 +5137,10 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> { // Convert if we need to enlarge the element type's bitwidth to avoid // precision loss. Type input_element_type = input_type.getElementType(); + + // TODO(hinsu): Handle complex element types. + if (!input_element_type.isIntOrFloat()) return failure(); + Type sum_element_type = GetSumAccumulationType(input_element_type); input = rewriter.create<ConvertOp>(op.getLoc(), input, sum_element_type); @@ -5148,8 +5154,9 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> { RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)), paddings); - Value init = - GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); + int64_t init_value = (std::is_same<AggregationOp, AddOp>::value) ? 0 : 1; + Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value, + &rewriter); auto reduce = rewriter.create<ReduceWindowOp>( op.getLoc(), input_type, input, init, @@ -5157,7 +5164,7 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> { GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)), /*base_dilations=*/DenseIntElementsAttr(), /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); - BuildReduceBody<AddOp>(sum_element_type, &reduce.body(), &rewriter); + BuildReduceBody<AggregationOp>(sum_element_type, &reduce.body(), &rewriter); Value result = reduce.getResult(); if (op.exclusive()) { @@ -5193,6 +5200,9 @@ class ConvertCumsumOp : public OpRewritePattern<TF::CumsumOp> { } }; +using ConvertCumsumOp = ConvertCumOp<TF::CumsumOp, AddOp>; +using ConvertCumprodOp = ConvertCumOp<TF::CumprodOp, MulOp>; + // Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard // dialect lowerings. This involves extracting the shape type, extracting and // converting each dimension to a known integer type, and repacking into a final @@ -5857,7 +5867,7 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp, ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp, ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp, - ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, + ConvertCumprodOp, ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op, ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 7f099540f39..10b7d88e0d4 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1165,6 +1165,7 @@ tf_xla_py_test( name = "scan_ops_test", size = "medium", srcs = ["scan_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 7c36f8b13ca..440b7672d98 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -129,6 +130,7 @@ class CumsumTest(xla_test.XLATestCase): for axis in range(-6, 6, 3): self._compareAll(x, axis) + @test_util.disable_mlir_bridge("Error handling") def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) with self.session(), self.test_scope(): @@ -207,6 +209,7 @@ class CumprodTest(xla_test.XLATestCase): for axis in range(-6, 6, 3): self._compareAll(x, axis) + @test_util.disable_mlir_bridge("Error handling") def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) with self.session(), self.test_scope():