diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir index 9ae152f9979..7adc6ef0b5a 100644 --- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir +++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir @@ -1,4 +1,6 @@ -// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s +// RUN: tf-opt -tfl-lower-static-tensor-list=allow-tensorlist-pass-through -split-input-file %s | FileCheck %s + +// ----- // CHECK-LABEL: tensorlistConst func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> { @@ -431,3 +433,29 @@ func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: ten // CHECK: [[RESULT:%.*]] = "tf.Slice"([[INPUT]], [[SLICE_BEGIN]], [[SLICE_SIZE]]) : (tensor<3x10xf32>, tensor, tensor) -> tensor // CHECK: return [[RESULT]] : tensor // CHECK: } + +// ----- + +// CHECK-LABEL: tensorlistReserveWithDynamicShape +func @tensorlistReserveWithDynamicShape(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor, tensor) -> tensor>> + %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor>>, tensor, tensor) -> tensor + return %1 : tensor + +// CHECK: %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor, tensor) -> tensor>> +// CHECK: %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor>>, tensor, tensor) -> tensor +// CHECK: return %1 : tensor +} + +// ----- + +// CHECK-LABEL: tensorlistConcat +func @tensorlistConcat(%arg0: tensor, %element_shape: tensor<0xi32>, %lead: tensor) -> (tensor, tensor<0xi64>) { + %list = "tf.TensorListFromTensor"(%arg0, %element_shape) : (tensor, tensor<0xi32>) -> tensor>> + %t:2 = "tf.TensorListConcatV2"(%list, %element_shape, %lead) : (tensor>>, tensor<0xi32>, tensor) -> (tensor, tensor<0xi64>) + return %t#0, %t#1 : tensor, tensor<0xi64> + +// CHECK: %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor, tensor<0xi32>) -> tensor>> +// CHECK: %tensor, %lengths = "tf.TensorListConcatV2"(%0, %arg1, %arg2) : (tensor>>, tensor<0xi32>, tensor) -> (tensor, tensor<0xi64>) +// CHECK: return %tensor, %lengths : tensor, tensor<0xi64> +} diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 400df2af946..7194e73ca28 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -58,6 +58,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/core/framework/tensor.h" @@ -72,20 +73,21 @@ limitations under the License. namespace mlir { namespace { -class TensorListPatternRewriter : public PatternRewriter { - public: - explicit TensorListPatternRewriter(FuncOp fn) - : PatternRewriter(fn.getContext()) {} -}; - /// Lower TensorList ops in functions for subsequent legalization. struct LowerStaticTensorListPass : public PassWrapper> { + LowerStaticTensorListPass() = default; + LowerStaticTensorListPass(const LowerStaticTensorListPass &) {} + void runOnOperation() override; - // Apply type and op changes within a function. - LogicalResult RewriteFunction(FuncOp func, - TensorListPatternRewriter *rewriter); + Option allow_tensorlist_pass_through{ + *this, "allow-tensorlist-pass-through", + llvm::cl::desc( + "When specified to true, if the tensorlist ops can't be properly " + "legalized by this pass, then the IR won't be changed so that " + "tensorlist ops can pass through (default false)"), + llvm::cl::init(false)}; }; Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter, @@ -335,7 +337,8 @@ struct ConvertTensorListInitOp : public OpConversionPattern { if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() || dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) || dtype.isInteger(32) || dtype.isInteger(64))) { - op.emitError( + rewriter.notifyMatchFailure( + op, "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit " "integer or 16-bit/32-bit/64-bit float type during TF Lite " "transformation pass"); @@ -393,7 +396,8 @@ struct ConvertTensorListInitOp : public OpConversionPattern { if (element_shape_acquired) break; } if (!element_shape_acquired) { - op.emitError( + rewriter.notifyMatchFailure( + op, "requires element_shape to be 1D tensor during TF Lite " "transformation pass"); return failure(); @@ -972,8 +976,7 @@ struct ConvertWhileRegion : public OpConversionPattern { #include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc" -LogicalResult LowerStaticTensorListPass::RewriteFunction( - FuncOp func, TensorListPatternRewriter *rewriter) { +void LowerStaticTensorListPass::runOnOperation() { auto *context = &getContext(); // TensorFlow operations that doesn't have operands and results of type @@ -996,7 +999,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( TF::TensorListGetItemOp, TF::TensorListLengthOp, TF::TensorListPushBackOp, TF::TensorListReserveOp, TF::TensorListSetItemOp, TF::TensorListStackOp, - TF::TensorListResizeOp>(); + TF::TensorListResizeOp, TF::TensorListConcatV2Op>(); // TODO(hinsu): Use TFLite constant op for constants. target.addLegalOp(); target.addLegalOp(); @@ -1016,29 +1019,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( ConvertTensorListSetItem, ConvertTensorListStack, ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>( context); - return applyPartialConversion(func, target, std::move(patterns)); -} - -void LowerStaticTensorListPass::runOnOperation() { - // TODO(haoliang): currently we process the `main` function first, and the - // remaining functions may be processed in arbitrary order. However, this will - // have a potential issue when one function taking a `DT_VARIANT` is processed - // before the function that produces the `DT_VARIANT`. We need to carefully - // order the functions to be processed. - std::vector funcs_in_module; - for (auto func : getOperation().getOps()) { - // Always place the main function to be the first in the list. - if (func.getName() == "main") { - funcs_in_module.insert(funcs_in_module.begin(), func); - } else { - funcs_in_module.push_back(func); - } - } - for (auto func : funcs_in_module) { - TensorListPatternRewriter rewriter(func); - if (failed(RewriteFunction(func, &rewriter))) { + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + if (!allow_tensorlist_pass_through) { signalPassFailure(); - return; } } }