NFC: Improve tfl-lower-static-tensor-list pass and tests
Specifically, * Use better syntax in tests by naming arguments in function signatures and not having any explicit block * Pass WhileOp by value * Use OpRewritePattern instead of RewritePattern for the patterns PiperOrigin-RevId: 261007552
This commit is contained in:
parent
58dfc47ef1
commit
6b31ac341a
@ -1,6 +1,6 @@
|
||||
// RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure
|
||||
func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
|
||||
^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
|
||||
|
||||
func @tensorlistGetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> (tensor<10xf32>, tensor<3x10xf32>) {
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>, tensor<1xi32>) -> tensor<10xf32>
|
||||
%2 = "tf.TensorListStack"(%0, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<1xi32>) -> tensor<3x10xf32>
|
||||
@ -11,8 +11,7 @@ func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>) -> (tensor
|
||||
// CHECK: return %0, %arg0 : tensor<10xf32>, tensor<3x10xf32>
|
||||
}
|
||||
|
||||
func @tensorlistGetItemWithUnknownRank(tensor<*xf32>, tensor<1xi32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
^bb0(%arg0: tensor<*xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>):
|
||||
func @tensorlistGetItemWithUnknownRank(%arg0: tensor<*xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<*xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<*xf32>>>
|
||||
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>, tensor<1xi32>) -> tensor<*xf32>
|
||||
%2 = "tf.TensorListStack"(%0, %arg1) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<1xi32>) -> tensor<*xf32>
|
||||
@ -23,8 +22,7 @@ func @tensorlistGetItemWithUnknownRank(tensor<*xf32>, tensor<1xi32>, tensor<i32>
|
||||
// CHECK: return %0, %arg0 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
func @tensorlistSetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>, tensor<10xf32>) -> tensor<3x10xf32> {
|
||||
^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>, %arg3: tensor<10xf32>):
|
||||
func @tensorlistSetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>, %arg3: tensor<10xf32>) -> tensor<3x10xf32> {
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||
%1 = "tf.TensorListSetItem"(%0, %arg2, %arg3) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>, tensor<10xf32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||
%2 = "tf.TensorListStack"(%1, %arg1) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<1xi32>) -> tensor<3x10xf32>
|
||||
@ -56,8 +54,7 @@ func @tensorlistSetItem(tensor<3x10xf32>, tensor<1xi32>, tensor<i32>, tensor<10x
|
||||
// CHECK: return %15 : tensor<3x10xf32>
|
||||
}
|
||||
|
||||
func @tensorlistSetItemWithScalarElements(tensor<5xf32>, tensor<0xi32>, tensor<i32>, tensor<f32>) -> tensor<5xf32> {
|
||||
^bb0(%arg0: tensor<5xf32>, %arg1: tensor<0xi32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
|
||||
func @tensorlistSetItemWithScalarElements(%arg0: tensor<5xf32>, %arg1: tensor<0xi32>, %arg2: tensor<i32>, %arg3: tensor<f32>) -> tensor<5xf32> {
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<5xf32>, tensor<0xi32>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
%1 = "tf.TensorListSetItem"(%0, %arg2, %arg3) : (tensor<!tf.variant<tensor<f32>>>, tensor<i32>, tensor<f32>) -> tensor<!tf.variant<tensor<f32>>>
|
||||
%2 = "tf.TensorListStack"(%1, %arg1) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>) -> tensor<5xf32>
|
||||
@ -89,8 +86,7 @@ func @tensorlistSetItemWithScalarElements(tensor<5xf32>, tensor<0xi32>, tensor<i
|
||||
// CHECK: return %15 : tensor<5xf32>
|
||||
}
|
||||
|
||||
func @tensorlistReserve(tensor<3xi32>, tensor<i32>, tensor<i32>) -> tensor<?x?x?xf32> {
|
||||
^bb0(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
|
||||
func @tensorlistReserve(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
|
||||
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
|
||||
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<?x?x?xf32>>>, tensor<i32>, tensor<3xi32>) -> tensor<?x?x?xf32>
|
||||
return %1 : tensor<?x?x?xf32>
|
||||
@ -105,8 +101,7 @@ func @tensorlistReserve(tensor<3xi32>, tensor<i32>, tensor<i32>) -> tensor<?x?x?
|
||||
// CHECK: return %3 : tensor<?x?x?xf32>
|
||||
}
|
||||
|
||||
func @tensorlistReserveUnrankedElements(tensor<?xi32>, tensor<i32>, tensor<i32>) -> tensor<*xf32> {
|
||||
^bb0(%arg0: tensor<?xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>):
|
||||
func @tensorlistReserveUnrankedElements(%arg0: tensor<?xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<*xf32> {
|
||||
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<?xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<*xf32>>>
|
||||
%1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
return %1 : tensor<*xf32>
|
||||
@ -117,8 +112,7 @@ func @tensorlistReserveUnrankedElements(tensor<?xi32>, tensor<i32>, tensor<i32>)
|
||||
// CHECK: return [[RESULT2]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @tensorlistWhileLoop(tensor<2x3xf32>) -> tensor<*xf32> {
|
||||
^bb0(%arg0: tensor<2x3xf32>):
|
||||
func @tensorlistWhileLoop(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant dense<3> : tensor<1xi32>
|
||||
%cst_0 = constant dense<0> : tensor<i32>
|
||||
%cst_1 = constant dense<-1> : tensor<i32>
|
||||
@ -136,8 +130,7 @@ func @tensorlistWhileLoop(tensor<2x3xf32>) -> tensor<*xf32> {
|
||||
// CHECK: return %0#1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @tensorlistWhileBody(tensor<*xi32>, tensor<!tf.variant>) -> (tensor<*xi32>, tensor<!tf.variant>) {
|
||||
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<!tf.variant>):
|
||||
func @tensorlistWhileBody(%arg0: tensor<*xi32>, %arg1: tensor<!tf.variant>) -> (tensor<*xi32>, tensor<!tf.variant>) {
|
||||
%cst = constant dense<1> : tensor<i32>
|
||||
%0 = "tf.Add"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
|
||||
%1 = "tf.Identity"(%arg1) : (tensor<!tf.variant>) -> tensor<!tf.variant>
|
||||
@ -151,8 +144,7 @@ func @tensorlistWhileBody(tensor<*xi32>, tensor<!tf.variant>) -> (tensor<*xi32>,
|
||||
// CHECK: return %0, %1 : tensor<*xi32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
func @tensorlistWhileCond(tensor<*xi32>, tensor<!tf.variant>) -> tensor<*xi1> {
|
||||
^bb0(%arg0: tensor<*xi32>, %arg1: tensor<!tf.variant>):
|
||||
func @tensorlistWhileCond(%arg0: tensor<*xi32>, %arg1: tensor<!tf.variant>) -> tensor<*xi1> {
|
||||
%cst = constant dense<2> : tensor<i32>
|
||||
%0 = "tf.Less"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
|
||||
return %0 : tensor<*xi1>
|
||||
|
@ -82,7 +82,7 @@ struct LowerStaticTensorListPass
|
||||
|
||||
// Changes the function type of `cond_func` and `body_func`, and the result
|
||||
// type of the `WhileOp`.
|
||||
LogicalResult UpdateWhileFunctionType(TF::WhileOp *while_op);
|
||||
LogicalResult UpdateWhileFunctionType(TF::WhileOp op);
|
||||
};
|
||||
|
||||
Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter,
|
||||
@ -100,10 +100,10 @@ Value *CreateI32SplatTensor(Operation *op, PatternRewriter *rewriter,
|
||||
shape_tensor, scalar_val);
|
||||
}
|
||||
|
||||
struct ConvertTFTensorListSetItem : public RewritePattern {
|
||||
struct ConvertTFTensorListSetItem
|
||||
: public OpRewritePattern<TF::TensorListSetItemOp> {
|
||||
explicit ConvertTFTensorListSetItem(MLIRContext *context)
|
||||
: RewritePattern(TF::TensorListSetItemOp::getOperationName(), 1,
|
||||
context) {}
|
||||
: OpRewritePattern<TF::TensorListSetItemOp>(context, 1) {}
|
||||
// This function rewrites the original op into a series of slice and concat op
|
||||
// to produce the same result. It first slices the first `$index` rows. Then
|
||||
// expands the dimension of the `$item`, followed by another slice of the
|
||||
@ -116,23 +116,21 @@ struct ConvertTFTensorListSetItem : public RewritePattern {
|
||||
// (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
|
||||
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
|
||||
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternMatchResult matchAndRewrite(TF::TensorListSetItemOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
TF::TensorListSetItemOp tf_op = cast<TF::TensorListSetItemOp>(op);
|
||||
|
||||
auto input = tf_op.input_handle();
|
||||
auto input = op.input_handle();
|
||||
auto shape_dtype = rewriter.getIntegerType(32);
|
||||
auto input_rank = rewriter.create<TF::RankOp>(
|
||||
op->getLoc(), rewriter.getTensorType({}, shape_dtype), input);
|
||||
auto item = tf_op.item();
|
||||
op.getLoc(), rewriter.getTensorType({}, shape_dtype), input);
|
||||
auto item = op.item();
|
||||
auto item_rank = rewriter.create<TF::RankOp>(
|
||||
op->getLoc(), rewriter.getTensorType({}, shape_dtype), item);
|
||||
op.getLoc(), rewriter.getTensorType({}, shape_dtype), item);
|
||||
|
||||
// Prepare the start position for the first slice op, which is [0, 0, ..,
|
||||
// 0].
|
||||
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
|
||||
auto position_shape = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), input_rank,
|
||||
op.getLoc(), rewriter.getTensorType({1}, shape_dtype), input_rank,
|
||||
scalar_zero);
|
||||
// Fill all 0s into the first position tensor.
|
||||
auto first_start_position =
|
||||
@ -141,33 +139,33 @@ struct ConvertTFTensorListSetItem : public RewritePattern {
|
||||
// Prepare the start position for the second slice op, which is
|
||||
// [index + 1, 0, 0 .. 0].
|
||||
// Calculate the first dimension, which is index + 1.
|
||||
auto index = tf_op.index();
|
||||
auto index = op.index();
|
||||
auto vector_type = rewriter.getTensorType({1}, shape_dtype);
|
||||
auto begin = rewriter.create<TF::AddOp>(
|
||||
op->getLoc(), rewriter.getTensorType(shape_dtype), index,
|
||||
op.getLoc(), rewriter.getTensorType(shape_dtype), index,
|
||||
CreateI32SplatConst(op, &rewriter, {1}, 1));
|
||||
|
||||
// Followed by the first dimension `begin`, are `item_rank` of 0s.
|
||||
auto item_position_shape = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), item_rank,
|
||||
op.getLoc(), rewriter.getTensorType({1}, shape_dtype), item_rank,
|
||||
scalar_zero);
|
||||
auto partial_second_start_position =
|
||||
CreateI32SplatTensor(op, &rewriter, item_position_shape, 0);
|
||||
auto position_type = first_start_position->getType();
|
||||
// Concatenate `begin` with the remaining 0s.
|
||||
auto second_start_position = rewriter.create<TF::ConcatOp>(
|
||||
op->getLoc(), position_type, scalar_zero,
|
||||
op.getLoc(), position_type, scalar_zero,
|
||||
ArrayRef<Value *>({begin, partial_second_start_position}),
|
||||
rewriter.getI64IntegerAttr(2));
|
||||
|
||||
// Create the size parameter for the first slice op, which is [index, -1,
|
||||
// -1, .., -1].
|
||||
auto size1_leading_dim = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), vector_type, index, scalar_zero);
|
||||
op.getLoc(), vector_type, index, scalar_zero);
|
||||
auto partial_size1 =
|
||||
CreateI32SplatTensor(op, &rewriter, item_position_shape, -1);
|
||||
auto size1 = rewriter.create<TF::ConcatOp>(
|
||||
op->getLoc(), position_type, scalar_zero,
|
||||
op.getLoc(), position_type, scalar_zero,
|
||||
ArrayRef<Value *>({size1_leading_dim, partial_size1}),
|
||||
rewriter.getI64IntegerAttr(2));
|
||||
|
||||
@ -179,14 +177,14 @@ struct ConvertTFTensorListSetItem : public RewritePattern {
|
||||
auto element_type = input->getType().cast<TensorType>().getElementType();
|
||||
auto unranked_tensor = rewriter.getTensorType(element_type);
|
||||
auto slice1 = rewriter.create<TF::SliceOp>(
|
||||
op->getLoc(), unranked_tensor, input, first_start_position, size1);
|
||||
op.getLoc(), unranked_tensor, input, first_start_position, size1);
|
||||
auto slice2 = rewriter.create<TF::SliceOp>(
|
||||
op->getLoc(), unranked_tensor, input, second_start_position, size2);
|
||||
op.getLoc(), unranked_tensor, input, second_start_position, size2);
|
||||
|
||||
// Expand the dimension of item so that it will have the same rank with
|
||||
// input.
|
||||
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), unranked_tensor, item, scalar_zero);
|
||||
op.getLoc(), unranked_tensor, item, scalar_zero);
|
||||
|
||||
// Concatenate three parts together to generate the final result.
|
||||
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
|
||||
@ -198,26 +196,24 @@ struct ConvertTFTensorListSetItem : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTFTensorListReserve : public RewritePattern {
|
||||
struct ConvertTFTensorListReserve
|
||||
: public OpRewritePattern<TF::TensorListReserveOp> {
|
||||
explicit ConvertTFTensorListReserve(MLIRContext *context)
|
||||
: RewritePattern(TF::TensorListReserveOp::getOperationName(), 1,
|
||||
context) {}
|
||||
: OpRewritePattern<TF::TensorListReserveOp>(context, 1) {}
|
||||
|
||||
// Rewrites the original op into `tf.fill`. The result tensor shape is
|
||||
// [num_element, element_shape]. All the values in the result tensor will be
|
||||
// initialized to 0.
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternMatchResult matchAndRewrite(TF::TensorListReserveOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
TF::TensorListReserveOp tf_op = cast<TF::TensorListReserveOp>(op);
|
||||
|
||||
auto element_shape = tf_op.element_shape();
|
||||
auto element_shape = op.element_shape();
|
||||
auto shape_dtype = getElementTypeOrSelf(element_shape->getType());
|
||||
auto num_elements = tf_op.num_elements();
|
||||
Type element_dtype = tf_op.element_dtype();
|
||||
auto num_elements = op.num_elements();
|
||||
Type element_dtype = op.element_dtype();
|
||||
|
||||
int64_t result_rank = -1; // -1 means unknown result rank.
|
||||
Type result_type = rewriter.getTensorType(element_dtype);
|
||||
if (auto element_type = tf_op.element_type().dyn_cast<RankedTensorType>()) {
|
||||
if (auto element_type = op.element_type().dyn_cast<RankedTensorType>()) {
|
||||
result_rank = element_type.getRank() + 1;
|
||||
// If element type is ranked, then result type will have unknown leading
|
||||
// dimension and element shape for the following dimensions.
|
||||
@ -234,7 +230,7 @@ struct ConvertTFTensorListReserve : public RewritePattern {
|
||||
// element_shape].
|
||||
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
|
||||
auto leading_dim = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), rewriter.getTensorType({1}, shape_dtype), num_elements,
|
||||
op.getLoc(), rewriter.getTensorType({1}, shape_dtype), num_elements,
|
||||
scalar_zero);
|
||||
|
||||
// Create a 1-D RankedTensorType for result's shape. Number of elements in
|
||||
@ -243,7 +239,7 @@ struct ConvertTFTensorListReserve : public RewritePattern {
|
||||
// specify dimension using rank of the result.
|
||||
Type shape_type = rewriter.getTensorType({result_rank}, shape_dtype);
|
||||
auto list_shape = rewriter.create<TF::ConcatOp>(
|
||||
op->getLoc(), shape_type, scalar_zero,
|
||||
op.getLoc(), shape_type, scalar_zero,
|
||||
ArrayRef<Value *>({leading_dim, element_shape}),
|
||||
rewriter.getI64IntegerAttr(2));
|
||||
|
||||
@ -251,7 +247,7 @@ struct ConvertTFTensorListReserve : public RewritePattern {
|
||||
// as specified by element_dtype.
|
||||
auto zero_type = rewriter.getTensorType({}, element_dtype);
|
||||
auto zero_attr = rewriter.getZeroAttr(zero_type);
|
||||
auto zero = rewriter.create<ConstantOp>(op->getLoc(), zero_type, zero_attr);
|
||||
auto zero = rewriter.create<ConstantOp>(op.getLoc(), zero_type, zero_attr);
|
||||
|
||||
rewriter.replaceOpWithNewOp<TF::FillOp>(op, result_type, list_shape, zero);
|
||||
return matchSuccess();
|
||||
@ -267,17 +263,17 @@ namespace {
|
||||
} // namespace TFL
|
||||
|
||||
LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
|
||||
TF::WhileOp *while_op) {
|
||||
TF::WhileOp op) {
|
||||
SmallVector<Type, 8> unranked_argument_types;
|
||||
for (const auto &operand : while_op->getOperands()) {
|
||||
for (const auto &operand : op.getOperands()) {
|
||||
unranked_argument_types.push_back(
|
||||
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
|
||||
}
|
||||
|
||||
auto *context = &getContext();
|
||||
auto module = getModule();
|
||||
FuncOp cond_func = module.lookupSymbol<FuncOp>(while_op->cond());
|
||||
FuncOp body_func = module.lookupSymbol<FuncOp>(while_op->body());
|
||||
FuncOp cond_func = module.lookupSymbol<FuncOp>(op.cond());
|
||||
FuncOp body_func = module.lookupSymbol<FuncOp>(op.body());
|
||||
|
||||
if (cond_func) {
|
||||
// Change `cond_func`'s argument types to `unranked_argument_types`.
|
||||
@ -313,9 +309,9 @@ LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < while_op->getNumOperands(); ++i) {
|
||||
auto operand = while_op->getOperand(i);
|
||||
auto result = while_op->getResult(i);
|
||||
for (int i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto operand = op.getOperand(i);
|
||||
auto result = op.getResult(i);
|
||||
if (getElementTypeOrSelf(result->getType()).isa<TF::VariantType>()) {
|
||||
// If we notice the result type is a DT_VARIANT, we change the
|
||||
// corresponding result type to unranked tensor type.
|
||||
@ -357,7 +353,7 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
}
|
||||
auto c = ConvertTFTensorListReserve(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
c.matchAndRewrite(tf_op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListGetItemOp>(op)) {
|
||||
auto c = TFL::ConvertTFTensorListGetItem(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
@ -365,14 +361,14 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
|
||||
auto c = ConvertTFTensorListSetItem(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
c.matchAndRewrite(tf_op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListStackOp>(op)) {
|
||||
auto c = TFL::ConvertTFTensorListStack(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
|
||||
UpdateWhileFunctionType(&tf_op);
|
||||
UpdateWhileFunctionType(tf_op);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::IdentityOp>(op)) {
|
||||
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
|
||||
tf_op.getResult()->setType(tf_op.getOperand()->getType());
|
||||
|
Loading…
x
Reference in New Issue
Block a user