In lower static tensorlist pass, provide a flag allow_tensorlist_pass_through
to ensure that unsupported tensorlist ops can pass through. When the flag is set to true, and if the function contains any unsupported ops then that op will pass through instead of throwing out an error.
PiperOrigin-RevId: 354339397 Change-Id: I7b6acd1611c0557fd4b4faa45ffef97983caf02e
This commit is contained in:
parent
1b5ceacc3e
commit
852c7a1716
@ -1,6 +1,4 @@
|
||||
// RUN: tf-opt -tfl-lower-static-tensor-list=allow-tensorlist-pass-through -split-input-file %s | FileCheck %s
|
||||
|
||||
// -----
|
||||
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: tensorlistConst
|
||||
func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> {
|
||||
@ -433,27 +431,3 @@ func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: ten
|
||||
// CHECK: [[RESULT:%.*]] = "tf.Slice"([[INPUT]], [[SLICE_BEGIN]], [[SLICE_SIZE]]) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<?x10xf32>
|
||||
// CHECK: return [[RESULT]] : tensor<?x10xf32>
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
func @tensorlistReserveWithDynamicShape(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
|
||||
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
|
||||
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32>
|
||||
return %1 : tensor<?x?x?xf32>
|
||||
|
||||
// CHECK-LABEL: tensorlistReserveWithDynamicShape
|
||||
// CHECK: %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
|
||||
// CHECK: %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: return %1 : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
func @tensorlistConcat(%arg0: tensor<?xf32>, %element_shape: tensor<0xi32>, %lead: tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>) {
|
||||
%list = "tf.TensorListFromTensor"(%arg0, %element_shape) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
%t:2 = "tf.TensorListConcatV2"(%list, %element_shape, %lead) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
|
||||
return %t#0, %t#1 : tensor<?xf32>, tensor<0xi64>
|
||||
|
||||
// CHECK-LABEL: tensorlistConcat
|
||||
// CHECK: %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
// CHECK: %tensor, %lengths = "tf.TensorListConcatV2"(%0, %arg1, %arg2) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
|
||||
// CHECK: return %tensor, %lengths : tensor<?xf32>, tensor<0xi64>
|
||||
}
|
||||
|
@ -36,7 +36,6 @@ limitations under the License.
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
@ -82,21 +81,11 @@ class TensorListPatternRewriter : public PatternRewriter {
|
||||
/// Lower TensorList ops in functions for subsequent legalization.
|
||||
struct LowerStaticTensorListPass
|
||||
: public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
|
||||
LowerStaticTensorListPass() = default;
|
||||
LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
// Apply type and op changes within a function.
|
||||
LogicalResult RewriteFunction(FuncOp func,
|
||||
TensorListPatternRewriter *rewriter);
|
||||
|
||||
Option<bool> allow_tensorlist_pass_through{
|
||||
*this, "allow-tensorlist-pass-through",
|
||||
llvm::cl::desc(
|
||||
"If specified to true, then the tensorlist ops may pass "
|
||||
"through if it can't be handled by this pass (default false)"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
|
||||
@ -346,8 +335,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
|
||||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
|
||||
dtype.isInteger(32) || dtype.isInteger(64))) {
|
||||
rewriter.notifyMatchFailure(
|
||||
op,
|
||||
op.emitError(
|
||||
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
|
||||
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
|
||||
"transformation pass");
|
||||
@ -405,8 +393,7 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
if (element_shape_acquired) break;
|
||||
}
|
||||
if (!element_shape_acquired) {
|
||||
rewriter.notifyMatchFailure(
|
||||
op,
|
||||
op.emitError(
|
||||
"requires element_shape to be 1D tensor during TF Lite "
|
||||
"transformation pass");
|
||||
return failure();
|
||||
@ -1047,23 +1034,10 @@ void LowerStaticTensorListPass::runOnOperation() {
|
||||
funcs_in_module.push_back(func);
|
||||
}
|
||||
}
|
||||
auto cloned_module = getOperation().clone();
|
||||
for (auto func : funcs_in_module) {
|
||||
TensorListPatternRewriter rewriter(func);
|
||||
if (failed(RewriteFunction(func, &rewriter))) {
|
||||
if (allow_tensorlist_pass_through) {
|
||||
// If the current pass allows unsupported tensorlist ops to pass
|
||||
// through, in terms of failure we should roll back all the changes done
|
||||
// so far. The reason that we can't rely on dialect conversion to
|
||||
// automatically roll back the changes is that, the dialect conversion
|
||||
// is currently applied on function level.
|
||||
BlockAndValueMapping mapping;
|
||||
getOperation().body().getBlocks().clear();
|
||||
cloned_module.body().cloneInto(&getOperation().body(), mapping);
|
||||
break;
|
||||
} else {
|
||||
signalPassFailure();
|
||||
}
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user