From 05a4e4bf8cd7e2b94b892d0aa35945507aad07d2 Mon Sep 17 00:00:00 2001 From: Haoliang Zhang Date: Tue, 2 Jul 2019 17:28:29 -0700 Subject: [PATCH] Rewrite graph with TensorList ops in support for dynamic RNN use case. PiperOrigin-RevId: 256272543 --- tensorflow/compiler/mlir/lite/BUILD | 1 + .../lite/tests/lower-static-tensor-list.mlir | 113 ++++++--- .../transforms/lower_static_tensor_list.cc | 239 ++++++++++++------ .../compiler/mlir/lite/transforms/passes.h | 3 +- .../lite/transforms/tensorlist_patterns.td | 15 +- 5 files changed, 250 insertions(+), 121 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 865c206a944..a9000158629 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -232,6 +232,7 @@ cc_library( "@local_config_mlir//:QuantOps", "@local_config_mlir//:StandardOps", "@local_config_mlir//:Support", + "@local_config_mlir//:TypeUtilities", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir index 6e46c124f76..32008d4c851 100644 --- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir +++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir @@ -1,5 +1,4 @@ -// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s - +// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor) -> (tensor<10xf32>, tensor<3x10xf32>) { ^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor): %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<*x!tf.variant> @@ -32,25 +31,28 @@ func @tensorlistSetItem(tensor<3x10xf32>, tensor<1xi32>, tensor, tensor<10x return %2 : tensor<3x10xf32> // CHECK-LABEL: tensorlistSetItem -// CHECK: %cst = constant dense<1> : tensor<1xi32> -// CHECK: %cst_0 = constant dense<0> : tensor -// CHECK: %cst_1 = constant dense<-1> : tensor // CHECK: %0 = "tf.Rank"(%arg0) : (tensor<3x10xf32>) -> tensor // CHECK: %1 = "tf.Rank"(%arg3) : (tensor<10xf32>) -> tensor -// CHECK: %2 = "tf.ExpandDims"(%0, %cst_0) : (tensor, tensor) -> tensor<1xi32> +// CHECK: %cst = constant dense<0> : tensor +// CHECK: %2 = "tf.ExpandDims"(%0, %cst) : (tensor, tensor) -> tensor<1xi32> +// CHECK: %cst_0 = constant dense<0> : tensor // CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor) -> tensor -// CHECK: %4 = "tf.Add"(%arg2, %cst) : (tensor, tensor<1xi32>) -> tensor<1xi32> -// CHECK: %5 = "tf.ExpandDims"(%1, %cst_0) : (tensor, tensor) -> tensor<1xi32> -// CHECK: %6 = "tf.Fill"(%5, %cst_0) : (tensor<1xi32>, tensor) -> tensor -// CHECK: %7 = "tf.Concat"(%cst_0, %4, %6) {N = 2 : i64} : (tensor, tensor<1xi32>, tensor) -> tensor -// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst_0) : (tensor, tensor) -> tensor<1xi32> -// CHECK: %9 = "tf.Fill"(%5, %cst_1) : (tensor<1xi32>, tensor) -> tensor -// CHECK: %10 = "tf.Concat"(%cst_0, %8, %9) {N = 2 : i64} : (tensor, tensor<1xi32>, tensor) -> tensor -// CHECK: %11 = "tf.Fill"(%2, %cst_1) : (tensor<1xi32>, tensor) -> tensor +// CHECK: %cst_1 = constant dense<1> : tensor<1xi32> +// CHECK: %4 = "tf.Add"(%arg2, %cst_1) : (tensor, tensor<1xi32>) -> tensor<*xi32> +// CHECK: %5 = "tf.ExpandDims"(%1, %cst) : (tensor, tensor) -> tensor<1xi32> +// CHECK: %cst_2 = constant dense<0> : tensor +// CHECK: %6 = "tf.Fill"(%5, %cst_2) : (tensor<1xi32>, tensor) -> tensor +// CHECK: %7 = "tf.Concat"(%cst, %4, %6) {N = 2 : i64} : (tensor, tensor<*xi32>, tensor) -> tensor +// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst) : (tensor, tensor) -> tensor<1xi32> +// CHECK: %cst_3 = constant dense<-1> : tensor +// CHECK: %9 = "tf.Fill"(%5, %cst_3) : (tensor<1xi32>, tensor) -> tensor +// CHECK: %10 = "tf.Concat"(%cst, %8, %9) {N = 2 : i64} : (tensor, tensor<1xi32>, tensor) -> tensor +// CHECK: %cst_4 = constant dense<-1> : tensor +// CHECK: %11 = "tf.Fill"(%2, %cst_4) : (tensor<1xi32>, tensor) -> tensor // CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<3x10xf32>, tensor, tensor) -> tensor<*xf32> // CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<3x10xf32>, tensor, tensor) -> tensor<*xf32> -// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst_0) : (tensor<10xf32>, tensor) -> tensor<*xf32> -// CHECK: %15 = "tf.Concat"(%cst_0, %12, %14, %13) {N = 3 : i64} : (tensor, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<3x10xf32> +// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst) : (tensor<10xf32>, tensor) -> tensor<*xf32> +// CHECK: %15 = "tf.Concat"(%cst, %12, %14, %13) {N = 3 : i64} : (tensor, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<3x10xf32> // CHECK: return %15 : tensor<3x10xf32> } @@ -62,25 +64,28 @@ func @tensorlistSetItemWithScalarElements(tensor<5xf32>, tensor<0xi32>, tensor // CHECK-LABEL: tensorlistSetItemWithScalarElements -// CHECK: %cst = constant dense<1> : tensor<1xi32> -// CHECK: %cst_0 = constant dense<0> : tensor -// CHECK: %cst_1 = constant dense<-1> : tensor // CHECK: %0 = "tf.Rank"(%arg0) : (tensor<5xf32>) -> tensor // CHECK: %1 = "tf.Rank"(%arg3) : (tensor) -> tensor -// CHECK: %2 = "tf.ExpandDims"(%0, %cst_0) : (tensor, tensor) -> tensor<1xi32> +// CHECK: %cst = constant dense<0> : tensor +// CHECK: %2 = "tf.ExpandDims"(%0, %cst) : (tensor, tensor) -> tensor<1xi32> +// CHECK: %cst_0 = constant dense<0> : tensor // CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor) -> tensor -// CHECK: %4 = "tf.Add"(%arg2, %cst) : (tensor, tensor<1xi32>) -> tensor<1xi32> -// CHECK: %5 = "tf.ExpandDims"(%1, %cst_0) : (tensor, tensor) -> tensor<1xi32> -// CHECK: %6 = "tf.Fill"(%5, %cst_0) : (tensor<1xi32>, tensor) -> tensor -// CHECK: %7 = "tf.Concat"(%cst_0, %4, %6) {N = 2 : i64} : (tensor, tensor<1xi32>, tensor) -> tensor -// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst_0) : (tensor, tensor) -> tensor<1xi32> -// CHECK: %9 = "tf.Fill"(%5, %cst_1) : (tensor<1xi32>, tensor) -> tensor -// CHECK: %10 = "tf.Concat"(%cst_0, %8, %9) {N = 2 : i64} : (tensor, tensor<1xi32>, tensor) -> tensor -// CHECK: %11 = "tf.Fill"(%2, %cst_1) : (tensor<1xi32>, tensor) -> tensor +// CHECK: %cst_1 = constant dense<1> : tensor<1xi32> +// CHECK: %4 = "tf.Add"(%arg2, %cst_1) : (tensor, tensor<1xi32>) -> tensor<*xi32> +// CHECK: %5 = "tf.ExpandDims"(%1, %cst) : (tensor, tensor) -> tensor<1xi32> +// CHECK: %cst_2 = constant dense<0> : tensor +// CHECK: %6 = "tf.Fill"(%5, %cst_2) : (tensor<1xi32>, tensor) -> tensor +// CHECK: %7 = "tf.Concat"(%cst, %4, %6) {N = 2 : i64} : (tensor, tensor<*xi32>, tensor) -> tensor +// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst) : (tensor, tensor) -> tensor<1xi32> +// CHECK: %cst_3 = constant dense<-1> : tensor +// CHECK: %9 = "tf.Fill"(%5, %cst_3) : (tensor<1xi32>, tensor) -> tensor +// CHECK: %10 = "tf.Concat"(%cst, %8, %9) {N = 2 : i64} : (tensor, tensor<1xi32>, tensor) -> tensor +// CHECK: %cst_4 = constant dense<-1> : tensor +// CHECK: %11 = "tf.Fill"(%2, %cst_4) : (tensor<1xi32>, tensor) -> tensor // CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<5xf32>, tensor, tensor) -> tensor<*xf32> // CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<5xf32>, tensor, tensor) -> tensor<*xf32> -// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst_0) : (tensor, tensor) -> tensor<*xf32> -// CHECK: %15 = "tf.Concat"(%cst_0, %12, %14, %13) {N = 3 : i64} : (tensor, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<5xf32> +// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst) : (tensor, tensor) -> tensor<*xf32> +// CHECK: %15 = "tf.Concat"(%cst, %12, %14, %13) {N = 3 : i64} : (tensor, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<5xf32> // CHECK: return %15 : tensor<5xf32> } @@ -92,10 +97,56 @@ func @tensorlistReserve(tensor<3xi32>, tensor, tensor) -> tensor<3xf32 // CHECK-LABEL: tensorlistReserve // CHECK: %cst = constant dense<0> : tensor -// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor // CHECK: %0 = "tf.ExpandDims"(%arg1, %cst) : (tensor, tensor) -> tensor<1xi32> // CHECK: %1 = "tf.Concat"(%cst, %0, %arg0) {N = 2 : i64} : (tensor, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32> +// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor // CHECK: %2 = "tf.Fill"(%1, %cst_0) : (tensor<4xi32>, tensor) -> tensor<*xf32> // CHECK: %3 = "tf.Gather"(%2, %arg2) {validate_indices = true} : (tensor<*xf32>, tensor) -> tensor<3xf32> // CHECK: return %3 : tensor<3xf32> } + +func @tensorlistWhileLoop(tensor<2x3xf32>) -> tensor<*xf32> { +^bb0(%arg0: tensor<2x3xf32>): + %cst = constant dense<3> : tensor<1xi32> + %cst_0 = constant dense<0> : tensor + %cst_1 = constant dense<-1> : tensor + %0 = "tf.TensorListFromTensor"(%arg0, %cst) : (tensor<2x3xf32>, tensor<1xi32>) -> tensor + %1:2 = "tf.While"(%cst_0, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_VARIANT"], body = @tensorlistWhileBody, cond = @tensorlistWhileCond} : (tensor, tensor) -> (tensor, tensor) + %2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor, tensor) -> tensor<*xf32> + return %2 : tensor<*xf32> + +// make sure the variant types in input/output have been updated, and `T` attribute +// is removed. +// CHECK-LABEL: func @tensorlistWhileLoop +// CHECK-NOT: "tf.While"{{.*}}T = +// CHECK: "tf.While" +// CHECK-SAME: (tensor, tensor<2x3xf32>) -> (tensor, tensor<*xf32>) +// CHECK: return %0#1 : tensor<*xf32> +} + +func @tensorlistWhileBody(tensor<*xi32>, tensor<*x!tf.variant>) -> (tensor<*xi32>, tensor<*x!tf.variant>) { +^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*x!tf.variant>): + %cst = constant dense<1> : tensor + %0 = "tf.Add"(%arg0, %cst) : (tensor<*xi32>, tensor) -> tensor<*xi32> + %1 = "tf.Identity"(%arg1) : (tensor<*x!tf.variant>) -> tensor<*x!tf.variant> + return %0, %1 : tensor<*xi32>, tensor<*x!tf.variant> + +// verify `body` function's signature. +// CHECK: func @tensorlistWhileBody(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) +// CHECK: %0 = "tf.Add"(%arg0, %cst) : (tensor<*xi32>, tensor) -> tensor<*xi32> +// CHECK-NOT: tensor<*x!tf.variant> +// CHECK: %1 = "tf.Identity"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: return %0, %1 : tensor<*xi32>, tensor<*xf32> +} + +func @tensorlistWhileCond(tensor<*xi32>, tensor<*x!tf.variant>) -> tensor<*xi1> { +^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*x!tf.variant>): + %cst = constant dense<2> : tensor + %0 = "tf.Less"(%arg0, %cst) : (tensor<*xi32>, tensor) -> tensor<*xi1> + return %0 : tensor<*xi1> + +// verify `cond` function's signature. +// CHECK: func @tensorlistWhileCond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<*xi1> +// CHECK: %0 = "tf.Less"(%arg0, %cst) : (tensor<*xi32>, tensor) -> tensor<*xi1> +// CHECK: return %0 : tensor<*xi1> +} diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index b965b067440..c4d8464d3d8 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Block.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir @@ -42,11 +43,13 @@ limitations under the License. #include "mlir/Support/Functional.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #define DEBUG_TYPE "tf-tfl-legalization" @@ -56,11 +59,30 @@ limitations under the License. namespace mlir { namespace { -// Lower TensorList ops in functions for subsequent legalization. +class TensorListPatternRewriter : public PatternRewriter { + public: + explicit TensorListPatternRewriter(Function fn) + : PatternRewriter(fn.getBody()) {} + + Operation *createOperation(const OperationState &state) override { + return OpBuilder::createOperation(state); + } +}; + +/// Lower TensorList ops in functions for subsequent legalization. +// TODO(haoliang): Use DialectConversion infra to simplify the rewriting +// process. struct LowerStaticTensorListPass - : public FunctionPass { - void runOnFunction() override; - LogicalResult ModifyTensorList(); + : public ModulePass { + void runOnModule() override; + + // Apply type and op changes within a function. + LogicalResult RewriteFunction(Function func, + TensorListPatternRewriter *rewriter); + + // Changes the function type of `cond_func` and `body_func`, and the result + // type of the `WhileOp`. + LogicalResult UpdateWhileFunctionType(TF::WhileOp *while_op); }; Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter, @@ -121,9 +143,9 @@ struct ConvertTFTensorListSetItem : public RewritePattern { // Calculate the first dimension, which is index + 1. auto index = tf_op.index(); auto vector_type = rewriter.getTensorType({1}, shape_dtype); - auto begin = - rewriter.create(op->getLoc(), vector_type, index, - CreateI32SplatConst(op, &rewriter, {1}, 1)); + auto begin = rewriter.create( + op->getLoc(), rewriter.getTensorType(shape_dtype), index, + CreateI32SplatConst(op, &rewriter, {1}, 1)); // Followed by the first dimension `begin`, are `item_rank` of 0s. auto item_position_shape = rewriter.create( @@ -198,7 +220,7 @@ struct ConvertTFTensorListReserve : public RewritePattern { if (auto type = element_shape->getType().dyn_cast()) { // Note that the first item of the shape array is the element's rank, add // it by 1 to get the input's rank. - if (type.hasStaticShape()) { + if (type.hasStaticShape() && type.getRank() != 0) { input_rank = type.getShape()[0] + 1; } } @@ -236,98 +258,149 @@ namespace { } // namespace } // namespace TFL -LogicalResult LowerStaticTensorListPass::ModifyTensorList() { - // In `runOnFunction`, there is no guarantee about - // in which order those patterns will be applied. Our transformation requires - // that at runtime each `TensorListSetItem` op takes in a normal tensor type - // rather than a `DT_VARIANT` tensor. So here we need to manually walk-through - // the IR and change the argument/return types of each `TensorListSetItemOp`. - // TODO(haoliang): 1) support modifying more `TensorList` ops that consumes/ - // produces `DT_VARIANT` tensor. 2) More robust support for handling multiple - // different tensorlist types. For example, consider the case like: - // l1 = list_ops.tensor_list_from_tensor(t, element_shape1) - // l2 = list_ops.tensor_list_from_tensor(t, element_shape2) - // l1 = list_ops.tensor_list_set_item(l1, 0, item1) - // l2 = list_ops.tensor_list_set_item(l2, 0, item2) - // 3) Handle the case where a tensorlist output is passed to multiple - // functions. - for (Block &block : getFunction()) { - Type tensor_type; +LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType( + TF::WhileOp *while_op) { + SmallVector unranked_argument_types; + for (const auto &operand : while_op->getOperands()) { + unranked_argument_types.push_back( + UnrankedTensorType::get(getElementTypeOrSelf(operand->getType()))); + } + + auto *context = &getContext(); + auto module = getModule(); + Function cond_func = module.getNamedFunction(while_op->getCond()); + Function body_func = module.getNamedFunction(while_op->getBody()); + + if (cond_func) { + // Change `cond_func`'s argument types to `unranked_argument_types`. + cond_func.setType(FunctionType::get( + unranked_argument_types, cond_func.getType().getResults(), context)); + // Change the argument type for the first block. + Block &cond_first_bb = cond_func.front(); + for (int i = 0; i < cond_first_bb.getNumArguments(); ++i) { + cond_first_bb.getArgument(i)->setType(unranked_argument_types[i]); + } + } + + if (body_func) { + SmallVector updated_result_types; + for (int i = 0; i < body_func.getType().getNumResults(); ++i) { + auto result_type = body_func.getType().getResult(i); + if (getElementTypeOrSelf(result_type).isa()) { + // For variant type, use the corresponding unranked type. + updated_result_types.push_back(unranked_argument_types[i]); + } else { + updated_result_types.push_back(result_type); + } + } + // Change `body_func`'s argument type to `unranked_argument_types`. If it + // return types contain a `DT_VARIANT`, change it to the unranked type + // derived from the corresponding argument. + body_func.setType(FunctionType::get(unranked_argument_types, + updated_result_types, context)); + // Change the argument type for the first block. + Block &body_first_bb = body_func.front(); + for (int i = 0; i < body_first_bb.getNumArguments(); ++i) { + body_first_bb.getArgument(i)->setType(unranked_argument_types[i]); + } + } + + for (int i = 0; i < while_op->getNumOperands(); ++i) { + auto operand = while_op->getOperand(i); + auto result = while_op->getResult(i); + if (getElementTypeOrSelf(result->getType()).isa()) { + // If we notice the result type is a DT_VARIANT, we change the + // corresponding result type to unranked tensor type. + result->setType( + UnrankedTensorType::get(getElementTypeOrSelf(operand->getType()))); + } + } + return success(); +} + +LogicalResult LowerStaticTensorListPass::RewriteFunction( + Function func, TensorListPatternRewriter *rewriter) { + auto *context = &getContext(); + + for (Block &block : func) { + // Buffer the op pointers inside the current block into a vector, since + // the block iterator might be invalidated if we rewrite ops during looping. + std::vector ops_in_block; for (Operation &op : block) { + ops_in_block.push_back(&op); + } + + for (Operation *op : ops_in_block) { if (auto tf_op = llvm::dyn_cast(op)) { - tensor_type = tf_op.tensor()->getType(); + auto c = TFL::ConvertTFTensorListFromTensor(context); + rewriter->setInsertionPoint(op); + c.matchAndRewrite(op, *rewriter); } else if (auto tf_op = llvm::dyn_cast(op)) { if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() || tf_op.element_dtype().isF64() || - tf_op.element_dtype().isa())) { + tf_op.element_dtype().isInteger(8) || + tf_op.element_dtype().isInteger(16) || + tf_op.element_dtype().isInteger(32) || + tf_op.element_dtype().isInteger(64))) { return tf_op.emitError( - "requires element_dtype to be integer or 16-bit/32-bit/64-bit " + "requires element_dtype to be 8-bit/16-bit/32-bit/64-bit integer " + "or 16-bit/32-bit/64-bit " "float type during TF Lite transformation pass"); } - // TODO(haoliang): figure out better way of specify shape. - tensor_type = UnrankedTensorType::get(tf_op.element_dtype()); - } - - if (auto tf_op = llvm::dyn_cast(op)) { - tf_op.input_handle()->setType(tensor_type); - tf_op.getResult()->setType(tensor_type); - } - // Currently we will raise an error if an op other than the following - // contains a DT_VARIANT tensor as its input or output. Below ops already - // have proper transformation patterns that eliminate the need of - // `DT_VARIANT`, we consider it's safe to not raise an error on those ops. - if (llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op) || - llvm::isa(op)) { - continue; - } - // Check if any of the input operand is a DT_VARIANT. - for (Type type : op.getOperandTypes()) { - if (type.isa()) { - return op.emitError( - "op's input contains a DT_VARIANT tensor. Currently we only " - "allow " - "TensorListFromTensor/TensorListReserve/TensorListStack/" - "TensorListSetItem/" - "TensorListGetItem to have DT_VARIANT input/output"); - } - } - // Check if any of the output is a DT_VARIANT. - for (Type type : op.getResultTypes()) { - if (type.isa()) { - return op.emitError( - "op's output contains a DT_VARIANT tensor. Currently we only " - "allow " - "TensorListFromTensor/TensorListReserve/TensorListStack/" - "TensorListSetItem/" - "TensorListGetItem to have DT_VARIANT input/output"); - } + auto c = ConvertTFTensorListReserve(context); + rewriter->setInsertionPoint(op); + c.matchAndRewrite(op, *rewriter); + } else if (auto tf_op = llvm::dyn_cast(op)) { + auto c = TFL::ConvertTFTensorListGetItem(context); + rewriter->setInsertionPoint(op); + c.matchAndRewrite(op, *rewriter); + } else if (auto tf_op = llvm::dyn_cast(op)) { + auto c = ConvertTFTensorListSetItem(context); + rewriter->setInsertionPoint(op); + c.matchAndRewrite(op, *rewriter); + } else if (auto tf_op = llvm::dyn_cast(op)) { + auto c = TFL::ConvertTFTensorListStack(context); + rewriter->setInsertionPoint(op); + c.matchAndRewrite(op, *rewriter); + } else if (auto tf_op = llvm::dyn_cast(op)) { + if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context)); + UpdateWhileFunctionType(&tf_op); + } else if (auto tf_op = llvm::dyn_cast(op)) { + if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context)); + tf_op.getResult()->setType(tf_op.getOperand()->getType()); } } } return success(); } -void LowerStaticTensorListPass::runOnFunction() { - if (failed(ModifyTensorList())) { - signalPassFailure(); - return; +void LowerStaticTensorListPass::runOnModule() { + // TODO(haoliang): currently we process the `main` function first, and the + // remaining functions may be processed in arbitrary order. However, this will + // have a potential issue when one function taking a `DT_VARIANT` is processed + // before the function that produces the `DT_VARIANT`. We need to carefully + // order the functions to be processed. + std::vector funcs_in_module; + for (auto func : getModule().getFunctions()) { + // Always place the main function to be the first in the list. + if (func.getName().is("main")) { + funcs_in_module.insert(funcs_in_module.begin(), func); + } else { + funcs_in_module.push_back(func); + } + } + for (auto func : funcs_in_module) { + TensorListPatternRewriter rewriter(func); + if (failed(RewriteFunction(func, &rewriter))) { + signalPassFailure(); + return; + } } - OwningRewritePatternList patterns; - auto func = getFunction(); - TFL::populateWithGenerated(&getContext(), &patterns); - patterns.push_back( - llvm::make_unique(&getContext())); - patterns.push_back( - llvm::make_unique(&getContext())); - applyPatternsGreedily(func, std::move(patterns)); } -// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList -// pass. -FunctionPassBase *TFL::CreateLowerStaticTensorListPass() { +/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList +/// pass. +ModulePassBase *TFL::CreateLowerStaticTensorListPass() { return new LowerStaticTensorListPass(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index e17d054ee8f..524850d68ea 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -18,6 +18,7 @@ limitations under the License. namespace mlir { class FunctionPassBase; +class ModulePassBase; namespace TFL { @@ -32,7 +33,7 @@ FunctionPassBase *CreatePrepareTFPass(); // Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList // pass. -FunctionPassBase *CreateLowerStaticTensorListPass(); +ModulePassBase *CreateLowerStaticTensorListPass(); // Creates an instance of the TensorFlow Lite dialect Quantize pass. FunctionPassBase *CreateQuantizePass(); diff --git a/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td b/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td index 5866afe6d8c..764f8e95f55 100644 --- a/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/tensorlist_patterns.td @@ -23,11 +23,14 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" // into regular tensors. We also assume that each element in the `TensorList` has // a same constant shape. //===----------------------------------------------------------------------===// -def : Pat<(TF_TensorListFromTensorOp $tensor, $element_shape), - (replaceWithValue $tensor)>; +def ConvertTFTensorListFromTensor : Pat< + (TF_TensorListFromTensorOp $tensor, $element_shape), + (replaceWithValue $tensor)>; -def : Pat<(TF_TensorListStackOp $input, $element_shape, $num_elements), - (replaceWithValue $input)>; +def ConvertTFTensorListStack : Pat< + (TF_TensorListStackOp $input, $element_shape, $num_elements), + (replaceWithValue $input)>; -def : Pat<(TF_TensorListGetItemOp $input, $index, $element_shape), - (TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>; +def ConvertTFTensorListGetItem : Pat< + (TF_TensorListGetItemOp $input, $index, $element_shape), + (TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;