diff --git a/RELEASE.md b/RELEASE.md index a05ad11779a..921106c46e1 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -297,6 +297,7 @@ `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. * `DynamicBuffer::AddJoinedString()` will now add a separator if the first string to be joined is empty. + * Added support for cumulative sum (cumsum), both as builtin op and MLIR conversion. * * `tf.random`: diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index f7ee323957d..21cbf518967 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -876,6 +876,30 @@ def TFL_CosOp: TFL_Op<"cos", [ let hasFolder = 1; } +def TFL_CumsumOp: TFL_Op<"cumsum", [ + NoSideEffect, + PredOpTrait<"input and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + NoQuantizableResult, + TFL_OperandHasRank<1, 0>]> { + let summary = "Cumsum operator"; + + let description = [{ + Compute the cumulative sum of the tensor x along axis. + }]; + + let arguments = ( + ins TFL_TensorOf<[F32, I32, I64]>:$input, + TFL_I32Tensor:$axis, + DefaultValuedAttr:$exclusive, + DefaultValuedAttr:$reverse + ); + + let results = (outs TFL_TensorOf<[F32, I32, I64]>:$output); + + let hasOptions = 1; +} + def TFL_DepthwiseConv2DOp : TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> { let arguments = ( diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 4de278ee324..07d3754a00e 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -1594,3 +1594,17 @@ func @tranpose_arg(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf // CHECK: "tfl.transpose" } +func @cumsum(%arg0: tensor<3x3xf32>, %arg1: tensor) -> tensor<3x3xf32> { + %0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> + // CHECK-LABEL: cumsum + // CHECK: "tfl.cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> +} + +func @cumsum_invalid(%arg0: tensor<3x3xf32>, %arg1: tensor) -> tensor<3x3xf32> { + %0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> + // CHECK-LABEL: cumsum_invalid + // CHECK-NOT: "tfl.cumsum" +} + diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 322da815a47..0260ed216b5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -444,3 +444,7 @@ def LegalizeMatrixSetDiag : Pat< def LegalizeScatterNd : Pat< (TF_ScatterNdOp I32Tensor:$indices, $updates, $shape), (TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>; + +def LegalizeCumsum : Pat< + (TF_CumsumOp $input, $axis, $exclusive, $reverse), + (TFL_CumsumOp $input, $axis, $exclusive, $reverse)>;