Update lower static tensor list pass with the following changes:
* Don't reorder FuncOps manually (dialect conversion can automatically handle dependencies between SSA values). * Apply conversion on the whole ModuleOp * Rely on dialect conversion to automatically roll back changes to IR in case of legalization failure. PiperOrigin-RevId: 354364707 Change-Id: Iac586eba16fde8c25d0b7da9484f747a1a3874e3
This commit is contained in:
parent
791dc47439
commit
c400143fd1
@ -1,4 +1,6 @@
|
||||
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s
|
||||
// RUN: tf-opt -tfl-lower-static-tensor-list=allow-tensorlist-pass-through -split-input-file %s | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: tensorlistConst
|
||||
func @tensorlistConst(%arg0 : tensor<1xi32>) -> tensor<2x3xi32> {
|
||||
@ -431,3 +433,29 @@ func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: ten
|
||||
// CHECK: [[RESULT:%.*]] = "tf.Slice"([[INPUT]], [[SLICE_BEGIN]], [[SLICE_SIZE]]) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<?x10xf32>
|
||||
// CHECK: return [[RESULT]] : tensor<?x10xf32>
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: tensorlistReserveWithDynamicShape
|
||||
func @tensorlistReserveWithDynamicShape(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
|
||||
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
|
||||
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32>
|
||||
return %1 : tensor<?x?x?xf32>
|
||||
|
||||
// CHECK: %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
|
||||
// CHECK: %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: return %1 : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: tensorlistConcat
|
||||
func @tensorlistConcat(%arg0: tensor<?xf32>, %element_shape: tensor<0xi32>, %lead: tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>) {
|
||||
%list = "tf.TensorListFromTensor"(%arg0, %element_shape) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
%t:2 = "tf.TensorListConcatV2"(%list, %element_shape, %lead) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
|
||||
return %t#0, %t#1 : tensor<?xf32>, tensor<0xi64>
|
||||
|
||||
// CHECK: %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<?xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
// CHECK: %tensor, %lengths = "tf.TensorListConcatV2"(%0, %arg1, %arg2) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
|
||||
// CHECK: return %tensor, %lengths : tensor<?xf32>, tensor<0xi64>
|
||||
}
|
||||
|
@ -58,6 +58,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -72,20 +73,21 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
namespace {
|
||||
|
||||
class TensorListPatternRewriter : public PatternRewriter {
|
||||
public:
|
||||
explicit TensorListPatternRewriter(FuncOp fn)
|
||||
: PatternRewriter(fn.getContext()) {}
|
||||
};
|
||||
|
||||
/// Lower TensorList ops in functions for subsequent legalization.
|
||||
struct LowerStaticTensorListPass
|
||||
: public PassWrapper<LowerStaticTensorListPass, OperationPass<ModuleOp>> {
|
||||
LowerStaticTensorListPass() = default;
|
||||
LowerStaticTensorListPass(const LowerStaticTensorListPass &) {}
|
||||
|
||||
void runOnOperation() override;
|
||||
|
||||
// Apply type and op changes within a function.
|
||||
LogicalResult RewriteFunction(FuncOp func,
|
||||
TensorListPatternRewriter *rewriter);
|
||||
Option<bool> allow_tensorlist_pass_through{
|
||||
*this, "allow-tensorlist-pass-through",
|
||||
llvm::cl::desc(
|
||||
"When specified to true, if the tensorlist ops can't be properly "
|
||||
"legalized by this pass, then the IR won't be changed so that "
|
||||
"tensorlist ops can pass through (default false)"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
Value CreateI32SplatConst(Location loc, PatternRewriter *rewriter,
|
||||
@ -335,7 +337,8 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
|
||||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
|
||||
dtype.isInteger(32) || dtype.isInteger(64))) {
|
||||
op.emitError(
|
||||
rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
|
||||
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
|
||||
"transformation pass");
|
||||
@ -393,7 +396,8 @@ struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
if (element_shape_acquired) break;
|
||||
}
|
||||
if (!element_shape_acquired) {
|
||||
op.emitError(
|
||||
rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"requires element_shape to be 1D tensor during TF Lite "
|
||||
"transformation pass");
|
||||
return failure();
|
||||
@ -972,8 +976,7 @@ struct ConvertWhileRegion : public OpConversionPattern<TF::WhileRegionOp> {
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
|
||||
|
||||
LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
FuncOp func, TensorListPatternRewriter *rewriter) {
|
||||
void LowerStaticTensorListPass::runOnOperation() {
|
||||
auto *context = &getContext();
|
||||
|
||||
// TensorFlow operations that doesn't have operands and results of type
|
||||
@ -996,7 +999,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
TF::TensorListGetItemOp, TF::TensorListLengthOp,
|
||||
TF::TensorListPushBackOp, TF::TensorListReserveOp,
|
||||
TF::TensorListSetItemOp, TF::TensorListStackOp,
|
||||
TF::TensorListResizeOp>();
|
||||
TF::TensorListResizeOp, TF::TensorListConcatV2Op>();
|
||||
// TODO(hinsu): Use TFLite constant op for constants.
|
||||
target.addLegalOp<ConstantOp>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
@ -1016,29 +1019,10 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
ConvertTensorListSetItem, ConvertTensorListStack,
|
||||
ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>(
|
||||
context);
|
||||
return applyPartialConversion(func, target, std::move(patterns));
|
||||
}
|
||||
|
||||
void LowerStaticTensorListPass::runOnOperation() {
|
||||
// TODO(haoliang): currently we process the `main` function first, and the
|
||||
// remaining functions may be processed in arbitrary order. However, this will
|
||||
// have a potential issue when one function taking a `DT_VARIANT` is processed
|
||||
// before the function that produces the `DT_VARIANT`. We need to carefully
|
||||
// order the functions to be processed.
|
||||
std::vector<FuncOp> funcs_in_module;
|
||||
for (auto func : getOperation().getOps<FuncOp>()) {
|
||||
// Always place the main function to be the first in the list.
|
||||
if (func.getName() == "main") {
|
||||
funcs_in_module.insert(funcs_in_module.begin(), func);
|
||||
} else {
|
||||
funcs_in_module.push_back(func);
|
||||
}
|
||||
}
|
||||
for (auto func : funcs_in_module) {
|
||||
TensorListPatternRewriter rewriter(func);
|
||||
if (failed(RewriteFunction(func, &rewriter))) {
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
if (!allow_tensorlist_pass_through) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user