[tf.lite] Add cumsum mlir conversion & e2e tests.

PiperOrigin-RevId: 337985673
Change-Id: Ib5f6267a8d644a6caffa312924b687fd9863307f
This commit is contained in:
Renjie Liu 2020-10-19 20:11:18 -07:00 committed by TensorFlower Gardener
parent 1e2389b4ba
commit ec494399c9
4 changed files with 43 additions and 0 deletions

View File

@ -297,6 +297,7 @@
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty.
* Added support for cumulative sum (cumsum), both as builtin op and MLIR conversion.
* <ADD RELEASE NOTES HERE>
* `tf.random`:

View File

@ -876,6 +876,30 @@ def TFL_CosOp: TFL_Op<"cos", [
let hasFolder = 1;
}
def TFL_CumsumOp: TFL_Op<"cumsum", [
NoSideEffect,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoQuantizableResult,
TFL_OperandHasRank<1, 0>]> {
let summary = "Cumsum operator";
let description = [{
Compute the cumulative sum of the tensor x along axis.
}];
let arguments = (
ins TFL_TensorOf<[F32, I32, I64]>:$input,
TFL_I32Tensor:$axis,
DefaultValuedAttr<BoolAttr, "false">:$exclusive,
DefaultValuedAttr<BoolAttr, "false">:$reverse
);
let results = (outs TFL_TensorOf<[F32, I32, I64]>:$output);
let hasOptions = 1;
}
def TFL_DepthwiseConv2DOp :
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
let arguments = (

View File

@ -1594,3 +1594,17 @@ func @tranpose_arg(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf
// CHECK: "tfl.transpose"
}
func @cumsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<3x3xf32> {
%0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
return %0 : tensor<3x3xf32>
// CHECK-LABEL: cumsum
// CHECK: "tfl.cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
}
func @cumsum_invalid(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
%0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i64>) -> tensor<3x3xf32>
return %0 : tensor<3x3xf32>
// CHECK-LABEL: cumsum_invalid
// CHECK-NOT: "tfl.cumsum"
}

View File

@ -444,3 +444,7 @@ def LegalizeMatrixSetDiag : Pat<
def LegalizeScatterNd : Pat<
(TF_ScatterNdOp I32Tensor:$indices, $updates, $shape),
(TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>;
def LegalizeCumsum : Pat<
(TF_CumsumOp $input, $axis, $exclusive, $reverse),
(TFL_CumsumOp $input, $axis, $exclusive, $reverse)>;