Rewrite graph with TensorList ops in support for dynamic RNN use case.
PiperOrigin-RevId: 256272543
This commit is contained in:
parent
9e0e23c59f
commit
05a4e4bf8c
@ -232,6 +232,7 @@ cc_library(
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:TypeUtilities",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -1,5 +1,4 @@
|
||||
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s
|
||||
|
||||
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure
|
||||
func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
|
||||
^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<*x!tf.variant>
|
||||
@ -32,25 +31,28 @@ func @tensorlistSetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>, tensor<10x
|
||||
return %2 : tensor<3x10xf32>
|
||||
|
||||
// CHECK-LABEL: tensorlistSetItem
|
||||
// CHECK: %cst = constant dense<1> : tensor<1xi32>
|
||||
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
|
||||
// CHECK: %cst_1 = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<3x10xf32>) -> tensor<i32>
|
||||
// CHECK: %1 = "tf.Rank"(%arg3) : (tensor<10xf32>) -> tensor<i32>
|
||||
// CHECK: %2 = "tf.ExpandDims"(%0, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %cst = constant dense<0> : tensor<i32>
|
||||
// CHECK: %2 = "tf.ExpandDims"(%0, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
|
||||
// CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %4 = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: %5 = "tf.ExpandDims"(%1, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %6 = "tf.Fill"(%5, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %7 = "tf.Concat"(%cst_0, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %9 = "tf.Fill"(%5, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %10 = "tf.Concat"(%cst_0, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %11 = "tf.Fill"(%2, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %cst_1 = constant dense<1> : tensor<1xi32>
|
||||
// CHECK: %4 = "tf.Add"(%arg2, %cst_1) : (tensor<i32>, tensor<1xi32>) -> tensor<*xi32>
|
||||
// CHECK: %5 = "tf.ExpandDims"(%1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %cst_2 = constant dense<0> : tensor<i32>
|
||||
// CHECK: %6 = "tf.Fill"(%5, %cst_2) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %7 = "tf.Concat"(%cst, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<*xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %cst_3 = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %9 = "tf.Fill"(%5, %cst_3) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %10 = "tf.Concat"(%cst, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %cst_4 = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %11 = "tf.Fill"(%2, %cst_4) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<3x10xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst_0) : (tensor<10xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
// CHECK: %15 = "tf.Concat"(%cst_0, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<3x10xf32>
|
||||
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst) : (tensor<10xf32>, tensor<i32>) -> tensor<*xf32>
|
||||
// CHECK: %15 = "tf.Concat"(%cst, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<3x10xf32>
|
||||
// CHECK: return %15 : tensor<3x10xf32>
|
||||
}
|
||||
|
||||
@ -62,25 +64,28 @@ func @tensorlistSetItemWithScalarElements(tensor<5xf32>, tensor<0xi32>, tensor<i
|
||||
return %2 : tensor<5xf32>
|
||||
|
||||
// CHECK-LABEL: tensorlistSetItemWithScalarElements
|
||||
// CHECK: %cst = constant dense<1> : tensor<1xi32>
|
||||
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
|
||||
// CHECK: %cst_1 = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %0 = "tf.Rank"(%arg0) : (tensor<5xf32>) -> tensor<i32>
|
||||
// CHECK: %1 = "tf.Rank"(%arg3) : (tensor<f32>) -> tensor<i32>
|
||||
// CHECK: %2 = "tf.ExpandDims"(%0, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %cst = constant dense<0> : tensor<i32>
|
||||
// CHECK: %2 = "tf.ExpandDims"(%0, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %cst_0 = constant dense<0> : tensor<i32>
|
||||
// CHECK: %3 = "tf.Fill"(%2, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %4 = "tf.Add"(%arg2, %cst) : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: %5 = "tf.ExpandDims"(%1, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %6 = "tf.Fill"(%5, %cst_0) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %7 = "tf.Concat"(%cst_0, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst_0) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %9 = "tf.Fill"(%5, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %10 = "tf.Concat"(%cst_0, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %11 = "tf.Fill"(%2, %cst_1) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %cst_1 = constant dense<1> : tensor<1xi32>
|
||||
// CHECK: %4 = "tf.Add"(%arg2, %cst_1) : (tensor<i32>, tensor<1xi32>) -> tensor<*xi32>
|
||||
// CHECK: %5 = "tf.ExpandDims"(%1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %cst_2 = constant dense<0> : tensor<i32>
|
||||
// CHECK: %6 = "tf.Fill"(%5, %cst_2) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %7 = "tf.Concat"(%cst, %4, %6) {N = 2 : i64} : (tensor<i32>, tensor<*xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %8 = "tf.ExpandDims"(%arg2, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %cst_3 = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %9 = "tf.Fill"(%5, %cst_3) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %10 = "tf.Concat"(%cst, %8, %9) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %cst_4 = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %11 = "tf.Fill"(%2, %cst_4) : (tensor<1xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
// CHECK: %12 = "tf.Slice"(%arg0, %3, %10) : (tensor<5xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %13 = "tf.Slice"(%arg0, %7, %11) : (tensor<5xf32>, tensor<?xi32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst_0) : (tensor<f32>, tensor<i32>) -> tensor<*xf32>
|
||||
// CHECK: %15 = "tf.Concat"(%cst_0, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<5xf32>
|
||||
// CHECK: %14 = "tf.ExpandDims"(%arg3, %cst) : (tensor<f32>, tensor<i32>) -> tensor<*xf32>
|
||||
// CHECK: %15 = "tf.Concat"(%cst, %12, %14, %13) {N = 3 : i64} : (tensor<i32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<5xf32>
|
||||
// CHECK: return %15 : tensor<5xf32>
|
||||
}
|
||||
|
||||
@ -92,10 +97,56 @@ func @tensorlistReserve(tensor<3xi32>, tensor<i32>, tensor<i32>) -> tensor<3xf32
|
||||
|
||||
// CHECK-LABEL: tensorlistReserve
|
||||
// CHECK: %cst = constant dense<0> : tensor<i32>
|
||||
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %0 = "tf.ExpandDims"(%arg1, %cst) : (tensor<i32>, tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK: %1 = "tf.Concat"(%cst, %0, %arg0) {N = 2 : i64} : (tensor<i32>, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32>
|
||||
// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %2 = "tf.Fill"(%1, %cst_0) : (tensor<4xi32>, tensor<f32>) -> tensor<*xf32>
|
||||
// CHECK: %3 = "tf.Gather"(%2, %arg2) {validate_indices = true} : (tensor<*xf32>, tensor<i32>) -> tensor<3xf32>
|
||||
// CHECK: return %3 : tensor<3xf32>
|
||||
}
|
||||
|
||||
func @tensorlistWhileLoop(tensor<2x3xf32>) -> tensor<*xf32> {
|
||||
^bb0(%arg0: tensor<2x3xf32>):
|
||||
%cst = constant dense<3> : tensor<1xi32>
|
||||
%cst_0 = constant dense<0> : tensor<i32>
|
||||
%cst_1 = constant dense<-1> : tensor<i32>
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %cst) : (tensor<2x3xf32>, tensor<1xi32>) -> tensor<!tf.variant>
|
||||
%1:2 = "tf.While"(%cst_0, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_VARIANT"], body = @tensorlistWhileBody, cond = @tensorlistWhileCond} : (tensor<i32>, tensor<!tf.variant>) -> (tensor<i32>, tensor<!tf.variant>)
|
||||
%2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor<!tf.variant>, tensor<i32>) -> tensor<*xf32>
|
||||
return %2 : tensor<*xf32>
|
||||
|
||||
// make sure the variant types in input/output have been updated, and `T` attribute
|
||||
// is removed.
|
||||
// CHECK-LABEL: func @tensorlistWhileLoop
|
||||
// CHECK-NOT: "tf.While"{{.*}}T =
|
||||
// CHECK: "tf.While"
|
||||
// CHECK-SAME: (tensor<i32>, tensor<2x3xf32>) -> (tensor<i32>, tensor<*xf32>)
|
||||
// CHECK: return %0#1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @tensorlistWhileBody(tensor<*xi32>, tensor<*x!tf.variant>) -> (tensor<*xi32>, tensor<*x!tf.variant>) {
|
||||
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*x!tf.variant>):
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
%0 = "tf.Add"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%1 = "tf.Identity"(%arg1) : (tensor<*x!tf.variant>) -> tensor<*x!tf.variant>
|
||||
return %0, %1 : tensor<*xi32>, tensor<*x!tf.variant>
|
||||
|
||||
// verify `body` function's signature.
|
||||
// CHECK: func @tensorlistWhileBody(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>)
|
||||
// CHECK: %0 = "tf.Add"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
// CHECK-NOT: tensor<*x!tf.variant>
|
||||
// CHECK: %1 = "tf.Identity"(%arg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: return %0, %1 : tensor<*xi32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
func @tensorlistWhileCond(tensor<*xi32>, tensor<*x!tf.variant>) -> tensor<*xi1> {
|
||||
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*x!tf.variant>):
|
||||
%cst = constant dense<2> : tensor<i32>
|
||||
%0 = "tf.Less"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||
return %0 : tensor<*xi1>
|
||||
|
||||
// verify `cond` function's signature.
|
||||
// CHECK: func @tensorlistWhileCond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor<*xi1>
|
||||
// CHECK: %0 = "tf.Less"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||
// CHECK: return %0 : tensor<*xi1>
|
||||
}
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Block.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
@ -42,11 +43,13 @@ limitations under the License.
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/validators.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
|
||||
#define DEBUG_TYPE "tf-tfl-legalization"
|
||||
|
||||
@ -56,11 +59,30 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
namespace {
|
||||
|
||||
// Lower TensorList ops in functions for subsequent legalization.
|
||||
class TensorListPatternRewriter : public PatternRewriter {
|
||||
public:
|
||||
explicit TensorListPatternRewriter(Function fn)
|
||||
: PatternRewriter(fn.getBody()) {}
|
||||
|
||||
Operation *createOperation(const OperationState &state) override {
|
||||
return OpBuilder::createOperation(state);
|
||||
}
|
||||
};
|
||||
|
||||
/// Lower TensorList ops in functions for subsequent legalization.
|
||||
// TODO(haoliang): Use DialectConversion infra to simplify the rewriting
|
||||
// process.
|
||||
struct LowerStaticTensorListPass
|
||||
: public FunctionPass<LowerStaticTensorListPass> {
|
||||
void runOnFunction() override;
|
||||
LogicalResult ModifyTensorList();
|
||||
: public ModulePass<LowerStaticTensorListPass> {
|
||||
void runOnModule() override;
|
||||
|
||||
// Apply type and op changes within a function.
|
||||
LogicalResult RewriteFunction(Function func,
|
||||
TensorListPatternRewriter *rewriter);
|
||||
|
||||
// Changes the function type of `cond_func` and `body_func`, and the result
|
||||
// type of the `WhileOp`.
|
||||
LogicalResult UpdateWhileFunctionType(TF::WhileOp *while_op);
|
||||
};
|
||||
|
||||
Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter,
|
||||
@ -121,9 +143,9 @@ struct ConvertTFTensorListSetItem : public RewritePattern {
|
||||
// Calculate the first dimension, which is index + 1.
|
||||
auto index = tf_op.index();
|
||||
auto vector_type = rewriter.getTensorType({1}, shape_dtype);
|
||||
auto begin =
|
||||
rewriter.create<TF::AddOp>(op->getLoc(), vector_type, index,
|
||||
CreateI32SplatConst(op, &rewriter, {1}, 1));
|
||||
auto begin = rewriter.create<TF::AddOp>(
|
||||
op->getLoc(), rewriter.getTensorType(shape_dtype), index,
|
||||
CreateI32SplatConst(op, &rewriter, {1}, 1));
|
||||
|
||||
// Followed by the first dimension `begin`, are `item_rank` of 0s.
|
||||
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
|
||||
@ -198,7 +220,7 @@ struct ConvertTFTensorListReserve : public RewritePattern {
|
||||
if (auto type = element_shape->getType().dyn_cast<RankedTensorType>()) {
|
||||
// Note that the first item of the shape array is the element's rank, add
|
||||
// it by 1 to get the input's rank.
|
||||
if (type.hasStaticShape()) {
|
||||
if (type.hasStaticShape() && type.getRank() != 0) {
|
||||
input_rank = type.getShape()[0] + 1;
|
||||
}
|
||||
}
|
||||
@ -236,98 +258,149 @@ namespace {
|
||||
} // namespace
|
||||
} // namespace TFL
|
||||
|
||||
LogicalResult LowerStaticTensorListPass::ModifyTensorList() {
|
||||
// In `runOnFunction`, there is no guarantee about
|
||||
// in which order those patterns will be applied. Our transformation requires
|
||||
// that at runtime each `TensorListSetItem` op takes in a normal tensor type
|
||||
// rather than a `DT_VARIANT` tensor. So here we need to manually walk-through
|
||||
// the IR and change the argument/return types of each `TensorListSetItemOp`.
|
||||
// TODO(haoliang): 1) support modifying more `TensorList` ops that consumes/
|
||||
// produces `DT_VARIANT` tensor. 2) More robust support for handling multiple
|
||||
// different tensorlist types. For example, consider the case like:
|
||||
// l1 = list_ops.tensor_list_from_tensor(t, element_shape1)
|
||||
// l2 = list_ops.tensor_list_from_tensor(t, element_shape2)
|
||||
// l1 = list_ops.tensor_list_set_item(l1, 0, item1)
|
||||
// l2 = list_ops.tensor_list_set_item(l2, 0, item2)
|
||||
// 3) Handle the case where a tensorlist output is passed to multiple
|
||||
// functions.
|
||||
for (Block &block : getFunction()) {
|
||||
Type tensor_type;
|
||||
LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
|
||||
TF::WhileOp *while_op) {
|
||||
SmallVector<Type, 8> unranked_argument_types;
|
||||
for (const auto &operand : while_op->getOperands()) {
|
||||
unranked_argument_types.push_back(
|
||||
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
|
||||
}
|
||||
|
||||
auto *context = &getContext();
|
||||
auto module = getModule();
|
||||
Function cond_func = module.getNamedFunction(while_op->getCond());
|
||||
Function body_func = module.getNamedFunction(while_op->getBody());
|
||||
|
||||
if (cond_func) {
|
||||
// Change `cond_func`'s argument types to `unranked_argument_types`.
|
||||
cond_func.setType(FunctionType::get(
|
||||
unranked_argument_types, cond_func.getType().getResults(), context));
|
||||
// Change the argument type for the first block.
|
||||
Block &cond_first_bb = cond_func.front();
|
||||
for (int i = 0; i < cond_first_bb.getNumArguments(); ++i) {
|
||||
cond_first_bb.getArgument(i)->setType(unranked_argument_types[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (body_func) {
|
||||
SmallVector<Type, 8> updated_result_types;
|
||||
for (int i = 0; i < body_func.getType().getNumResults(); ++i) {
|
||||
auto result_type = body_func.getType().getResult(i);
|
||||
if (getElementTypeOrSelf(result_type).isa<TF::VariantType>()) {
|
||||
// For variant type, use the corresponding unranked type.
|
||||
updated_result_types.push_back(unranked_argument_types[i]);
|
||||
} else {
|
||||
updated_result_types.push_back(result_type);
|
||||
}
|
||||
}
|
||||
// Change `body_func`'s argument type to `unranked_argument_types`. If it
|
||||
// return types contain a `DT_VARIANT`, change it to the unranked type
|
||||
// derived from the corresponding argument.
|
||||
body_func.setType(FunctionType::get(unranked_argument_types,
|
||||
updated_result_types, context));
|
||||
// Change the argument type for the first block.
|
||||
Block &body_first_bb = body_func.front();
|
||||
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
|
||||
body_first_bb.getArgument(i)->setType(unranked_argument_types[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < while_op->getNumOperands(); ++i) {
|
||||
auto operand = while_op->getOperand(i);
|
||||
auto result = while_op->getResult(i);
|
||||
if (getElementTypeOrSelf(result->getType()).isa<TF::VariantType>()) {
|
||||
// If we notice the result type is a DT_VARIANT, we change the
|
||||
// corresponding result type to unranked tensor type.
|
||||
result->setType(
|
||||
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
Function func, TensorListPatternRewriter *rewriter) {
|
||||
auto *context = &getContext();
|
||||
|
||||
for (Block &block : func) {
|
||||
// Buffer the op pointers inside the current block into a vector, since
|
||||
// the block iterator might be invalidated if we rewrite ops during looping.
|
||||
std::vector<Operation *> ops_in_block;
|
||||
for (Operation &op : block) {
|
||||
ops_in_block.push_back(&op);
|
||||
}
|
||||
|
||||
for (Operation *op : ops_in_block) {
|
||||
if (auto tf_op = llvm::dyn_cast<TF::TensorListFromTensorOp>(op)) {
|
||||
tensor_type = tf_op.tensor()->getType();
|
||||
auto c = TFL::ConvertTFTensorListFromTensor(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListReserveOp>(op)) {
|
||||
if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() ||
|
||||
tf_op.element_dtype().isF64() ||
|
||||
tf_op.element_dtype().isa<IntegerType>())) {
|
||||
tf_op.element_dtype().isInteger(8) ||
|
||||
tf_op.element_dtype().isInteger(16) ||
|
||||
tf_op.element_dtype().isInteger(32) ||
|
||||
tf_op.element_dtype().isInteger(64))) {
|
||||
return tf_op.emitError(
|
||||
"requires element_dtype to be integer or 16-bit/32-bit/64-bit "
|
||||
"requires element_dtype to be 8-bit/16-bit/32-bit/64-bit integer "
|
||||
"or 16-bit/32-bit/64-bit "
|
||||
"float type during TF Lite transformation pass");
|
||||
}
|
||||
// TODO(haoliang): figure out better way of specify shape.
|
||||
tensor_type = UnrankedTensorType::get(tf_op.element_dtype());
|
||||
}
|
||||
|
||||
if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
|
||||
tf_op.input_handle()->setType(tensor_type);
|
||||
tf_op.getResult()->setType(tensor_type);
|
||||
}
|
||||
// Currently we will raise an error if an op other than the following
|
||||
// contains a DT_VARIANT tensor as its input or output. Below ops already
|
||||
// have proper transformation patterns that eliminate the need of
|
||||
// `DT_VARIANT`, we consider it's safe to not raise an error on those ops.
|
||||
if (llvm::isa<TF::TensorListFromTensorOp>(op) ||
|
||||
llvm::isa<TF::TensorListReserveOp>(op) ||
|
||||
llvm::isa<TF::TensorListSetItemOp>(op) ||
|
||||
llvm::isa<TF::TensorListStackOp>(op) ||
|
||||
llvm::isa<TF::TensorListGetItemOp>(op)) {
|
||||
continue;
|
||||
}
|
||||
// Check if any of the input operand is a DT_VARIANT.
|
||||
for (Type type : op.getOperandTypes()) {
|
||||
if (type.isa<TF::VariantType>()) {
|
||||
return op.emitError(
|
||||
"op's input contains a DT_VARIANT tensor. Currently we only "
|
||||
"allow "
|
||||
"TensorListFromTensor/TensorListReserve/TensorListStack/"
|
||||
"TensorListSetItem/"
|
||||
"TensorListGetItem to have DT_VARIANT input/output");
|
||||
}
|
||||
}
|
||||
// Check if any of the output is a DT_VARIANT.
|
||||
for (Type type : op.getResultTypes()) {
|
||||
if (type.isa<TF::VariantType>()) {
|
||||
return op.emitError(
|
||||
"op's output contains a DT_VARIANT tensor. Currently we only "
|
||||
"allow "
|
||||
"TensorListFromTensor/TensorListReserve/TensorListStack/"
|
||||
"TensorListSetItem/"
|
||||
"TensorListGetItem to have DT_VARIANT input/output");
|
||||
}
|
||||
auto c = ConvertTFTensorListReserve(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListGetItemOp>(op)) {
|
||||
auto c = TFL::ConvertTFTensorListGetItem(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
|
||||
auto c = ConvertTFTensorListSetItem(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListStackOp>(op)) {
|
||||
auto c = TFL::ConvertTFTensorListStack(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
|
||||
UpdateWhileFunctionType(&tf_op);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::IdentityOp>(op)) {
|
||||
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
|
||||
tf_op.getResult()->setType(tf_op.getOperand()->getType());
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void LowerStaticTensorListPass::runOnFunction() {
|
||||
if (failed(ModifyTensorList())) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
void LowerStaticTensorListPass::runOnModule() {
|
||||
// TODO(haoliang): currently we process the `main` function first, and the
|
||||
// remaining functions may be processed in arbitrary order. However, this will
|
||||
// have a potential issue when one function taking a `DT_VARIANT` is processed
|
||||
// before the function that produces the `DT_VARIANT`. We need to carefully
|
||||
// order the functions to be processed.
|
||||
std::vector<Function> funcs_in_module;
|
||||
for (auto func : getModule().getFunctions()) {
|
||||
// Always place the main function to be the first in the list.
|
||||
if (func.getName().is("main")) {
|
||||
funcs_in_module.insert(funcs_in_module.begin(), func);
|
||||
} else {
|
||||
funcs_in_module.push_back(func);
|
||||
}
|
||||
}
|
||||
for (auto func : funcs_in_module) {
|
||||
TensorListPatternRewriter rewriter(func);
|
||||
if (failed(RewriteFunction(func, &rewriter))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
TFL::populateWithGenerated(&getContext(), &patterns);
|
||||
patterns.push_back(
|
||||
llvm::make_unique<ConvertTFTensorListReserve>(&getContext()));
|
||||
patterns.push_back(
|
||||
llvm::make_unique<ConvertTFTensorListSetItem>(&getContext()));
|
||||
applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||
// pass.
|
||||
FunctionPassBase *TFL::CreateLowerStaticTensorListPass() {
|
||||
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||
/// pass.
|
||||
ModulePassBase *TFL::CreateLowerStaticTensorListPass() {
|
||||
return new LowerStaticTensorListPass();
|
||||
}
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
namespace mlir {
|
||||
class FunctionPassBase;
|
||||
class ModulePassBase;
|
||||
|
||||
namespace TFL {
|
||||
|
||||
@ -32,7 +33,7 @@ FunctionPassBase *CreatePrepareTFPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||
// pass.
|
||||
FunctionPassBase *CreateLowerStaticTensorListPass();
|
||||
ModulePassBase *CreateLowerStaticTensorListPass();
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect Quantize pass.
|
||||
FunctionPassBase *CreateQuantizePass();
|
||||
|
@ -23,11 +23,14 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
// into regular tensors. We also assume that each element in the `TensorList` has
|
||||
// a same constant shape.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def : Pat<(TF_TensorListFromTensorOp $tensor, $element_shape),
|
||||
(replaceWithValue $tensor)>;
|
||||
def ConvertTFTensorListFromTensor : Pat<
|
||||
(TF_TensorListFromTensorOp $tensor, $element_shape),
|
||||
(replaceWithValue $tensor)>;
|
||||
|
||||
def : Pat<(TF_TensorListStackOp $input, $element_shape, $num_elements),
|
||||
(replaceWithValue $input)>;
|
||||
def ConvertTFTensorListStack : Pat<
|
||||
(TF_TensorListStackOp $input, $element_shape, $num_elements),
|
||||
(replaceWithValue $input)>;
|
||||
|
||||
def : Pat<(TF_TensorListGetItemOp $input, $index, $element_shape),
|
||||
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;
|
||||
def ConvertTFTensorListGetItem : Pat<
|
||||
(TF_TensorListGetItemOp $input, $index, $element_shape),
|
||||
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;
|
||||
|
Loading…
Reference in New Issue
Block a user