From a3d2b5a910506a2a4677f71faae780c986dbe1b6 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Fri, 14 Aug 2020 22:54:05 -0700 Subject: [PATCH] Legalize TensorFlow Cumprod op to HLO Also, add verifier for Cumsum op to reject illegal axis value. GetScalarConstOfType doesn't support non int and float element types so reject ops with other element types. PiperOrigin-RevId: 326787027 Change-Id: I54afe4e494d711fa873b6329391603fbd8958c88 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 60 +++++++++++++++++++ .../compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | 29 +++++++++ .../mlir/tensorflow/tests/tf-ops.mlir | 25 ++++++++ .../compiler/mlir/xla/tests/legalize-tf.mlir | 14 +++++ .../mlir/xla/transforms/legalize_tf.cc | 34 +++++++---- tensorflow/compiler/tests/BUILD | 1 + tensorflow/compiler/tests/scan_ops_test.py | 3 + 7 files changed, 154 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index e017db0afc6..1dde5739bb1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -2068,6 +2068,62 @@ and `B, D, F, H` as group 1. Thus we get the outputs: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CumprodOp : TF_Op<"Cumprod", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> { + let summary = [{ +Compute the cumulative product of the tensor `x` along `axis`. + }]; + + let description = [{ +By default, this op performs an inclusive cumprod, which means that the first +element of the input is identical to the first element of the output: + +```python +tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] +``` + +By setting the `exclusive` kwarg to `True`, an exclusive cumprod is +performed instead: + +```python +tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] +``` + +By setting the `reverse` kwarg to `True`, the cumprod is performed in the +opposite direction: + +```python +tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] +``` + +This is more efficient than using separate `tf.reverse` ops. + +The `reverse` and `exclusive` kwargs can also be combined: + +```python +tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TF_I32OrI64Tensor:$axis, + + DefaultValuedAttr:$exclusive, + DefaultValuedAttr:$reverse + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; +} + def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> { let summary = "Compute the cumulative sum of the tensor `x` along `axis`."; @@ -2116,6 +2172,10 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; } def TF_DataFormatDimMapOp : TF_Op<"DataFormatDimMap", [NoSideEffect, SameOperandsAndResultType]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 3e9ed6f2941..cf9ae2c2174 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -774,6 +774,35 @@ void ConcatV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, context); } +//===----------------------------------------------------------------------===// +// CumsumOp and CumprodOp +//===----------------------------------------------------------------------===// + +template ::value>::type * = nullptr> +static LogicalResult Verify(OpT op) { + if (!IsOfRankOrUnranked(op.axis(), 0)) + return op.emitOpError("requires scalar axis operand"); + + DenseIntElementsAttr axis_attr; + if (matchPattern(op.axis(), m_Constant(&axis_attr))) { + auto input_ty = op.x().getType().template dyn_cast(); + if (input_ty) { + int64_t rank = input_ty.getRank(); + assert(axis_attr.getNumElements() == 1 && + "scalar attribute should have exactly one element"); + int64_t axis = (*axis_attr.begin()).getSExtValue(); + if (axis < -rank || axis >= rank) { + return op.emitError() + << "axis operand should be within range [" << -rank << ", " + << rank << "); actual value: " << axis; + } + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // ConcatOffsetOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 365007f75e4..91d45395a46 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -3373,3 +3373,28 @@ func @testCaseRegionMismatchedResultTypes(%arg0: tensor, %arg1: tensor }) {is_stateless = false} : (tensor) -> tensor return } + +// ----- + +// Test valid tf.Cumsum +func @testCumsum(%arg: tensor<8x16xf32>, %axis: tensor) -> tensor<8x16xf32> { + %0 = "tf.Cumsum"(%arg, %axis) : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// ----- + +func @testCumprod(%arg: tensor<8x16xf32>, %axis: tensor<2xi32>) -> tensor<8x16xf32> { + // expected-error @+1 {{requires scalar axis operand}} + %0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<2xi32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// ----- + +func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> { + %axis = constant dense<-3> : tensor + // expected-error @+1 {{axis operand should be within range [-2, 2)}} + %0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 9b32fb97260..dd9ec6e3d8b 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -4668,6 +4668,20 @@ func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor return %0 : tensor } +//===----------------------------------------------------------------------===// +// Cumprod op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @cumprod +func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ( { + // CHECK: mhlo.mul + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor + %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + //===----------------------------------------------------------------------===// // Qr op legalization //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 878feb85f75..7601a54088e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -5092,17 +5092,19 @@ class ConvertXlaDynamicUpdateSliceOp } }; -/// Converts the Cumsum TensorFlow op to the HLO ReduceWindow op by setting -/// appropriate window dimensions, with 'add' as the reduction function. The -/// input tensor needs to have a static shape, and 'axis' must be const. The -/// TableGen pattern is not used for this rewrite because it involves regions. -class ConvertCumsumOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +// Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by +// setting appropriate window dimensions, with the given aggregation op as the +// reduction function. The input tensor needs to have a static shape, and 'axis' +// must be const. The TableGen pattern is not used for this rewrite because it +// involves regions. +template +class ConvertCumOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::CumsumOp op, + LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const override { auto input = op.x(); - auto input_type = input.getType().dyn_cast(); + auto input_type = input.getType().template dyn_cast(); if (!input_type || !input_type.hasStaticShape()) { return failure(); } @@ -5135,6 +5137,10 @@ class ConvertCumsumOp : public OpRewritePattern { // Convert if we need to enlarge the element type's bitwidth to avoid // precision loss. Type input_element_type = input_type.getElementType(); + + // TODO(hinsu): Handle complex element types. + if (!input_element_type.isIntOrFloat()) return failure(); + Type sum_element_type = GetSumAccumulationType(input_element_type); input = rewriter.create(op.getLoc(), input, sum_element_type); @@ -5148,8 +5154,9 @@ class ConvertCumsumOp : public OpRewritePattern { RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)), paddings); - Value init = - GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); + int64_t init_value = (std::is_same::value) ? 0 : 1; + Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value, + &rewriter); auto reduce = rewriter.create( op.getLoc(), input_type, input, init, @@ -5157,7 +5164,7 @@ class ConvertCumsumOp : public OpRewritePattern { GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)), /*base_dilations=*/DenseIntElementsAttr(), /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); - BuildReduceBody(sum_element_type, &reduce.body(), &rewriter); + BuildReduceBody(sum_element_type, &reduce.body(), &rewriter); Value result = reduce.getResult(); if (op.exclusive()) { @@ -5193,6 +5200,9 @@ class ConvertCumsumOp : public OpRewritePattern { } }; +using ConvertCumsumOp = ConvertCumOp; +using ConvertCumprodOp = ConvertCumOp; + // Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard // dialect lowerings. This involves extracting the shape type, extracting and // converting each dimension to a known integer type, and repacking into a final @@ -5857,7 +5867,7 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp, ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp, ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp, - ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, + ConvertCumprodOp, ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op, ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 7f099540f39..10b7d88e0d4 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1165,6 +1165,7 @@ tf_xla_py_test( name = "scan_ops_test", size = "medium", srcs = ["scan_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 7c36f8b13ca..440b7672d98 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -129,6 +130,7 @@ class CumsumTest(xla_test.XLATestCase): for axis in range(-6, 6, 3): self._compareAll(x, axis) + @test_util.disable_mlir_bridge("Error handling") def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) with self.session(), self.test_scope(): @@ -207,6 +209,7 @@ class CumprodTest(xla_test.XLATestCase): for axis in range(-6, 6, 3): self._compareAll(x, axis) + @test_util.disable_mlir_bridge("Error handling") def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) with self.session(), self.test_scope():