NFC: Improve tfl-lower-static-tensor-list pass and tests

Specifically,
* Use better syntax in tests by naming arguments in function signatures and not having any explicit block
* Pass WhileOp by value
* Use OpRewritePattern instead of RewritePattern for the patterns

PiperOrigin-RevId: 261007552
This commit is contained in:
Smit Hinsu 2019-07-31 15:11:30 -07:00 committed by TensorFlower Gardener
parent 58dfc47ef1
commit 6b31ac341a
2 changed files with 50 additions and 62 deletions

View File

@ -1,6 +1,6 @@
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure
func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
func @tensorlistGetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>, tensor<1xi32>) -> tensor<10xf32>
%2 = "tf.TensorListStack"(%0, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<1xi32>) -> tensor<3x10xf32>
@ -11,8 +11,7 @@ func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>) -> (tensor
// CHECK: return %0, %arg0 : tensor<10xf32>, tensor<3x10xf32>
}
func @tensorlistGetItemWithUnknownRank(tensor<*xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) {
^bb0(%arg0: tensor<*xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
func @tensorlistGetItemWithUnknownRank(%arg0: tensor<*xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) {
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<*xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<*xf32>>>
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>, tensor<1xi32>) -> tensor<*xf32>
%2 = "tf.TensorListStack"(%0, %arg1) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<1xi32>) -> tensor<*xf32>
@ -23,8 +22,7 @@ func @tensorlistGetItemWithUnknownRank(tensor<*xf32>, tensor<1xi32>, tensor<i32>
// CHECK: return %0, %arg0 : tensor<*xf32>, tensor<*xf32>
}
func @tensorlistSetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>, tensor<10xf32>) -> tensor<3x10xf32> {
^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>, %arg3: tensor<10xf32>):
func @tensorlistSetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>, %arg3: tensor<10xf32>) -> tensor<3x10xf32> {
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
%1 = "tf.TensorListSetItem"(%0, %arg2, %arg3) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>, tensor<10xf32>) -> tensor<!tf.variant<tensor<10xf32>>>
%2 = "tf.TensorListStack"(%1, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<1xi32>) -> tensor<3x10xf32>
@ -56,8 +54,7 @@ func @tensorlistSetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>, tensor<10x
// CHECK: return %15 : tensor<3x10xf32>
}
func @tensorlistSetItemWithScalarElements(tensor<5xf32>, tensor<0xi32>, tensor<i32>, tensor<f32>) -> tensor<5xf32> {
^bb0(%arg0: tensor<5xf32>, %arg1: tensor<0xi32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
func @tensorlistSetItemWithScalarElements(%arg0: tensor<5xf32>, %arg1: tensor<0xi32>, %arg2: tensor<i32>, %arg3: tensor<f32>) -> tensor<5xf32> {
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<5xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
%1 = "tf.TensorListSetItem"(%0, %arg2, %arg3) : (tensor<!tf.variant<tensor<f32>>>, tensor<i32>, tensor<f32>) -> tensor<!tf.variant<tensor<f32>>>
%2 = "tf.TensorListStack"(%1, %arg1) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>) -> tensor<5xf32>
@ -89,8 +86,7 @@ func @tensorlistSetItemWithScalarElements(tensor<5xf32>, tensor<0xi32>, tensor<i
// CHECK: return %15 : tensor<5xf32>
}
func @tensorlistReserve(tensor<3xi32>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32> {
^bb0(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
func @tensorlistReserve(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<3xi32>) -> tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
@ -105,8 +101,7 @@ func @tensorlistReserve(tensor<3xi32>, tensor<i32>, tensor<i32>) -> tensor<?x?x?
// CHECK: return %3 : tensor<?x?x?xf32>
}
func @tensorlistReserveUnrankedElements(tensor<?xi32>, tensor<i32>, tensor<i32>) -> tensor<*xf32> {
^bb0(%arg0: tensor<?xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
func @tensorlistReserveUnrankedElements(%arg0: tensor<?xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<*xf32> {
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<?xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<*xf32>>>
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>, tensor<?xi32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
@ -117,8 +112,7 @@ func @tensorlistReserveUnrankedElements(tensor<?xi32>, tensor<i32>, tensor<i32>)
// CHECK: return [[RESULT2]] : tensor<*xf32>
}
func @tensorlistWhileLoop(tensor<2x3xf32>) -> tensor<*xf32> {
^bb0(%arg0: tensor<2x3xf32>):
func @tensorlistWhileLoop(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
%cst = constant dense<3> : tensor<1xi32>
%cst_0 = constant dense<0> : tensor<i32>
%cst_1 = constant dense<-1> : tensor<i32>
@ -136,8 +130,7 @@ func @tensorlistWhileLoop(tensor<2x3xf32>) -> tensor<*xf32> {
// CHECK: return %0#1 : tensor<*xf32>
}
func @tensorlistWhileBody(tensor<*xi32>, tensor<!tf.variant>) -> (tensor<*xi32>, tensor<!tf.variant>) {
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<!tf.variant>):
func @tensorlistWhileBody(%arg0: tensor<*xi32>, %arg1: tensor<!tf.variant>) -> (tensor<*xi32>, tensor<!tf.variant>) {
%cst = constant dense<1> : tensor<i32>
%0 = "tf.Add"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%1 = "tf.Identity"(%arg1) : (tensor<!tf.variant>) -> tensor<!tf.variant>
@ -151,8 +144,7 @@ func @tensorlistWhileBody(tensor<*xi32>, tensor<!tf.variant>) -> (tensor<*xi32>,
// CHECK: return %0, %1 : tensor<*xi32>, tensor<*xf32>
}
func @tensorlistWhileCond(tensor<*xi32>, tensor<!tf.variant>) -> tensor<*xi1> {
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<!tf.variant>):
func @tensorlistWhileCond(%arg0: tensor<*xi32>, %arg1: tensor<!tf.variant>) -> tensor<*xi1> {
%cst = constant dense<2> : tensor<i32>
%0 = "tf.Less"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
return %0 : tensor<*xi1>

View File

@ -82,7 +82,7 @@ struct LowerStaticTensorListPass
// Changes the function type of `cond_func` and `body_func`, and the result
// type of the `WhileOp`.
LogicalResult UpdateWhileFunctionType(TF::WhileOp *while_op);
LogicalResult UpdateWhileFunctionType(TF::WhileOp op);
};
Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter,
@ -100,10 +100,10 @@ Value *CreateI32SplatTensor(Operation *op, PatternRewriter *rewriter,
shape_tensor, scalar_val);
}
struct ConvertTFTensorListSetItem : public RewritePattern {
struct ConvertTFTensorListSetItem
: public OpRewritePattern<TF::TensorListSetItemOp> {
explicit ConvertTFTensorListSetItem(MLIRContext *context)
: RewritePattern(TF::TensorListSetItemOp::getOperationName(), 1,
context) {}
: OpRewritePattern<TF::TensorListSetItemOp>(context, 1) {}
// This function rewrites the original op into a series of slice and concat op
// to produce the same result. It first slices the first `$index` rows. Then
// expands the dimension of the `$item`, followed by another slice of the
@ -116,23 +116,21 @@ struct ConvertTFTensorListSetItem : public RewritePattern {
// (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
PatternMatchResult matchAndRewrite(Operation *op,
PatternMatchResult matchAndRewrite(TF::TensorListSetItemOp op,
PatternRewriter &rewriter) const override {
TF::TensorListSetItemOp tf_op = cast<TF::TensorListSetItemOp>(op);
auto input = tf_op.input_handle();
auto input = op.input_handle();
auto shape_dtype = rewriter.getIntegerType(32);
auto input_rank = rewriter.create<TF::RankOp>(
op->getLoc(), rewriter.getTensorType({}, shape_dtype), input);
auto item = tf_op.item();
op.getLoc(), rewriter.getTensorType({}, shape_dtype), input);
auto item = op.item();
auto item_rank = rewriter.create<TF::RankOp>(
op->getLoc(), rewriter.getTensorType({}, shape_dtype), item);
op.getLoc(), rewriter.getTensorType({}, shape_dtype), item);
// Prepare the start position for the first slice op, which is [0, 0, ..,
// 0].
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
auto position_shape = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), input_rank,
op.getLoc(), rewriter.getTensorType({1}, shape_dtype), input_rank,
scalar_zero);
// Fill all 0s into the first position tensor.
auto first_start_position =
@ -141,33 +139,33 @@ struct ConvertTFTensorListSetItem : public RewritePattern {
// Prepare the start position for the second slice op, which is
// [index + 1, 0, 0 .. 0].
// Calculate the first dimension, which is index + 1.
auto index = tf_op.index();
auto index = op.index();
auto vector_type = rewriter.getTensorType({1}, shape_dtype);
auto begin = rewriter.create<TF::AddOp>(
op->getLoc(), rewriter.getTensorType(shape_dtype), index,
op.getLoc(), rewriter.getTensorType(shape_dtype), index,
CreateI32SplatConst(op, &rewriter, {1}, 1));
// Followed by the first dimension `begin`, are `item_rank` of 0s.
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), item_rank,
op.getLoc(), rewriter.getTensorType({1}, shape_dtype), item_rank,
scalar_zero);
auto partial_second_start_position =
CreateI32SplatTensor(op, &rewriter, item_position_shape, 0);
auto position_type = first_start_position->getType();
// Concatenate `begin` with the remaining 0s.
auto second_start_position = rewriter.create<TF::ConcatOp>(
op->getLoc(), position_type, scalar_zero,
op.getLoc(), position_type, scalar_zero,
ArrayRef<Value *>({begin, partial_second_start_position}),
rewriter.getI64IntegerAttr(2));
// Create the size parameter for the first slice op, which is [index, -1,
// -1, .., -1].
auto size1_leading_dim = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), vector_type, index, scalar_zero);
op.getLoc(), vector_type, index, scalar_zero);
auto partial_size1 =
CreateI32SplatTensor(op, &rewriter, item_position_shape, -1);
auto size1 = rewriter.create<TF::ConcatOp>(
op->getLoc(), position_type, scalar_zero,
op.getLoc(), position_type, scalar_zero,
ArrayRef<Value *>({size1_leading_dim, partial_size1}),
rewriter.getI64IntegerAttr(2));
@ -179,14 +177,14 @@ struct ConvertTFTensorListSetItem : public RewritePattern {
auto element_type = input->getType().cast<TensorType>().getElementType();
auto unranked_tensor = rewriter.getTensorType(element_type);
auto slice1 = rewriter.create<TF::SliceOp>(
op->getLoc(), unranked_tensor, input, first_start_position, size1);
op.getLoc(), unranked_tensor, input, first_start_position, size1);
auto slice2 = rewriter.create<TF::SliceOp>(
op->getLoc(), unranked_tensor, input, second_start_position, size2);
op.getLoc(), unranked_tensor, input, second_start_position, size2);
// Expand the dimension of item so that it will have the same rank with
// input.
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), unranked_tensor, item, scalar_zero);
op.getLoc(), unranked_tensor, item, scalar_zero);
// Concatenate three parts together to generate the final result.
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
@ -198,26 +196,24 @@ struct ConvertTFTensorListSetItem : public RewritePattern {
}
};
struct ConvertTFTensorListReserve : public RewritePattern {
struct ConvertTFTensorListReserve
: public OpRewritePattern<TF::TensorListReserveOp> {
explicit ConvertTFTensorListReserve(MLIRContext *context)
: RewritePattern(TF::TensorListReserveOp::getOperationName(), 1,
context) {}
: OpRewritePattern<TF::TensorListReserveOp>(context, 1) {}
// Rewrites the original op into `tf.fill`. The result tensor shape is
// [num_element, element_shape]. All the values in the result tensor will be
// initialized to 0.
PatternMatchResult matchAndRewrite(Operation *op,
PatternMatchResult matchAndRewrite(TF::TensorListReserveOp op,
PatternRewriter &rewriter) const override {
TF::TensorListReserveOp tf_op = cast<TF::TensorListReserveOp>(op);
auto element_shape = tf_op.element_shape();
auto element_shape = op.element_shape();
auto shape_dtype = getElementTypeOrSelf(element_shape->getType());
auto num_elements = tf_op.num_elements();
Type element_dtype = tf_op.element_dtype();
auto num_elements = op.num_elements();
Type element_dtype = op.element_dtype();
int64_t result_rank = -1; // -1 means unknown result rank.
Type result_type = rewriter.getTensorType(element_dtype);
if (auto element_type = tf_op.element_type().dyn_cast<RankedTensorType>()) {
if (auto element_type = op.element_type().dyn_cast<RankedTensorType>()) {
result_rank = element_type.getRank() + 1;
// If element type is ranked, then result type will have unknown leading
// dimension and element shape for the following dimensions.
@ -234,7 +230,7 @@ struct ConvertTFTensorListReserve : public RewritePattern {
// element_shape].
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
auto leading_dim = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), num_elements,
op.getLoc(), rewriter.getTensorType({1}, shape_dtype), num_elements,
scalar_zero);
// Create a 1-D RankedTensorType for result's shape. Number of elements in
@ -243,7 +239,7 @@ struct ConvertTFTensorListReserve : public RewritePattern {
// specify dimension using rank of the result.
Type shape_type = rewriter.getTensorType({result_rank}, shape_dtype);
auto list_shape = rewriter.create<TF::ConcatOp>(
op->getLoc(), shape_type, scalar_zero,
op.getLoc(), shape_type, scalar_zero,
ArrayRef<Value *>({leading_dim, element_shape}),
rewriter.getI64IntegerAttr(2));
@ -251,7 +247,7 @@ struct ConvertTFTensorListReserve : public RewritePattern {
// as specified by element_dtype.
auto zero_type = rewriter.getTensorType({}, element_dtype);
auto zero_attr = rewriter.getZeroAttr(zero_type);
auto zero = rewriter.create<ConstantOp>(op->getLoc(), zero_type, zero_attr);
auto zero = rewriter.create<ConstantOp>(op.getLoc(), zero_type, zero_attr);
rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
return matchSuccess();
@ -267,17 +263,17 @@ namespace {
} // namespace TFL
LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
TF::WhileOp *while_op) {
TF::WhileOp op) {
SmallVector<Type, 8> unranked_argument_types;
for (const auto &operand : while_op->getOperands()) {
for (const auto &operand : op.getOperands()) {
unranked_argument_types.push_back(
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
}
auto *context = &getContext();
auto module = getModule();
FuncOp cond_func = module.lookupSymbol<FuncOp>(while_op->cond());
FuncOp body_func = module.lookupSymbol<FuncOp>(while_op->body());
FuncOp cond_func = module.lookupSymbol<FuncOp>(op.cond());
FuncOp body_func = module.lookupSymbol<FuncOp>(op.body());
if (cond_func) {
// Change `cond_func`'s argument types to `unranked_argument_types`.
@ -313,9 +309,9 @@ LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
}
}
for (int i = 0; i < while_op->getNumOperands(); ++i) {
auto operand = while_op->getOperand(i);
auto result = while_op->getResult(i);
for (int i = 0; i < op.getNumOperands(); ++i) {
auto operand = op.getOperand(i);
auto result = op.getResult(i);
if (getElementTypeOrSelf(result->getType()).isa<TF::VariantType>()) {
// If we notice the result type is a DT_VARIANT, we change the
// corresponding result type to unranked tensor type.
@ -357,7 +353,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
}
auto c = ConvertTFTensorListReserve(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
c.matchAndRewrite(tf_op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListGetItemOp>(op)) {
auto c = TFL::ConvertTFTensorListGetItem(context);
rewriter->setInsertionPoint(op);
@ -365,14 +361,14 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
auto c = ConvertTFTensorListSetItem(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
c.matchAndRewrite(tf_op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListStackOp>(op)) {
auto c = TFL::ConvertTFTensorListStack(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
UpdateWhileFunctionType(&tf_op);
UpdateWhileFunctionType(tf_op);
} else if (auto tf_op = llvm::dyn_cast<TF::IdentityOp>(op)) {
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
tf_op.getResult()->setType(tf_op.getOperand()->getType());