Update lower static tensor list pass with the following changes:

* Don't reorder FuncOps manually (dialect conversion can automatically handle dependencies between SSA values).
* Apply conversion on the whole ModuleOp
* Rely on dialect conversion to automatically roll back changes to IR in case of legalization failure.

PiperOrigin-RevId: 354364707
Change-Id: Iac586eba16fde8c25d0b7da9484f747a1a3874e3
This commit is contained in:
Haoliang Zhang 2021-01-28 12:08:23 -08:00 committed by TensorFlower Gardener
parent 791dc47439
commit c400143fd1
2 changed files with 49 additions and 37 deletions

View File

@ -1,4 +1,6 @@
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s
// RUN: tf-opt -tfl-lower-static-tensor-list=allow-tensorlist-pass-through -split-input-file %s | FileCheck %s
// -----
// CHECK-LABEL: tensorlistConst
func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> {
@ -431,3 +433,29 @@ func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: ten
// CHECK: [[RESULT:%.*]] = "tf.Slice"([[INPUT]], [[SLICE_BEGIN]], [[SLICE_SIZE]]) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<?x10xf32>
// CHECK: return [[RESULT]] : tensor<?x10xf32>
// CHECK: }
// -----
// CHECK-LABEL: tensorlistReserveWithDynamicShape
func @tensorlistReserveWithDynamicShape(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
// CHECK: %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
// CHECK: %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32>
// CHECK: return %1 : tensor<?x?x?xf32>
}
// -----
// CHECK-LABEL: tensorlistConcat
func @tensorlistConcat(%arg0: tensor<?xf32>, %element_shape: tensor<0xi32>, %lead: tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>) {
%list = "tf.TensorListFromTensor"(%arg0, %element_shape) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
%t:2 = "tf.TensorListConcatV2"(%list, %element_shape, %lead) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
return %t#0, %t#1 : tensor<?xf32>, tensor<0xi64>
// CHECK: %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
// CHECK: %tensor, %lengths = "tf.TensorListConcatV2"(%0, %arg1, %arg2) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
// CHECK: return %tensor, %lengths : tensor<?xf32>, tensor<0xi64>
}

View File

@ -58,6 +58,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/core/framework/tensor.h"
@ -72,20 +73,21 @@ limitations under the License.
namespace mlir {
namespace {
class TensorListPatternRewriter : public PatternRewriter {
public:
explicit TensorListPatternRewriter(FuncOp fn)
: PatternRewriter(fn.getContext()) {}
};
/// Lower TensorList ops in functions for subsequent legalization.
struct LowerStaticTensorListPass
: public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
LowerStaticTensorListPass() = default;
LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
void runOnOperation() override;
// Apply type and op changes within a function.
LogicalResult RewriteFunction(FuncOp func,
TensorListPatternRewriter *rewriter);
Option<bool> allow_tensorlist_pass_through{
*this, "allow-tensorlist-pass-through",
llvm::cl::desc(
"When specified to true, if the tensorlist ops can't be properly "
"legalized by this pass, then the IR won't be changed so that "
"tensorlist ops can pass through (default false)"),
llvm::cl::init(false)};
};
Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
@ -335,7 +337,8 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
dtype.isInteger(32) || dtype.isInteger(64))) {
op.emitError(
rewriter.notifyMatchFailure(
op,
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
"transformation pass");
@ -393,7 +396,8 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
if (element_shape_acquired) break;
}
if (!element_shape_acquired) {
op.emitError(
rewriter.notifyMatchFailure(
op,
"requires element_shape to be 1D tensor during TF Lite "
"transformation pass");
return failure();
@ -972,8 +976,7 @@ struct ConvertWhileRegion : public OpConversionPattern<TF::WhileRegionOp> {
#include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
LogicalResult LowerStaticTensorListPass::RewriteFunction(
FuncOp func, TensorListPatternRewriter *rewriter) {
void LowerStaticTensorListPass::runOnOperation() {
auto *context = &getContext();
// TensorFlow operations that doesn't have operands and results of type
@ -996,7 +999,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
TF::TensorListGetItemOp, TF::TensorListLengthOp,
TF::TensorListPushBackOp, TF::TensorListReserveOp,
TF::TensorListSetItemOp, TF::TensorListStackOp,
TF::TensorListResizeOp>();
TF::TensorListResizeOp, TF::TensorListConcatV2Op>();
// TODO(hinsu): Use TFLite constant op for constants.
target.addLegalOp<ConstantOp>();
target.addLegalOp<FuncOp>();
@ -1016,29 +1019,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
ConvertTensorListSetItem, ConvertTensorListStack,
ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>(
context);
return applyPartialConversion(func, target, std::move(patterns));
}
void LowerStaticTensorListPass::runOnOperation() {
// TODO(haoliang): currently we process the `main` function first, and the
// remaining functions may be processed in arbitrary order. However, this will
// have a potential issue when one function taking a `DT_VARIANT` is processed
// before the function that produces the `DT_VARIANT`. We need to carefully
// order the functions to be processed.
std::vector<FuncOp> funcs_in_module;
for (auto func : getOperation().getOps<FuncOp>()) {
// Always place the main function to be the first in the list.
if (func.getName() == "main") {
funcs_in_module.insert(funcs_in_module.begin(), func);
} else {
funcs_in_module.push_back(func);
}
}
for (auto func : funcs_in_module) {
TensorListPatternRewriter rewriter(func);
if (failed(RewriteFunction(func, &rewriter))) {
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
if (!allow_tensorlist_pass_through) {
signalPassFailure();
return;
}
}
}