Rewrite graph with TensorList ops in support for dynamic RNN use case.

PiperOrigin-RevId: 256272543
This commit is contained in:
Haoliang Zhang 2019-07-02 17:28:29 -07:00 committed by TensorFlower Gardener
parent 9e0e23c59f
commit 05a4e4bf8c
5 changed files with 250 additions and 121 deletions

View File

@ -232,6 +232,7 @@ cc_library(
"@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support",
"@local_config_mlir//:TypeUtilities",
],
alwayslink = 1,
)

View File

@ -1,5 +1,4 @@
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure
func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<*x!tf.variant>
@ -32,25 +31,28 @@ func @tensorlistSetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>, tensor<10x
return %2 : tensor<3x10xf32>
// CHECK-LABEL: tensorlistSetItem
// CHECK: %cst = constant dense<1> : tensor<1xi32>
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
// CHECK: %cst_1 = constant dense<-1> : tensor<i32>
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<3x10xf32>) -> tensor<i32>
// CHECK: %1 = "tf.Rank"(%arg3) : (tensor<10xf32>) -> tensor<i32>
// CHECK: %2 = "tf.ExpandDims"(%0, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %cst = constant dense<0> : tensor<i32>
// CHECK: %2 = "tf.ExpandDims"(%0, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
// CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %4 = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: %5 = "tf.ExpandDims"(%1, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %6 = "tf.Fill"(%5, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %7 = "tf.Concat"(%cst_0, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %9 = "tf.Fill"(%5, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %10 = "tf.Concat"(%cst_0, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
// CHECK: %11 = "tf.Fill"(%2, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %cst_1 = constant dense<1> : tensor<1xi32>
// CHECK: %4 = "tf.Add"(%arg2, %cst_1) : (tensor<i32>, tensor<1xi32>) -> tensor<*xi32>
// CHECK: %5 = "tf.ExpandDims"(%1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %cst_2 = constant dense<0> : tensor<i32>
// CHECK: %6 = "tf.Fill"(%5, %cst_2) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %7 = "tf.Concat"(%cst, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<*xi32>, tensor<?xi32>) -> tensor<?xi32>
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %cst_3 = constant dense<-1> : tensor<i32>
// CHECK: %9 = "tf.Fill"(%5, %cst_3) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %10 = "tf.Concat"(%cst, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
// CHECK: %cst_4 = constant dense<-1> : tensor<i32>
// CHECK: %11 = "tf.Fill"(%2, %cst_4) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst_0) : (tensor<10xf32>, tensor<i32>) -> tensor<*xf32>
// CHECK: %15 = "tf.Concat"(%cst_0, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<3x10xf32>
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst) : (tensor<10xf32>, tensor<i32>) -> tensor<*xf32>
// CHECK: %15 = "tf.Concat"(%cst, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<3x10xf32>
// CHECK: return %15 : tensor<3x10xf32>
}
@ -62,25 +64,28 @@ func @tensorlistSetItemWithScalarElements(tensor<5xf32>, tensor<0xi32>, tensor<i
return %2 : tensor<5xf32>
// CHECK-LABEL: tensorlistSetItemWithScalarElements
// CHECK: %cst = constant dense<1> : tensor<1xi32>
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
// CHECK: %cst_1 = constant dense<-1> : tensor<i32>
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<5xf32>) -> tensor<i32>
// CHECK: %1 = "tf.Rank"(%arg3) : (tensor<f32>) -> tensor<i32>
// CHECK: %2 = "tf.ExpandDims"(%0, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %cst = constant dense<0> : tensor<i32>
// CHECK: %2 = "tf.ExpandDims"(%0, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
// CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %4 = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: %5 = "tf.ExpandDims"(%1, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %6 = "tf.Fill"(%5, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %7 = "tf.Concat"(%cst_0, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %9 = "tf.Fill"(%5, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %10 = "tf.Concat"(%cst_0, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
// CHECK: %11 = "tf.Fill"(%2, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %cst_1 = constant dense<1> : tensor<1xi32>
// CHECK: %4 = "tf.Add"(%arg2, %cst_1) : (tensor<i32>, tensor<1xi32>) -> tensor<*xi32>
// CHECK: %5 = "tf.ExpandDims"(%1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %cst_2 = constant dense<0> : tensor<i32>
// CHECK: %6 = "tf.Fill"(%5, %cst_2) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %7 = "tf.Concat"(%cst, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<*xi32>, tensor<?xi32>) -> tensor<?xi32>
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %cst_3 = constant dense<-1> : tensor<i32>
// CHECK: %9 = "tf.Fill"(%5, %cst_3) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %10 = "tf.Concat"(%cst, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
// CHECK: %cst_4 = constant dense<-1> : tensor<i32>
// CHECK: %11 = "tf.Fill"(%2, %cst_4) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
// CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<5xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<5xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst_0) : (tensor<f32>, tensor<i32>) -> tensor<*xf32>
// CHECK: %15 = "tf.Concat"(%cst_0, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<5xf32>
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst) : (tensor<f32>, tensor<i32>) -> tensor<*xf32>
// CHECK: %15 = "tf.Concat"(%cst, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<5xf32>
// CHECK: return %15 : tensor<5xf32>
}
@ -92,10 +97,56 @@ func @tensorlistReserve(tensor<3xi32>, tensor<i32>, tensor<i32>) -> tensor<3xf32
// CHECK-LABEL: tensorlistReserve
// CHECK: %cst = constant dense<0> : tensor<i32>
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<f32>
// CHECK: %0 = "tf.ExpandDims"(%arg1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
// CHECK: %1 = "tf.Concat"(%cst, %0, %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<f32>
// CHECK: %2 = "tf.Fill"(%1, %cst_0) : (tensor<4xi32>, tensor<f32>) -> tensor<*xf32>
// CHECK: %3 = "tf.Gather"(%2, %arg2) {validate_indices = true} : (tensor<*xf32>, tensor<i32>) -> tensor<3xf32>
// CHECK: return %3 : tensor<3xf32>
}
func @tensorlistWhileLoop(tensor<2x3xf32>) -> tensor<*xf32> {
^bb0(%arg0: tensor<2x3xf32>):
%cst = constant dense<3> : tensor<1xi32>
%cst_0 = constant dense<0> : tensor<i32>
%cst_1 = constant dense<-1> : tensor<i32>
%0 = "tf.TensorListFromTensor"(%arg0, %cst) : (tensor<2x3xf32>, tensor<1xi32>) -> tensor<!tf.variant>
%1:2 = "tf.While"(%cst_0, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_VARIANT"], body = @tensorlistWhileBody, cond = @tensorlistWhileCond} : (tensor<i32>, tensor<!tf.variant>) -> (tensor<i32>, tensor<!tf.variant>)
%2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor<!tf.variant>, tensor<i32>) -> tensor<*xf32>
return %2 : tensor<*xf32>
// make sure the variant types in input/output have been updated, and `T` attribute
// is removed.
// CHECK-LABEL: func @tensorlistWhileLoop
// CHECK-NOT: "tf.While"{{.*}}T =
// CHECK: "tf.While"
// CHECK-SAME: (tensor<i32>, tensor<2x3xf32>) -> (tensor<i32>, tensor<*xf32>)
// CHECK: return %0#1 : tensor<*xf32>
}
func @tensorlistWhileBody(tensor<*xi32>, tensor<*x!tf.variant>) -> (tensor<*xi32>, tensor<*x!tf.variant>) {
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*x!tf.variant>):
%cst = constant dense<1> : tensor<i32>
%0 = "tf.Add"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%1 = "tf.Identity"(%arg1) : (tensor<*x!tf.variant>) -> tensor<*x!tf.variant>
return %0, %1 : tensor<*xi32>, tensor<*x!tf.variant>
// verify `body` function's signature.
// CHECK: func @tensorlistWhileBody(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>)
// CHECK: %0 = "tf.Add"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
// CHECK-NOT: tensor<*x!tf.variant>
// CHECK: %1 = "tf.Identity"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %0, %1 : tensor<*xi32>, tensor<*xf32>
}
func @tensorlistWhileCond(tensor<*xi32>, tensor<*x!tf.variant>) -> tensor<*xi1> {
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*x!tf.variant>):
%cst = constant dense<2> : tensor<i32>
%0 = "tf.Less"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
return %0 : tensor<*xi1>
// verify `cond` function's signature.
// CHECK: func @tensorlistWhileCond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<*xi1>
// CHECK: %0 = "tf.Less"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
// CHECK: return %0 : tensor<*xi1>
}

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "mlir/IR/Block.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
@ -42,11 +43,13 @@ limitations under the License.
#include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#define DEBUG_TYPE "tf-tfl-legalization"
@ -56,11 +59,30 @@ limitations under the License.
namespace mlir {
namespace {
// Lower TensorList ops in functions for subsequent legalization.
class TensorListPatternRewriter : public PatternRewriter {
public:
explicit TensorListPatternRewriter(Function fn)
: PatternRewriter(fn.getBody()) {}
Operation *createOperation(const OperationState &state) override {
return OpBuilder::createOperation(state);
}
};
/// Lower TensorList ops in functions for subsequent legalization.
// TODO(haoliang): Use DialectConversion infra to simplify the rewriting
// process.
struct LowerStaticTensorListPass
: public FunctionPass<LowerStaticTensorListPass> {
void runOnFunction() override;
LogicalResult ModifyTensorList();
: public ModulePass<LowerStaticTensorListPass> {
void runOnModule() override;
// Apply type and op changes within a function.
LogicalResult RewriteFunction(Function func,
TensorListPatternRewriter *rewriter);
// Changes the function type of `cond_func` and `body_func`, and the result
// type of the `WhileOp`.
LogicalResult UpdateWhileFunctionType(TF::WhileOp *while_op);
};
Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter,
@ -121,9 +143,9 @@ struct ConvertTFTensorListSetItem : public RewritePattern {
// Calculate the first dimension, which is index + 1.
auto index = tf_op.index();
auto vector_type = rewriter.getTensorType({1}, shape_dtype);
auto begin =
rewriter.create<TF::AddOp>(op->getLoc(), vector_type, index,
CreateI32SplatConst(op, &rewriter, {1}, 1));
auto begin = rewriter.create<TF::AddOp>(
op->getLoc(), rewriter.getTensorType(shape_dtype), index,
CreateI32SplatConst(op, &rewriter, {1}, 1));
// Followed by the first dimension `begin`, are `item_rank` of 0s.
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
@ -198,7 +220,7 @@ struct ConvertTFTensorListReserve : public RewritePattern {
if (auto type = element_shape->getType().dyn_cast<RankedTensorType>()) {
// Note that the first item of the shape array is the element's rank, add
// it by 1 to get the input's rank.
if (type.hasStaticShape()) {
if (type.hasStaticShape() && type.getRank() != 0) {
input_rank = type.getShape()[0] + 1;
}
}
@ -236,98 +258,149 @@ namespace {
} // namespace
} // namespace TFL
LogicalResult LowerStaticTensorListPass::ModifyTensorList() {
// In `runOnFunction`, there is no guarantee about
// in which order those patterns will be applied. Our transformation requires
// that at runtime each `TensorListSetItem` op takes in a normal tensor type
// rather than a `DT_VARIANT` tensor. So here we need to manually walk-through
// the IR and change the argument/return types of each `TensorListSetItemOp`.
// TODO(haoliang): 1) support modifying more `TensorList` ops that consumes/
// produces `DT_VARIANT` tensor. 2) More robust support for handling multiple
// different tensorlist types. For example, consider the case like:
// l1 = list_ops.tensor_list_from_tensor(t, element_shape1)
// l2 = list_ops.tensor_list_from_tensor(t, element_shape2)
// l1 = list_ops.tensor_list_set_item(l1, 0, item1)
// l2 = list_ops.tensor_list_set_item(l2, 0, item2)
// 3) Handle the case where a tensorlist output is passed to multiple
// functions.
for (Block &block : getFunction()) {
Type tensor_type;
LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
TF::WhileOp *while_op) {
SmallVector<Type, 8> unranked_argument_types;
for (const auto &operand : while_op->getOperands()) {
unranked_argument_types.push_back(
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
}
auto *context = &getContext();
auto module = getModule();
Function cond_func = module.getNamedFunction(while_op->getCond());
Function body_func = module.getNamedFunction(while_op->getBody());
if (cond_func) {
// Change `cond_func`'s argument types to `unranked_argument_types`.
cond_func.setType(FunctionType::get(
unranked_argument_types, cond_func.getType().getResults(), context));
// Change the argument type for the first block.
Block &cond_first_bb = cond_func.front();
for (int i = 0; i < cond_first_bb.getNumArguments(); ++i) {
cond_first_bb.getArgument(i)->setType(unranked_argument_types[i]);
}
}
if (body_func) {
SmallVector<Type, 8> updated_result_types;
for (int i = 0; i < body_func.getType().getNumResults(); ++i) {
auto result_type = body_func.getType().getResult(i);
if (getElementTypeOrSelf(result_type).isa<TF::VariantType>()) {
// For variant type, use the corresponding unranked type.
updated_result_types.push_back(unranked_argument_types[i]);
} else {
updated_result_types.push_back(result_type);
}
}
// Change `body_func`'s argument type to `unranked_argument_types`. If it
// return types contain a `DT_VARIANT`, change it to the unranked type
// derived from the corresponding argument.
body_func.setType(FunctionType::get(unranked_argument_types,
updated_result_types, context));
// Change the argument type for the first block.
Block &body_first_bb = body_func.front();
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
body_first_bb.getArgument(i)->setType(unranked_argument_types[i]);
}
}
for (int i = 0; i < while_op->getNumOperands(); ++i) {
auto operand = while_op->getOperand(i);
auto result = while_op->getResult(i);
if (getElementTypeOrSelf(result->getType()).isa<TF::VariantType>()) {
// If we notice the result type is a DT_VARIANT, we change the
// corresponding result type to unranked tensor type.
result->setType(
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
}
}
return success();
}
LogicalResult LowerStaticTensorListPass::RewriteFunction(
Function func, TensorListPatternRewriter *rewriter) {
auto *context = &getContext();
for (Block &block : func) {
// Buffer the op pointers inside the current block into a vector, since
// the block iterator might be invalidated if we rewrite ops during looping.
std::vector<Operation *> ops_in_block;
for (Operation &op : block) {
ops_in_block.push_back(&op);
}
for (Operation *op : ops_in_block) {
if (auto tf_op = llvm::dyn_cast<TF::TensorListFromTensorOp>(op)) {
tensor_type = tf_op.tensor()->getType();
auto c = TFL::ConvertTFTensorListFromTensor(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListReserveOp>(op)) {
if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() ||
tf_op.element_dtype().isF64() ||
tf_op.element_dtype().isa<IntegerType>())) {
tf_op.element_dtype().isInteger(8) ||
tf_op.element_dtype().isInteger(16) ||
tf_op.element_dtype().isInteger(32) ||
tf_op.element_dtype().isInteger(64))) {
return tf_op.emitError(
"requires element_dtype to be integer or 16-bit/32-bit/64-bit "
"requires element_dtype to be 8-bit/16-bit/32-bit/64-bit integer "
"or 16-bit/32-bit/64-bit "
"float type during TF Lite transformation pass");
}
// TODO(haoliang): figure out better way of specify shape.
tensor_type = UnrankedTensorType::get(tf_op.element_dtype());
}
if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
tf_op.input_handle()->setType(tensor_type);
tf_op.getResult()->setType(tensor_type);
}
// Currently we will raise an error if an op other than the following
// contains a DT_VARIANT tensor as its input or output. Below ops already
// have proper transformation patterns that eliminate the need of
// `DT_VARIANT`, we consider it's safe to not raise an error on those ops.
if (llvm::isa<TF::TensorListFromTensorOp>(op) ||
llvm::isa<TF::TensorListReserveOp>(op) ||
llvm::isa<TF::TensorListSetItemOp>(op) ||
llvm::isa<TF::TensorListStackOp>(op) ||
llvm::isa<TF::TensorListGetItemOp>(op)) {
continue;
}
// Check if any of the input operand is a DT_VARIANT.
for (Type type : op.getOperandTypes()) {
if (type.isa<TF::VariantType>()) {
return op.emitError(
"op's input contains a DT_VARIANT tensor. Currently we only "
"allow "
"TensorListFromTensor/TensorListReserve/TensorListStack/"
"TensorListSetItem/"
"TensorListGetItem to have DT_VARIANT input/output");
}
}
// Check if any of the output is a DT_VARIANT.
for (Type type : op.getResultTypes()) {
if (type.isa<TF::VariantType>()) {
return op.emitError(
"op's output contains a DT_VARIANT tensor. Currently we only "
"allow "
"TensorListFromTensor/TensorListReserve/TensorListStack/"
"TensorListSetItem/"
"TensorListGetItem to have DT_VARIANT input/output");
}
auto c = ConvertTFTensorListReserve(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListGetItemOp>(op)) {
auto c = TFL::ConvertTFTensorListGetItem(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
auto c = ConvertTFTensorListSetItem(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListStackOp>(op)) {
auto c = TFL::ConvertTFTensorListStack(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
UpdateWhileFunctionType(&tf_op);
} else if (auto tf_op = llvm::dyn_cast<TF::IdentityOp>(op)) {
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
tf_op.getResult()->setType(tf_op.getOperand()->getType());
}
}
}
return success();
}
void LowerStaticTensorListPass::runOnFunction() {
if (failed(ModifyTensorList())) {
signalPassFailure();
return;
void LowerStaticTensorListPass::runOnModule() {
// TODO(haoliang): currently we process the `main` function first, and the
// remaining functions may be processed in arbitrary order. However, this will
// have a potential issue when one function taking a `DT_VARIANT` is processed
// before the function that produces the `DT_VARIANT`. We need to carefully
// order the functions to be processed.
std::vector<Function> funcs_in_module;
for (auto func : getModule().getFunctions()) {
// Always place the main function to be the first in the list.
if (func.getName().is("main")) {
funcs_in_module.insert(funcs_in_module.begin(), func);
} else {
funcs_in_module.push_back(func);
}
}
for (auto func : funcs_in_module) {
TensorListPatternRewriter rewriter(func);
if (failed(RewriteFunction(func, &rewriter))) {
signalPassFailure();
return;
}
}
OwningRewritePatternList patterns;
auto func = getFunction();
TFL::populateWithGenerated(&getContext(), &patterns);
patterns.push_back(
llvm::make_unique<ConvertTFTensorListReserve>(&getContext()));
patterns.push_back(
llvm::make_unique<ConvertTFTensorListSetItem>(&getContext()));
applyPatternsGreedily(func, std::move(patterns));
}
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
// pass.
FunctionPassBase *TFL::CreateLowerStaticTensorListPass() {
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
/// pass.
ModulePassBase *TFL::CreateLowerStaticTensorListPass() {
return new LowerStaticTensorListPass();
}

View File

@ -18,6 +18,7 @@ limitations under the License.
namespace mlir {
class FunctionPassBase;
class ModulePassBase;
namespace TFL {
@ -32,7 +33,7 @@ FunctionPassBase *CreatePrepareTFPass();
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
// pass.
FunctionPassBase *CreateLowerStaticTensorListPass();
ModulePassBase *CreateLowerStaticTensorListPass();
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
FunctionPassBase *CreateQuantizePass();

View File

@ -23,11 +23,14 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
// into regular tensors. We also assume that each element in the `TensorList` has
// a same constant shape.
//===----------------------------------------------------------------------===//
def : Pat<(TF_TensorListFromTensorOp $tensor, $element_shape),
(replaceWithValue $tensor)>;
def ConvertTFTensorListFromTensor : Pat<
(TF_TensorListFromTensorOp $tensor, $element_shape),
(replaceWithValue $tensor)>;
def : Pat<(TF_TensorListStackOp $input, $element_shape, $num_elements),
(replaceWithValue $input)>;
def ConvertTFTensorListStack : Pat<
(TF_TensorListStackOp $input, $element_shape, $num_elements),
(replaceWithValue $input)>;
def : Pat<(TF_TensorListGetItemOp $input, $index, $element_shape),
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;
def ConvertTFTensorListGetItem : Pat<
(TF_TensorListGetItemOp $input, $index, $element_shape),
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;