[tf.lite] Add cumsum mlir conversion & e2e tests.
PiperOrigin-RevId: 337985673 Change-Id: Ib5f6267a8d644a6caffa312924b687fd9863307f
This commit is contained in:
parent
1e2389b4ba
commit
ec494399c9
@ -297,6 +297,7 @@
|
|||||||
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
|
`TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
|
||||||
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
|
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
|
||||||
string to be joined is empty.
|
string to be joined is empty.
|
||||||
|
* Added support for cumulative sum (cumsum), both as builtin op and MLIR conversion.
|
||||||
* <ADD RELEASE NOTES HERE>
|
* <ADD RELEASE NOTES HERE>
|
||||||
|
|
||||||
* `tf.random`:
|
* `tf.random`:
|
||||||
|
@ -876,6 +876,30 @@ def TFL_CosOp: TFL_Op<"cos", [
|
|||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TFL_CumsumOp: TFL_Op<"cumsum", [
|
||||||
|
NoSideEffect,
|
||||||
|
PredOpTrait<"input and output must have same element type",
|
||||||
|
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||||
|
NoQuantizableResult,
|
||||||
|
TFL_OperandHasRank<1, 0>]> {
|
||||||
|
let summary = "Cumsum operator";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Compute the cumulative sum of the tensor x along axis.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (
|
||||||
|
ins TFL_TensorOf<[F32, I32, I64]>:$input,
|
||||||
|
TFL_I32Tensor:$axis,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$exclusive,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$reverse
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs TFL_TensorOf<[F32, I32, I64]>:$output);
|
||||||
|
|
||||||
|
let hasOptions = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def TFL_DepthwiseConv2DOp :
|
def TFL_DepthwiseConv2DOp :
|
||||||
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
|
TFL_ConvOp<"depthwise_conv_2d", "Depthwise-separable convolution", 3> {
|
||||||
let arguments = (
|
let arguments = (
|
||||||
|
@ -1594,3 +1594,17 @@ func @tranpose_arg(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor<3x2xf
|
|||||||
// CHECK: "tfl.transpose"
|
// CHECK: "tfl.transpose"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @cumsum(%arg0: tensor<3x3xf32>, %arg1: tensor<i32>) -> tensor<3x3xf32> {
|
||||||
|
%0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
|
||||||
|
return %0 : tensor<3x3xf32>
|
||||||
|
// CHECK-LABEL: cumsum
|
||||||
|
// CHECK: "tfl.cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i32>) -> tensor<3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @cumsum_invalid(%arg0: tensor<3x3xf32>, %arg1: tensor<i64>) -> tensor<3x3xf32> {
|
||||||
|
%0 = "tf.Cumsum"(%arg0, %arg1) {exclusive = false, reverse = false} : (tensor<3x3xf32>, tensor<i64>) -> tensor<3x3xf32>
|
||||||
|
return %0 : tensor<3x3xf32>
|
||||||
|
// CHECK-LABEL: cumsum_invalid
|
||||||
|
// CHECK-NOT: "tfl.cumsum"
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -444,3 +444,7 @@ def LegalizeMatrixSetDiag : Pat<
|
|||||||
def LegalizeScatterNd : Pat<
|
def LegalizeScatterNd : Pat<
|
||||||
(TF_ScatterNdOp I32Tensor:$indices, $updates, $shape),
|
(TF_ScatterNdOp I32Tensor:$indices, $updates, $shape),
|
||||||
(TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>;
|
(TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>;
|
||||||
|
|
||||||
|
def LegalizeCumsum : Pat<
|
||||||
|
(TF_CumsumOp $input, $axis, $exclusive, $reverse),
|
||||||
|
(TFL_CumsumOp $input, $axis, $exclusive, $reverse)>;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user