In lower static tensorlist pass, provide a flag allow_tensorlist_pass_through to ensure that unsupported tensorlist ops can pass through. When the flag is set to true, and if the function contains any unsupported ops then that op will pass through instead of throwing out an error.

PiperOrigin-RevId: 354339397
Change-Id: I7b6acd1611c0557fd4b4faa45ffef97983caf02e
This commit is contained in:
Bixia Zheng 2021-01-28 10:22:08 -08:00 committed by TensorFlower Gardener
parent 1b5ceacc3e
commit 852c7a1716
2 changed files with 4 additions and 56 deletions

View File

@ -1,6 +1,4 @@
// RUN: tf-opt -tfl-lower-static-tensor-list=allow-tensorlist-pass-through -split-input-file %s | FileCheck %s
// -----
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s
// CHECK-LABEL: tensorlistConst
func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> {
@ -433,27 +431,3 @@ func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: ten
// CHECK: [[RESULT:%.*]] = "tf.Slice"([[INPUT]], [[SLICE_BEGIN]], [[SLICE_SIZE]]) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<?x10xf32>
// CHECK: return [[RESULT]] : tensor<?x10xf32>
// CHECK: }
// -----
func @tensorlistReserveWithDynamicShape(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
// CHECK-LABEL: tensorlistReserveWithDynamicShape
// CHECK: %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
// CHECK: %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32>
// CHECK: return %1 : tensor<?x?x?xf32>
}
func @tensorlistConcat(%arg0: tensor<?xf32>, %element_shape: tensor<0xi32>, %lead: tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>) {
%list = "tf.TensorListFromTensor"(%arg0, %element_shape) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
%t:2 = "tf.TensorListConcatV2"(%list, %element_shape, %lead) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
return %t#0, %t#1 : tensor<?xf32>, tensor<0xi64>
// CHECK-LABEL: tensorlistConcat
// CHECK: %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
// CHECK: %tensor, %lengths = "tf.TensorListConcatV2"(%0, %arg1, %arg2) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
// CHECK: return %tensor, %lengths : tensor<?xf32>, tensor<0xi64>
}

View File

@ -36,7 +36,6 @@ limitations under the License.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
@ -82,21 +81,11 @@ class TensorListPatternRewriter : public PatternRewriter {
/// Lower TensorList ops in functions for subsequent legalization.
struct LowerStaticTensorListPass
: public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
LowerStaticTensorListPass() = default;
LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
void runOnOperation() override;
// Apply type and op changes within a function.
LogicalResult RewriteFunction(FuncOp func,
TensorListPatternRewriter *rewriter);
Option<bool> allow_tensorlist_pass_through{
*this, "allow-tensorlist-pass-through",
llvm::cl::desc(
"If specified to true, then the tensorlist ops may pass "
"through if it can't be handled by this pass (default false)"),
llvm::cl::init(false)};
};
Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
@ -346,8 +335,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
dtype.isInteger(32) || dtype.isInteger(64))) {
rewriter.notifyMatchFailure(
op,
op.emitError(
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
"transformation pass");
@ -405,8 +393,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
if (element_shape_acquired) break;
}
if (!element_shape_acquired) {
rewriter.notifyMatchFailure(
op,
op.emitError(
"requires element_shape to be 1D tensor during TF Lite "
"transformation pass");
return failure();
@ -1047,23 +1034,10 @@ void LowerStaticTensorListPass::runOnOperation() {
funcs_in_module.push_back(func);
}
}
auto cloned_module = getOperation().clone();
for (auto func : funcs_in_module) {
TensorListPatternRewriter rewriter(func);
if (failed(RewriteFunction(func, &rewriter))) {
if (allow_tensorlist_pass_through) {
// If the current pass allows unsupported tensorlist ops to pass
// through, in terms of failure we should roll back all the changes done
// so far. The reason that we can't rely on dialect conversion to
// automatically roll back the changes is that, the dialect conversion
// is currently applied on function level.
BlockAndValueMapping mapping;
getOperation().body().getBlocks().clear();
cloned_module.body().cloneInto(&getOperation().body(), mapping);
break;
} else {
signalPassFailure();
}
signalPassFailure();
return;
}
}