NFC: Use DialectConversion framework for TFLite TensorList rewrite pass

Currently, it invokes individual rewrite patterns for each of the ops to support type conversions. DialectConversion framework supports this use-case and making use of it simplifies the pass.

This pass still uses explicit type rewrite for While op functions because it involves type conversion depending on the context. This can not be done with type conversion in DialectConverter. For example, variant argument type can change to tensor<*xf32> or tensor<*xi32> depending on the caller. This can be improved on once we have shape inference support to rewrite functions with appropriate types.

Specific changes:
- Switch to ConversionPattern to implement individual op conversions
- Define patterns in C++ for some of the ops currently defined in TableGen. Conversion
  patterns are not supported by TableGen
- Define pattern for the While op instead of manually changing the result types
- Define target for the pass and invoke DialectConversion

PiperOrigin-RevId: 270103339
This commit is contained in:
Smit Hinsu 2019-09-19 12:48:04 -07:00 committed by TensorFlower Gardener
parent c0687c0e12
commit 39db8af832
3 changed files with 206 additions and 167 deletions

View File

@ -273,6 +273,7 @@ cc_library(
"@local_config_mlir//:QuantOps", "@local_config_mlir//:QuantOps",
"@local_config_mlir//:StandardOps", "@local_config_mlir//:StandardOps",
"@local_config_mlir//:Support", "@local_config_mlir//:Support",
"@local_config_mlir//:Transforms",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <climits> #include <climits>
#include <cstdint> #include <cstdint>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
@ -44,6 +45,7 @@ limitations under the License.
#include "mlir/Support/Functional.h" // TF:local_config_mlir #include "mlir/Support/Functional.h" // TF:local_config_mlir
#include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir #include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
@ -79,10 +81,6 @@ struct LowerStaticTensorListPass
// Apply type and op changes within a function. // Apply type and op changes within a function.
LogicalResult RewriteFunction(FuncOp func, LogicalResult RewriteFunction(FuncOp func,
TensorListPatternRewriter *rewriter); TensorListPatternRewriter *rewriter);
// Changes the function type of `cond_func` and `body_func`, and the result
// type of the `WhileOp`.
LogicalResult UpdateWhileFunctionType(TF::WhileOp op);
}; };
Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter, Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter,
@ -100,10 +98,11 @@ Value *CreateI32SplatTensor(Operation *op, PatternRewriter *rewriter,
shape_tensor, scalar_val); shape_tensor, scalar_val);
} }
struct ConvertTFTensorListSetItem struct ConvertTensorListSetItem : public ConversionPattern {
: public OpRewritePattern<TF::TensorListSetItemOp> { explicit ConvertTensorListSetItem(MLIRContext *context)
explicit ConvertTFTensorListSetItem(MLIRContext *context) : ConversionPattern(TF::TensorListSetItemOp::getOperationName(), 1,
: OpRewritePattern<TF::TensorListSetItemOp>(context, 1) {} context) {}
// This function rewrites the original op into a series of slice and concat op // This function rewrites the original op into a series of slice and concat op
// to produce the same result. It first slices the first `$index` rows. Then // to produce the same result. It first slices the first `$index` rows. Then
// expands the dimension of the `$item`, followed by another slice of the // expands the dimension of the `$item`, followed by another slice of the
@ -116,13 +115,17 @@ struct ConvertTFTensorListSetItem
// (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim = // (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice // 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>; // $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
PatternMatchResult matchAndRewrite(TF::TensorListSetItemOp op, PatternMatchResult matchAndRewrite(
PatternRewriter &rewriter) const override { Operation *operation, ArrayRef<Value *> operands,
auto input = op.input_handle(); ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListSetItemOp>(operation);
Value *input = operands[0];
Value *index = operands[1];
Value *item = operands[2];
auto shape_dtype = rewriter.getIntegerType(32); auto shape_dtype = rewriter.getIntegerType(32);
auto input_rank = rewriter.create<TF::RankOp>( auto input_rank = rewriter.create<TF::RankOp>(
op.getLoc(), rewriter.getTensorType({}, shape_dtype), input); op.getLoc(), rewriter.getTensorType({}, shape_dtype), input);
auto item = op.item();
auto item_rank = rewriter.create<TF::RankOp>( auto item_rank = rewriter.create<TF::RankOp>(
op.getLoc(), rewriter.getTensorType({}, shape_dtype), item); op.getLoc(), rewriter.getTensorType({}, shape_dtype), item);
@ -139,7 +142,6 @@ struct ConvertTFTensorListSetItem
// Prepare the start position for the second slice op, which is // Prepare the start position for the second slice op, which is
// [index + 1, 0, 0 .. 0]. // [index + 1, 0, 0 .. 0].
// Calculate the first dimension, which is index + 1. // Calculate the first dimension, which is index + 1.
auto index = op.index();
auto vector_type = rewriter.getTensorType({1}, shape_dtype); auto vector_type = rewriter.getTensorType({1}, shape_dtype);
auto begin = rewriter.create<TF::AddOp>( auto begin = rewriter.create<TF::AddOp>(
op.getLoc(), rewriter.getTensorType(shape_dtype), index, op.getLoc(), rewriter.getTensorType(shape_dtype), index,
@ -191,7 +193,6 @@ struct ConvertTFTensorListSetItem
op, input->getType(), scalar_zero, op, input->getType(), scalar_zero,
ArrayRef<Value *>({slice1, expanded_item, slice2}), ArrayRef<Value *>({slice1, expanded_item, slice2}),
rewriter.getI64IntegerAttr(3)); rewriter.getI64IntegerAttr(3));
return matchSuccess(); return matchSuccess();
} }
}; };
@ -200,20 +201,35 @@ struct ConvertTFTensorListSetItem
// to generate an equivalent raw tensor. Derived classes are required to // to generate an equivalent raw tensor. Derived classes are required to
// override GetNumElements method. // override GetNumElements method.
template <typename OpT> template <typename OpT>
struct ConvertTFTensorListInitOp : public OpRewritePattern<OpT> { struct ConvertTensorListInitOp : public ConversionPattern {
explicit ConvertTFTensorListInitOp(MLIRContext *context) explicit ConvertTensorListInitOp(MLIRContext *context)
: OpRewritePattern<OpT>(context, 1) {} : ConversionPattern(OpT::getOperationName(), 1, context) {}
// Create and return a 1-d tensor with exactly one element equal to the number // Create and return a 1-d tensor with exactly one element equal to the number
// of list elements to initialize the output tensor list with. // of list elements to initialize the output tensor list with.
virtual Value *GetNumElements(OpT op, PatternRewriter *rewriter) const = 0; virtual Value *GetNumElements(OpT op, ArrayRef<Value *> operands,
PatternRewriter *rewriter) const = 0;
// Rewrites the original op into `tf.fill`. The result tensor shape is // Rewrites the original op into `tf.fill`. The result tensor shape is
// [num_element, element_shape]. All the values in the result tensor will be // [num_element, element_shape]. All the values in the result tensor will be
// initialized to 0. // initialized to 0.
PatternMatchResult matchAndRewrite(OpT op, PatternMatchResult matchAndRewrite(
PatternRewriter &rewriter) const override { Operation *operation, ArrayRef<Value *> operands,
auto element_shape = op.element_shape(); ConversionPatternRewriter &rewriter) const override {
OpT op = llvm::cast<OpT>(operation);
Type dtype = op.element_dtype();
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
dtype.isInteger(32) || dtype.isInteger(64))) {
op.emitError(
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
"transformation pass");
return matchFailure();
}
Value *element_shape = operands[0];
auto shape_dtype = getElementTypeOrSelf(element_shape->getType()); auto shape_dtype = getElementTypeOrSelf(element_shape->getType());
Type element_dtype = op.element_dtype(); Type element_dtype = op.element_dtype();
@ -240,7 +256,7 @@ struct ConvertTFTensorListInitOp : public OpRewritePattern<OpT> {
// Add number of elements as the prefix to the element shape to get shape of // Add number of elements as the prefix to the element shape to get shape of
// the output tensor. // the output tensor.
auto leading_dim = GetNumElements(op, &rewriter); auto leading_dim = GetNumElements(op, operands, &rewriter);
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0); auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
auto list_shape = rewriter.create<TF::ConcatOp>( auto list_shape = rewriter.create<TF::ConcatOp>(
op.getLoc(), shape_type, scalar_zero, op.getLoc(), shape_type, scalar_zero,
@ -258,47 +274,47 @@ struct ConvertTFTensorListInitOp : public OpRewritePattern<OpT> {
} }
}; };
struct ConvertTFTensorListReserve struct ConvertTensorListReserve
: public ConvertTFTensorListInitOp<TF::TensorListReserveOp> { : public ConvertTensorListInitOp<TF::TensorListReserveOp> {
explicit ConvertTFTensorListReserve(MLIRContext *context) explicit ConvertTensorListReserve(MLIRContext *context)
: ConvertTFTensorListInitOp(context) {} : ConvertTensorListInitOp(context) {}
Value *GetNumElements(TF::TensorListReserveOp op, Value *GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value *> operands,
PatternRewriter *rewriter) const override { PatternRewriter *rewriter) const override {
auto scalar_zero = CreateI32SplatConst(op, rewriter, {}, 0); auto scalar_zero = CreateI32SplatConst(op, rewriter, {}, 0);
auto shape_dtype = getElementTypeOrSelf(op.element_shape()->getType()); auto shape_dtype = getElementTypeOrSelf(op.element_shape()->getType());
Value *num_elements = operands[1];
return rewriter->create<TF::ExpandDimsOp>( return rewriter->create<TF::ExpandDimsOp>(
op.getLoc(), rewriter->getTensorType({1}, shape_dtype), op.getLoc(), rewriter->getTensorType({1}, shape_dtype), num_elements,
op.num_elements(), scalar_zero); scalar_zero);
} }
}; };
// TODO(hinsu): Replace with declarative patterns once the RewriterGen infra
// supports patterns involving variadic operand ops.
//
// Note that we ignore the second operand `max_num_elements` as we don't have // Note that we ignore the second operand `max_num_elements` as we don't have
// any restrictions on the number of elements we can support. So this may // any restrictions on the number of elements we can support. So this may
// have a different behavior compared to TensorFlow in case of errors. // have a different behavior compared to TensorFlow in case of errors.
struct ConvertTFEmptyTensorList struct ConvertEmptyTensorList
: public ConvertTFTensorListInitOp<TF::EmptyTensorListOp> { : public ConvertTensorListInitOp<TF::EmptyTensorListOp> {
explicit ConvertTFEmptyTensorList(MLIRContext *context) explicit ConvertEmptyTensorList(MLIRContext *context)
: ConvertTFTensorListInitOp(context) {} : ConvertTensorListInitOp(context) {}
Value *GetNumElements(TF::EmptyTensorListOp op, Value *GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value *> operands,
PatternRewriter *rewriter) const override { PatternRewriter *rewriter) const override {
return CreateI32SplatConst(op, rewriter, {1}, 0); return CreateI32SplatConst(op, rewriter, {1}, 0);
} }
}; };
struct ConvertTFTensorListPushBack : public RewritePattern { struct ConvertTensorListPushBack : public ConversionPattern {
explicit ConvertTFTensorListPushBack(MLIRContext *context) explicit ConvertTensorListPushBack(MLIRContext *context)
: RewritePattern(TF::TensorListPushBackOp::getOperationName(), 1, : ConversionPattern(TF::TensorListPushBackOp::getOperationName(), 1,
context) {} context) {}
PatternMatchResult matchAndRewrite(Operation *op, PatternMatchResult matchAndRewrite(
PatternRewriter &rewriter) const override { Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
TF::TensorListPushBackOp push_back_op = cast<TF::TensorListPushBackOp>(op); TF::TensorListPushBackOp push_back_op = cast<TF::TensorListPushBackOp>(op);
Value *item = push_back_op.tensor(); Value *input_handle = operands[0];
Value *item = operands[1];
Type dtype = getElementTypeOrSelf(*item); Type dtype = getElementTypeOrSelf(*item);
// Returns a new type by prepending the specified dimension to the shape of // Returns a new type by prepending the specified dimension to the shape of
@ -334,149 +350,182 @@ struct ConvertTFTensorListPushBack : public RewritePattern {
// Concatenate tensor stored in the input handle with the expanded item to // Concatenate tensor stored in the input handle with the expanded item to
// get a tensor equivalent to the TensorList generated by this op. // get a tensor equivalent to the TensorList generated by this op.
rewriter.replaceOpWithNewOp<TF::ConcatOp>( rewriter.replaceOpWithNewOp<TF::ConcatOp>(
op, result_type, scalar_zero, push_back_op, result_type, scalar_zero,
ArrayRef<Value *>({push_back_op.input_handle(), expanded_item}), ArrayRef<Value *>({input_handle, expanded_item}),
rewriter.getI64IntegerAttr(2)); rewriter.getI64IntegerAttr(2));
return matchSuccess(); return matchSuccess();
} }
}; };
} // namespace struct ConvertTensorListGetItem : public ConversionPattern {
explicit ConvertTensorListGetItem(MLIRContext *context)
: ConversionPattern(TF::TensorListGetItemOp::getOperationName(), 1,
context) {}
namespace TFL { PatternMatchResult matchAndRewrite(
namespace { Operation *operation, ArrayRef<Value *> operands,
#include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc" ConversionPatternRewriter &rewriter) const override {
} // namespace auto op = llvm::cast<TF::TensorListGetItemOp>(operation);
} // namespace TFL Value *input = operands[0];
Value *index = operands[1];
rewriter.replaceOpWithNewOp<TF::GatherOp>(
operation, op.getType(), input, index, rewriter.getBoolAttr(true));
return matchSuccess();
}
};
LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType( struct ConvertTensorListLength : public ConversionPattern {
TF::WhileOp op) { explicit ConvertTensorListLength(MLIRContext *context)
: ConversionPattern(TF::TensorListLengthOp::getOperationName(), 1,
context) {}
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListLengthOp>(operation);
Location loc = op.getLoc();
Value *input_handle = operands[0];
BoolAttr true_attr = rewriter.getBoolAttr(true);
auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
/*use_32bit=*/true_attr);
rewriter.replaceOpWithNewOp<TF::GatherOp>(
op, op.getType(), shape, CreateI32SplatConst(op, &rewriter, {}, 0),
/*validate_indices=*/true_attr);
return matchSuccess();
}
};
struct ConvertIdentity : public ConversionPattern {
explicit ConvertIdentity(MLIRContext *context)
: ConversionPattern(TF::IdentityOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::IdentityOp>(operation);
Value *input = operands[0];
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input->getType(), operands,
op.getAttrs());
return matchSuccess();
}
};
// Changes the function type of `cond_func` and `body_func` for the given While
// op.
static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
SmallVector<Type, 8> unranked_argument_types; SmallVector<Type, 8> unranked_argument_types;
for (const auto &operand : op.getOperands()) { for (const auto &operand : op.getOperands()) {
unranked_argument_types.push_back( unranked_argument_types.push_back(
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType()))); UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
} }
auto *context = &getContext(); auto module = op.getParentOfType<ModuleOp>();
auto module = getModule(); auto *context = module.getContext();
FuncOp cond_func = module.lookupSymbol<FuncOp>(op.cond());
FuncOp body_func = module.lookupSymbol<FuncOp>(op.body());
if (cond_func) { for (StringRef func_name : {op.cond(), op.body()}) {
// Change `cond_func`'s argument types to `unranked_argument_types`. FuncOp func = module.lookupSymbol<FuncOp>(func_name);
cond_func.setType(FunctionType::get( if (!func) continue;
unranked_argument_types, cond_func.getType().getResults(), context)); auto num_results = func.getType().getNumResults();
// Change the argument type for the first block.
Block &cond_first_bb = cond_func.front();
for (int i = 0; i < cond_first_bb.getNumArguments(); ++i) {
cond_first_bb.getArgument(i)->setType(unranked_argument_types[i]);
}
}
if (body_func) {
SmallVector<Type, 8> updated_result_types; SmallVector<Type, 8> updated_result_types;
for (int i = 0; i < body_func.getType().getNumResults(); ++i) { updated_result_types.reserve(num_results);
auto result_type = body_func.getType().getResult(i); for (int i = 0; i < num_results; ++i) {
Type result_type = func.getType().getResult(i);
if (getElementTypeOrSelf(result_type).isa<TF::VariantType>()) { if (getElementTypeOrSelf(result_type).isa<TF::VariantType>()) {
// For variant type, use the corresponding unranked type. // For variant type, use the corresponding unranked type.
updated_result_types.push_back(unranked_argument_types[i]); result_type = unranked_argument_types[i];
} else {
updated_result_types.push_back(result_type);
} }
updated_result_types.push_back(result_type);
} }
// Change `body_func`'s argument type to `unranked_argument_types`. If it
// Change `func`'s argument type to `unranked_argument_types`. If it
// return types contain a `DT_VARIANT`, change it to the unranked type // return types contain a `DT_VARIANT`, change it to the unranked type
// derived from the corresponding argument. // derived from the corresponding argument.
body_func.setType(FunctionType::get(unranked_argument_types, func.setType(FunctionType::get(unranked_argument_types,
updated_result_types, context)); updated_result_types, context));
// Change the argument type for the first block. // Change the argument type for the first block.
Block &body_first_bb = body_func.front(); Block &body_first_bb = func.front();
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) { for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
body_first_bb.getArgument(i)->setType(unranked_argument_types[i]); body_first_bb.getArgument(i)->setType(unranked_argument_types[i]);
} }
} }
for (int i = 0; i < op.getNumOperands(); ++i) {
auto operand = op.getOperand(i);
auto result = op.getResult(i);
if (getElementTypeOrSelf(result->getType()).isa<TF::VariantType>()) {
// If we notice the result type is a DT_VARIANT, we change the
// corresponding result type to unranked tensor type.
result->setType(
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
}
}
return success(); return success();
} }
struct ConvertWhile : public ConversionPattern {
explicit ConvertWhile(MLIRContext *context)
: ConversionPattern(TF::WhileOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::WhileOp>(operation);
llvm::SmallVector<Type, 8> result_types;
result_types.reserve(op.getNumOperands());
for (int i = 0, e = operands.size(); i != e; ++i) {
Type result_ty = op.getResult(i)->getType();
// If we notice the result type is a DT_VARIANT, we change the
// corresponding result type to unranked tensor type.
if (getElementTypeOrSelf(result_ty).isa<TF::VariantType>()) {
Type element_ty = getElementTypeOrSelf(operands[i]->getType());
result_ty = UnrankedTensorType::get(element_ty);
}
result_types.push_back(result_ty);
}
// Clone original while op with new operands and updated result types.
auto cloned = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
operands, op.getAttrs());
cloned.removeAttr("T");
UpdateFunctionTypes(cloned);
SmallVector<Value *, 8> results(cloned.getResults());
rewriter.replaceOp(op, results);
return matchSuccess();
}
};
#include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
LogicalResult LowerStaticTensorListPass::RewriteFunction( LogicalResult LowerStaticTensorListPass::RewriteFunction(
FuncOp func, TensorListPatternRewriter *rewriter) { FuncOp func, TensorListPatternRewriter *rewriter) {
auto *context = &getContext(); auto *context = &getContext();
for (Block &block : func) { // TensorFlow operations that doesn't have operands and results of type
// Buffer the op pointers inside the current block into a vector, since // variant are legal. Here, we don't distinguish between variants encoding
// the block iterator might be invalidated if we rewrite ops during looping. // TensorList or some other type as that information is not available here.
std::vector<Operation *> ops_in_block; // This constraint should be relaxed to support other variant types in TFLite.
for (Operation &op : block) { auto is_legal = [](Operation *op) {
ops_in_block.push_back(&op); auto is_not_variant = [](Type ty) {
} return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
};
return llvm::all_of(op->getOperandTypes(), is_not_variant) &&
llvm::all_of(op->getResultTypes(), is_not_variant);
};
for (Operation *op : ops_in_block) { ConversionTarget target(*context);
if (auto tf_op = llvm::dyn_cast<TF::TensorListFromTensorOp>(op)) { target.addDynamicallyLegalDialect<TF::TensorFlowDialect>(
auto c = TFL::ConvertTFTensorListFromTensor(context); llvm::Optional<ConversionTarget::DynamicLegalityCallbackFn>(is_legal));
rewriter->setInsertionPoint(op); target.addIllegalOp<TF::EmptyTensorListOp, TF::TensorListFromTensorOp,
c.matchAndRewrite(op, *rewriter); TF::TensorListGetItemOp, TF::TensorListLengthOp,
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListReserveOp>(op)) { TF::TensorListPushBackOp, TF::TensorListReserveOp,
if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() || TF::TensorListSetItemOp, TF::TensorListStackOp>();
tf_op.element_dtype().isF64() || // TODO(hinsu): Use TFLite constant op for constants.
tf_op.element_dtype().isInteger(1) || target.addLegalOp<ConstantOp>();
tf_op.element_dtype().isInteger(8) || target.addLegalOp<FuncOp>();
tf_op.element_dtype().isInteger(16) || target.addLegalOp<ReturnOp>();
tf_op.element_dtype().isInteger(32) ||
tf_op.element_dtype().isInteger(64))) { OwningRewritePatternList patterns;
return tf_op.emitError( patterns.insert<ConvertEmptyTensorList, ConvertIdentity,
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit " ConvertTensorListFromTensor, ConvertTensorListGetItem,
"integer " ConvertTensorListLength, ConvertTensorListPushBack,
"or 16-bit/32-bit/64-bit " ConvertTensorListReserve, ConvertTensorListSetItem,
"float type during TF Lite transformation pass"); ConvertTensorListStack, ConvertWhile>(context);
} return applyFullConversion(func, target, patterns);
auto c = ConvertTFTensorListReserve(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(tf_op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::EmptyTensorListOp>(op)) {
auto c = ConvertTFEmptyTensorList(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(tf_op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListGetItemOp>(op)) {
auto c = TFL::ConvertTFTensorListGetItem(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
auto c = ConvertTFTensorListSetItem(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(tf_op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListStackOp>(op)) {
auto c = TFL::ConvertTFTensorListStack(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListPushBackOp>(op)) {
auto c = ConvertTFTensorListPushBack(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListLengthOp>(op)) {
auto c = TFL::ConvertTFTensorListLength(context);
rewriter->setInsertionPoint(op);
c.matchAndRewrite(op, *rewriter);
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
UpdateWhileFunctionType(tf_op);
} else if (auto tf_op = llvm::dyn_cast<TF::IdentityOp>(op)) {
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
tf_op.getResult()->setType(tf_op.getOperand()->getType());
}
}
}
return success();
} }
void LowerStaticTensorListPass::runOnModule() { void LowerStaticTensorListPass::runOnModule() {
@ -503,6 +552,8 @@ void LowerStaticTensorListPass::runOnModule() {
} }
} }
} // namespace
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList /// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
/// pass. /// pass.
std::unique_ptr<OpPassBase<ModuleOp>> TFL::CreateLowerStaticTensorListPass() { std::unique_ptr<OpPassBase<ModuleOp>> TFL::CreateLowerStaticTensorListPass() {

View File

@ -15,7 +15,6 @@ limitations under the License.
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/Dialect/StandardOps/Ops.td" include "mlir/Dialect/StandardOps/Ops.td"
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
def CreateTFShapeOp : NativeCodeCall< def CreateTFShapeOp : NativeCodeCall<
@ -27,22 +26,10 @@ def CreateTFShapeOp : NativeCodeCall<
// into regular tensors. We also assume that each element in the `TensorList` has // into regular tensors. We also assume that each element in the `TensorList` has
// a same constant shape. // a same constant shape.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def ConvertTFTensorListFromTensor : Pat< def ConvertTensorListFromTensor : Pat<
(TF_TensorListFromTensorOp $tensor, $element_shape), (TF_TensorListFromTensorOp $tensor, $element_shape),
(replaceWithValue $tensor)>; (replaceWithValue $tensor)>;
def ConvertTFTensorListStack : Pat< def ConvertTensorListStack : Pat<
(TF_TensorListStackOp $input, $element_shape, $num_elements), (TF_TensorListStackOp $input, $element_shape, $num_elements),
(replaceWithValue $input)>; (replaceWithValue $input)>;
def ConvertTFTensorListGetItem : Pat<
(TF_TensorListGetItemOp $input, $index, $element_shape),
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;
// TensorListLength is equivalent to the size of the first dimension of the
// input tensorlist, rewrite it to a combination of Gather and Shape op.
def ConvertTFTensorListLength: Pat<
(TF_TensorListLengthOp:$old_value $input),
(TF_GatherOp
(CreateTFShapeOp $old_value, $input, /*use 32bit*/ConstBoolAttrTrue),
(ConstantOp ConstantAttr<I32ElementsAttr, "0">), ConstBoolAttrTrue)>;