[tf.lite] Add cumsum mlir conversion & e2e tests.
PiperOrigin-RevId: 337985673 Change-Id: Ib5f6267a8d644a6caffa312924b687fd9863307f
This commit is contained in:
parent
1e2389b4ba
commit
ec494399c9
@ -297,6 +297,7 @@
|
||||
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
|
||||
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
|
||||
string to be joined is empty.
|
||||
* Added support for cumulative sum (cumsum), both as builtin op and MLIR conversion.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
* `tf.random`:
|
||||
|
@ -876,6 +876,30 @@ def TFL_CosOp: TFL_Op<"cos", [
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def TFL_CumsumOp: TFL_Op<"cumsum", [
|
||||
NoSideEffect,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
NoQuantizableResult,
|
||||
TFL_OperandHasRank<1, 0>]> {
|
||||
let summary = "Cumsum operator";
|
||||
|
||||
let description = [{
|
||||
Compute the cumulative sum of the tensor x along axis.
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I32, I64]>:$input,
|
||||
TFL_I32Tensor:$axis,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$exclusive,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$reverse
|
||||
);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64]>:$output);
|
||||
|
||||
let hasOptions = 1;
|
||||
}
|
||||
|
||||
def TFL_DepthwiseConv2DOp :
|
||||
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
|
||||
let arguments = (
|
||||
|
@ -1594,3 +1594,17 @@ func @tranpose_arg(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf
|
||||
// CHECK: "tfl.transpose"
|
||||
}
|
||||
|
||||
func @cumsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
|
||||
return %0 : tensor<3x3xf32>
|
||||
// CHECK-LABEL: cumsum
|
||||
// CHECK: "tfl.cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
|
||||
}
|
||||
|
||||
func @cumsum_invalid(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
|
||||
%0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i64>) -> tensor<3x3xf32>
|
||||
return %0 : tensor<3x3xf32>
|
||||
// CHECK-LABEL: cumsum_invalid
|
||||
// CHECK-NOT: "tfl.cumsum"
|
||||
}
|
||||
|
||||
|
@ -444,3 +444,7 @@ def LegalizeMatrixSetDiag : Pat<
|
||||
def LegalizeScatterNd : Pat<
|
||||
(TF_ScatterNdOp I32Tensor:$indices, $updates, $shape),
|
||||
(TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>;
|
||||
|
||||
def LegalizeCumsum : Pat<
|
||||
(TF_CumsumOp $input, $axis, $exclusive, $reverse),
|
||||
(TFL_CumsumOp $input, $axis, $exclusive, $reverse)>;
|
||||
|
Loading…
x
Reference in New Issue
Block a user