diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir index f1eb0ca1b14..9ae152f9979 100644 --- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir +++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir @@ -1,6 +1,4 @@ -// RUN: tf-opt -tfl-lower-static-tensor-list=allow-tensorlist-pass-through -split-input-file %s | FileCheck %s - -// ----- +// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s // CHECK-LABEL: tensorlistConst func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> { @@ -433,27 +431,3 @@ func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: ten // CHECK: [[RESULT:%.*]] = "tf.Slice"([[INPUT]], [[SLICE_BEGIN]], [[SLICE_SIZE]]) : (tensor<3x10xf32>, tensor, tensor) -> tensor // CHECK: return [[RESULT]] : tensor // CHECK: } - -// ----- - -func @tensorlistReserveWithDynamicShape(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor, tensor) -> tensor>> - %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor>>, tensor, tensor) -> tensor - return %1 : tensor - -// CHECK-LABEL: tensorlistReserveWithDynamicShape -// CHECK: %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor, tensor) -> tensor>> -// CHECK: %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor>>, tensor, tensor) -> tensor -// CHECK: return %1 : tensor -} - -func @tensorlistConcat(%arg0: tensor, %element_shape: tensor<0xi32>, %lead: tensor) -> (tensor, tensor<0xi64>) { - %list = "tf.TensorListFromTensor"(%arg0, %element_shape) : (tensor, tensor<0xi32>) -> tensor>> - %t:2 = "tf.TensorListConcatV2"(%list, %element_shape, %lead) : (tensor>>, tensor<0xi32>, tensor) -> (tensor, tensor<0xi64>) - return %t#0, %t#1 : tensor, tensor<0xi64> - -// CHECK-LABEL: tensorlistConcat -// CHECK: %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor, tensor<0xi32>) -> tensor>> -// CHECK: %tensor, %lengths = "tf.TensorListConcatV2"(%0, %arg1, %arg2) : (tensor>>, tensor<0xi32>, tensor) -> (tensor, tensor<0xi64>) -// CHECK: return %tensor, %lengths : tensor, tensor<0xi64> -} diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 2a35e9731a1..400df2af946 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -36,7 +36,6 @@ limitations under the License. #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -82,21 +81,11 @@ class TensorListPatternRewriter : public PatternRewriter { /// Lower TensorList ops in functions for subsequent legalization. struct LowerStaticTensorListPass : public PassWrapper> { - LowerStaticTensorListPass() = default; - LowerStaticTensorListPass(const LowerStaticTensorListPass &) {} - void runOnOperation() override; // Apply type and op changes within a function. LogicalResult RewriteFunction(FuncOp func, TensorListPatternRewriter *rewriter); - - Option allow_tensorlist_pass_through{ - *this, "allow-tensorlist-pass-through", - llvm::cl::desc( - "If specified to true, then the tensorlist ops may pass " - "through if it can't be handled by this pass (default false)"), - llvm::cl::init(false)}; }; Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter, @@ -346,8 +335,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern { if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() || dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) || dtype.isInteger(32) || dtype.isInteger(64))) { - rewriter.notifyMatchFailure( - op, + op.emitError( "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit " "integer or 16-bit/32-bit/64-bit float type during TF Lite " "transformation pass"); @@ -405,8 +393,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern { if (element_shape_acquired) break; } if (!element_shape_acquired) { - rewriter.notifyMatchFailure( - op, + op.emitError( "requires element_shape to be 1D tensor during TF Lite " "transformation pass"); return failure(); @@ -1047,23 +1034,10 @@ void LowerStaticTensorListPass::runOnOperation() { funcs_in_module.push_back(func); } } - auto cloned_module = getOperation().clone(); for (auto func : funcs_in_module) { TensorListPatternRewriter rewriter(func); if (failed(RewriteFunction(func, &rewriter))) { - if (allow_tensorlist_pass_through) { - // If the current pass allows unsupported tensorlist ops to pass - // through, in terms of failure we should roll back all the changes done - // so far. The reason that we can't rely on dialect conversion to - // automatically roll back the changes is that, the dialect conversion - // is currently applied on function level. - BlockAndValueMapping mapping; - getOperation().body().getBlocks().clear(); - cloned_module.body().cloneInto(&getOperation().body(), mapping); - break; - } else { - signalPassFailure(); - } + signalPassFailure(); return; } }