NFC: Use DialectConversion framework for TFLite TensorList rewrite pass
Currently, it invokes individual rewrite patterns for each of the ops to support type conversions. DialectConversion framework supports this use-case and making use of it simplifies the pass. This pass still uses explicit type rewrite for While op functions because it involves type conversion depending on the context. This can not be done with type conversion in DialectConverter. For example, variant argument type can change to tensor<*xf32> or tensor<*xi32> depending on the caller. This can be improved on once we have shape inference support to rewrite functions with appropriate types. Specific changes: - Switch to ConversionPattern to implement individual op conversions - Define patterns in C++ for some of the ops currently defined in TableGen. Conversion patterns are not supported by TableGen - Define pattern for the While op instead of manually changing the result types - Define target for the pass and invoke DialectConversion PiperOrigin-RevId: 270103339
This commit is contained in:
parent
c0687c0e12
commit
39db8af832
@ -273,6 +273,7 @@ cc_library(
|
||||
"@local_config_mlir//:QuantOps",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:Support",
|
||||
"@local_config_mlir//:Transforms",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
@ -44,6 +45,7 @@ limitations under the License.
|
||||
#include "mlir/Support/Functional.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LLVM.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
|
||||
#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h"
|
||||
@ -79,10 +81,6 @@ struct LowerStaticTensorListPass
|
||||
// Apply type and op changes within a function.
|
||||
LogicalResult RewriteFunction(FuncOp func,
|
||||
TensorListPatternRewriter *rewriter);
|
||||
|
||||
// Changes the function type of `cond_func` and `body_func`, and the result
|
||||
// type of the `WhileOp`.
|
||||
LogicalResult UpdateWhileFunctionType(TF::WhileOp op);
|
||||
};
|
||||
|
||||
Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter,
|
||||
@ -100,10 +98,11 @@ Value *CreateI32SplatTensor(Operation *op, PatternRewriter *rewriter,
|
||||
shape_tensor, scalar_val);
|
||||
}
|
||||
|
||||
struct ConvertTFTensorListSetItem
|
||||
: public OpRewritePattern<TF::TensorListSetItemOp> {
|
||||
explicit ConvertTFTensorListSetItem(MLIRContext *context)
|
||||
: OpRewritePattern<TF::TensorListSetItemOp>(context, 1) {}
|
||||
struct ConvertTensorListSetItem : public ConversionPattern {
|
||||
explicit ConvertTensorListSetItem(MLIRContext *context)
|
||||
: ConversionPattern(TF::TensorListSetItemOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
// This function rewrites the original op into a series of slice and concat op
|
||||
// to produce the same result. It first slices the first `$index` rows. Then
|
||||
// expands the dimension of the `$item`, followed by another slice of the
|
||||
@ -116,13 +115,17 @@ struct ConvertTFTensorListSetItem
|
||||
// (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim =
|
||||
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
|
||||
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
|
||||
PatternMatchResult matchAndRewrite(TF::TensorListSetItemOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto input = op.input_handle();
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::TensorListSetItemOp>(operation);
|
||||
Value *input = operands[0];
|
||||
Value *index = operands[1];
|
||||
Value *item = operands[2];
|
||||
|
||||
auto shape_dtype = rewriter.getIntegerType(32);
|
||||
auto input_rank = rewriter.create<TF::RankOp>(
|
||||
op.getLoc(), rewriter.getTensorType({}, shape_dtype), input);
|
||||
auto item = op.item();
|
||||
auto item_rank = rewriter.create<TF::RankOp>(
|
||||
op.getLoc(), rewriter.getTensorType({}, shape_dtype), item);
|
||||
|
||||
@ -139,7 +142,6 @@ struct ConvertTFTensorListSetItem
|
||||
// Prepare the start position for the second slice op, which is
|
||||
// [index + 1, 0, 0 .. 0].
|
||||
// Calculate the first dimension, which is index + 1.
|
||||
auto index = op.index();
|
||||
auto vector_type = rewriter.getTensorType({1}, shape_dtype);
|
||||
auto begin = rewriter.create<TF::AddOp>(
|
||||
op.getLoc(), rewriter.getTensorType(shape_dtype), index,
|
||||
@ -191,7 +193,6 @@ struct ConvertTFTensorListSetItem
|
||||
op, input->getType(), scalar_zero,
|
||||
ArrayRef<Value *>({slice1, expanded_item, slice2}),
|
||||
rewriter.getI64IntegerAttr(3));
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
@ -200,20 +201,35 @@ struct ConvertTFTensorListSetItem
|
||||
// to generate an equivalent raw tensor. Derived classes are required to
|
||||
// override GetNumElements method.
|
||||
template <typename OpT>
|
||||
struct ConvertTFTensorListInitOp : public OpRewritePattern<OpT> {
|
||||
explicit ConvertTFTensorListInitOp(MLIRContext *context)
|
||||
: OpRewritePattern<OpT>(context, 1) {}
|
||||
struct ConvertTensorListInitOp : public ConversionPattern {
|
||||
explicit ConvertTensorListInitOp(MLIRContext *context)
|
||||
: ConversionPattern(OpT::getOperationName(), 1, context) {}
|
||||
|
||||
// Create and return a 1-d tensor with exactly one element equal to the number
|
||||
// of list elements to initialize the output tensor list with.
|
||||
virtual Value *GetNumElements(OpT op, PatternRewriter *rewriter) const = 0;
|
||||
virtual Value *GetNumElements(OpT op, ArrayRef<Value *> operands,
|
||||
PatternRewriter *rewriter) const = 0;
|
||||
|
||||
// Rewrites the original op into `tf.fill`. The result tensor shape is
|
||||
// [num_element, element_shape]. All the values in the result tensor will be
|
||||
// initialized to 0.
|
||||
PatternMatchResult matchAndRewrite(OpT op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto element_shape = op.element_shape();
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
OpT op = llvm::cast<OpT>(operation);
|
||||
|
||||
Type dtype = op.element_dtype();
|
||||
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
|
||||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
|
||||
dtype.isInteger(32) || dtype.isInteger(64))) {
|
||||
op.emitError(
|
||||
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
|
||||
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
|
||||
"transformation pass");
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
Value *element_shape = operands[0];
|
||||
auto shape_dtype = getElementTypeOrSelf(element_shape->getType());
|
||||
Type element_dtype = op.element_dtype();
|
||||
|
||||
@ -240,7 +256,7 @@ struct ConvertTFTensorListInitOp : public OpRewritePattern<OpT> {
|
||||
|
||||
// Add number of elements as the prefix to the element shape to get shape of
|
||||
// the output tensor.
|
||||
auto leading_dim = GetNumElements(op, &rewriter);
|
||||
auto leading_dim = GetNumElements(op, operands, &rewriter);
|
||||
auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0);
|
||||
auto list_shape = rewriter.create<TF::ConcatOp>(
|
||||
op.getLoc(), shape_type, scalar_zero,
|
||||
@ -258,47 +274,47 @@ struct ConvertTFTensorListInitOp : public OpRewritePattern<OpT> {
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTFTensorListReserve
|
||||
: public ConvertTFTensorListInitOp<TF::TensorListReserveOp> {
|
||||
explicit ConvertTFTensorListReserve(MLIRContext *context)
|
||||
: ConvertTFTensorListInitOp(context) {}
|
||||
struct ConvertTensorListReserve
|
||||
: public ConvertTensorListInitOp<TF::TensorListReserveOp> {
|
||||
explicit ConvertTensorListReserve(MLIRContext *context)
|
||||
: ConvertTensorListInitOp(context) {}
|
||||
|
||||
Value *GetNumElements(TF::TensorListReserveOp op,
|
||||
Value *GetNumElements(TF::TensorListReserveOp op, ArrayRef<Value *> operands,
|
||||
PatternRewriter *rewriter) const override {
|
||||
auto scalar_zero = CreateI32SplatConst(op, rewriter, {}, 0);
|
||||
auto shape_dtype = getElementTypeOrSelf(op.element_shape()->getType());
|
||||
Value *num_elements = operands[1];
|
||||
return rewriter->create<TF::ExpandDimsOp>(
|
||||
op.getLoc(), rewriter->getTensorType({1}, shape_dtype),
|
||||
op.num_elements(), scalar_zero);
|
||||
op.getLoc(), rewriter->getTensorType({1}, shape_dtype), num_elements,
|
||||
scalar_zero);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(hinsu): Replace with declarative patterns once the RewriterGen infra
|
||||
// supports patterns involving variadic operand ops.
|
||||
//
|
||||
// Note that we ignore the second operand `max_num_elements` as we don't have
|
||||
// any restrictions on the number of elements we can support. So this may
|
||||
// have a different behavior compared to TensorFlow in case of errors.
|
||||
struct ConvertTFEmptyTensorList
|
||||
: public ConvertTFTensorListInitOp<TF::EmptyTensorListOp> {
|
||||
explicit ConvertTFEmptyTensorList(MLIRContext *context)
|
||||
: ConvertTFTensorListInitOp(context) {}
|
||||
struct ConvertEmptyTensorList
|
||||
: public ConvertTensorListInitOp<TF::EmptyTensorListOp> {
|
||||
explicit ConvertEmptyTensorList(MLIRContext *context)
|
||||
: ConvertTensorListInitOp(context) {}
|
||||
|
||||
Value *GetNumElements(TF::EmptyTensorListOp op,
|
||||
Value *GetNumElements(TF::EmptyTensorListOp op, ArrayRef<Value *> operands,
|
||||
PatternRewriter *rewriter) const override {
|
||||
return CreateI32SplatConst(op, rewriter, {1}, 0);
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTFTensorListPushBack : public RewritePattern {
|
||||
explicit ConvertTFTensorListPushBack(MLIRContext *context)
|
||||
: RewritePattern(TF::TensorListPushBackOp::getOperationName(), 1,
|
||||
context) {}
|
||||
struct ConvertTensorListPushBack : public ConversionPattern {
|
||||
explicit ConvertTensorListPushBack(MLIRContext *context)
|
||||
: ConversionPattern(TF::TensorListPushBackOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
TF::TensorListPushBackOp push_back_op = cast<TF::TensorListPushBackOp>(op);
|
||||
Value *item = push_back_op.tensor();
|
||||
Value *input_handle = operands[0];
|
||||
Value *item = operands[1];
|
||||
Type dtype = getElementTypeOrSelf(*item);
|
||||
|
||||
// Returns a new type by prepending the specified dimension to the shape of
|
||||
@ -334,149 +350,182 @@ struct ConvertTFTensorListPushBack : public RewritePattern {
|
||||
// Concatenate tensor stored in the input handle with the expanded item to
|
||||
// get a tensor equivalent to the TensorList generated by this op.
|
||||
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
|
||||
op, result_type, scalar_zero,
|
||||
ArrayRef<Value *>({push_back_op.input_handle(), expanded_item}),
|
||||
push_back_op, result_type, scalar_zero,
|
||||
ArrayRef<Value *>({input_handle, expanded_item}),
|
||||
rewriter.getI64IntegerAttr(2));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
struct ConvertTensorListGetItem : public ConversionPattern {
|
||||
explicit ConvertTensorListGetItem(MLIRContext *context)
|
||||
: ConversionPattern(TF::TensorListGetItemOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
namespace TFL {
|
||||
namespace {
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
|
||||
} // namespace
|
||||
} // namespace TFL
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::TensorListGetItemOp>(operation);
|
||||
Value *input = operands[0];
|
||||
Value *index = operands[1];
|
||||
rewriter.replaceOpWithNewOp<TF::GatherOp>(
|
||||
operation, op.getType(), input, index, rewriter.getBoolAttr(true));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType(
|
||||
TF::WhileOp op) {
|
||||
struct ConvertTensorListLength : public ConversionPattern {
|
||||
explicit ConvertTensorListLength(MLIRContext *context)
|
||||
: ConversionPattern(TF::TensorListLengthOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::TensorListLengthOp>(operation);
|
||||
Location loc = op.getLoc();
|
||||
Value *input_handle = operands[0];
|
||||
|
||||
BoolAttr true_attr = rewriter.getBoolAttr(true);
|
||||
auto shape = rewriter.create<TF::ShapeOp>(loc, input_handle,
|
||||
/*use_32bit=*/true_attr);
|
||||
rewriter.replaceOpWithNewOp<TF::GatherOp>(
|
||||
op, op.getType(), shape, CreateI32SplatConst(op, &rewriter, {}, 0),
|
||||
/*validate_indices=*/true_attr);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertIdentity : public ConversionPattern {
|
||||
explicit ConvertIdentity(MLIRContext *context)
|
||||
: ConversionPattern(TF::IdentityOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::IdentityOp>(operation);
|
||||
Value *input = operands[0];
|
||||
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input->getType(), operands,
|
||||
op.getAttrs());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
// Changes the function type of `cond_func` and `body_func` for the given While
|
||||
// op.
|
||||
static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||
SmallVector<Type, 8> unranked_argument_types;
|
||||
for (const auto &operand : op.getOperands()) {
|
||||
unranked_argument_types.push_back(
|
||||
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
|
||||
}
|
||||
|
||||
auto *context = &getContext();
|
||||
auto module = getModule();
|
||||
FuncOp cond_func = module.lookupSymbol<FuncOp>(op.cond());
|
||||
FuncOp body_func = module.lookupSymbol<FuncOp>(op.body());
|
||||
auto module = op.getParentOfType<ModuleOp>();
|
||||
auto *context = module.getContext();
|
||||
|
||||
if (cond_func) {
|
||||
// Change `cond_func`'s argument types to `unranked_argument_types`.
|
||||
cond_func.setType(FunctionType::get(
|
||||
unranked_argument_types, cond_func.getType().getResults(), context));
|
||||
// Change the argument type for the first block.
|
||||
Block &cond_first_bb = cond_func.front();
|
||||
for (int i = 0; i < cond_first_bb.getNumArguments(); ++i) {
|
||||
cond_first_bb.getArgument(i)->setType(unranked_argument_types[i]);
|
||||
}
|
||||
}
|
||||
for (StringRef func_name : {op.cond(), op.body()}) {
|
||||
FuncOp func = module.lookupSymbol<FuncOp>(func_name);
|
||||
if (!func) continue;
|
||||
auto num_results = func.getType().getNumResults();
|
||||
|
||||
if (body_func) {
|
||||
SmallVector<Type, 8> updated_result_types;
|
||||
for (int i = 0; i < body_func.getType().getNumResults(); ++i) {
|
||||
auto result_type = body_func.getType().getResult(i);
|
||||
updated_result_types.reserve(num_results);
|
||||
for (int i = 0; i < num_results; ++i) {
|
||||
Type result_type = func.getType().getResult(i);
|
||||
if (getElementTypeOrSelf(result_type).isa<TF::VariantType>()) {
|
||||
// For variant type, use the corresponding unranked type.
|
||||
updated_result_types.push_back(unranked_argument_types[i]);
|
||||
} else {
|
||||
updated_result_types.push_back(result_type);
|
||||
result_type = unranked_argument_types[i];
|
||||
}
|
||||
updated_result_types.push_back(result_type);
|
||||
}
|
||||
// Change `body_func`'s argument type to `unranked_argument_types`. If it
|
||||
|
||||
// Change `func`'s argument type to `unranked_argument_types`. If it
|
||||
// return types contain a `DT_VARIANT`, change it to the unranked type
|
||||
// derived from the corresponding argument.
|
||||
body_func.setType(FunctionType::get(unranked_argument_types,
|
||||
updated_result_types, context));
|
||||
func.setType(FunctionType::get(unranked_argument_types,
|
||||
updated_result_types, context));
|
||||
|
||||
// Change the argument type for the first block.
|
||||
Block &body_first_bb = body_func.front();
|
||||
Block &body_first_bb = func.front();
|
||||
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
|
||||
body_first_bb.getArgument(i)->setType(unranked_argument_types[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < op.getNumOperands(); ++i) {
|
||||
auto operand = op.getOperand(i);
|
||||
auto result = op.getResult(i);
|
||||
if (getElementTypeOrSelf(result->getType()).isa<TF::VariantType>()) {
|
||||
// If we notice the result type is a DT_VARIANT, we change the
|
||||
// corresponding result type to unranked tensor type.
|
||||
result->setType(
|
||||
UnrankedTensorType::get(getElementTypeOrSelf(operand->getType())));
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
struct ConvertWhile : public ConversionPattern {
|
||||
explicit ConvertWhile(MLIRContext *context)
|
||||
: ConversionPattern(TF::WhileOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::WhileOp>(operation);
|
||||
|
||||
llvm::SmallVector<Type, 8> result_types;
|
||||
result_types.reserve(op.getNumOperands());
|
||||
for (int i = 0, e = operands.size(); i != e; ++i) {
|
||||
Type result_ty = op.getResult(i)->getType();
|
||||
|
||||
// If we notice the result type is a DT_VARIANT, we change the
|
||||
// corresponding result type to unranked tensor type.
|
||||
if (getElementTypeOrSelf(result_ty).isa<TF::VariantType>()) {
|
||||
Type element_ty = getElementTypeOrSelf(operands[i]->getType());
|
||||
result_ty = UnrankedTensorType::get(element_ty);
|
||||
}
|
||||
result_types.push_back(result_ty);
|
||||
}
|
||||
|
||||
// Clone original while op with new operands and updated result types.
|
||||
auto cloned = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
|
||||
operands, op.getAttrs());
|
||||
cloned.removeAttr("T");
|
||||
UpdateFunctionTypes(cloned);
|
||||
|
||||
SmallVector<Value *, 8> results(cloned.getResults());
|
||||
rewriter.replaceOp(op, results);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/generated_lower_static_tensor_list.inc"
|
||||
|
||||
LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
||||
FuncOp func, TensorListPatternRewriter *rewriter) {
|
||||
auto *context = &getContext();
|
||||
|
||||
for (Block &block : func) {
|
||||
// Buffer the op pointers inside the current block into a vector, since
|
||||
// the block iterator might be invalidated if we rewrite ops during looping.
|
||||
std::vector<Operation *> ops_in_block;
|
||||
for (Operation &op : block) {
|
||||
ops_in_block.push_back(&op);
|
||||
}
|
||||
// TensorFlow operations that doesn't have operands and results of type
|
||||
// variant are legal. Here, we don't distinguish between variants encoding
|
||||
// TensorList or some other type as that information is not available here.
|
||||
// This constraint should be relaxed to support other variant types in TFLite.
|
||||
auto is_legal = [](Operation *op) {
|
||||
auto is_not_variant = [](Type ty) {
|
||||
return !ty.cast<ShapedType>().getElementType().isa<TF::VariantType>();
|
||||
};
|
||||
return llvm::all_of(op->getOperandTypes(), is_not_variant) &&
|
||||
llvm::all_of(op->getResultTypes(), is_not_variant);
|
||||
};
|
||||
|
||||
for (Operation *op : ops_in_block) {
|
||||
if (auto tf_op = llvm::dyn_cast<TF::TensorListFromTensorOp>(op)) {
|
||||
auto c = TFL::ConvertTFTensorListFromTensor(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListReserveOp>(op)) {
|
||||
if (!(tf_op.element_dtype().isF16() || tf_op.element_dtype().isF32() ||
|
||||
tf_op.element_dtype().isF64() ||
|
||||
tf_op.element_dtype().isInteger(1) ||
|
||||
tf_op.element_dtype().isInteger(8) ||
|
||||
tf_op.element_dtype().isInteger(16) ||
|
||||
tf_op.element_dtype().isInteger(32) ||
|
||||
tf_op.element_dtype().isInteger(64))) {
|
||||
return tf_op.emitError(
|
||||
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
|
||||
"integer "
|
||||
"or 16-bit/32-bit/64-bit "
|
||||
"float type during TF Lite transformation pass");
|
||||
}
|
||||
auto c = ConvertTFTensorListReserve(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(tf_op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::EmptyTensorListOp>(op)) {
|
||||
auto c = ConvertTFEmptyTensorList(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(tf_op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListGetItemOp>(op)) {
|
||||
auto c = TFL::ConvertTFTensorListGetItem(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListSetItemOp>(op)) {
|
||||
auto c = ConvertTFTensorListSetItem(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(tf_op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListStackOp>(op)) {
|
||||
auto c = TFL::ConvertTFTensorListStack(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListPushBackOp>(op)) {
|
||||
auto c = ConvertTFTensorListPushBack(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::TensorListLengthOp>(op)) {
|
||||
auto c = TFL::ConvertTFTensorListLength(context);
|
||||
rewriter->setInsertionPoint(op);
|
||||
c.matchAndRewrite(op, *rewriter);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::WhileOp>(op)) {
|
||||
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
|
||||
UpdateWhileFunctionType(tf_op);
|
||||
} else if (auto tf_op = llvm::dyn_cast<TF::IdentityOp>(op)) {
|
||||
if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context));
|
||||
tf_op.getResult()->setType(tf_op.getOperand()->getType());
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
ConversionTarget target(*context);
|
||||
target.addDynamicallyLegalDialect<TF::TensorFlowDialect>(
|
||||
llvm::Optional<ConversionTarget::DynamicLegalityCallbackFn>(is_legal));
|
||||
target.addIllegalOp<TF::EmptyTensorListOp, TF::TensorListFromTensorOp,
|
||||
TF::TensorListGetItemOp, TF::TensorListLengthOp,
|
||||
TF::TensorListPushBackOp, TF::TensorListReserveOp,
|
||||
TF::TensorListSetItemOp, TF::TensorListStackOp>();
|
||||
// TODO(hinsu): Use TFLite constant op for constants.
|
||||
target.addLegalOp<ConstantOp>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
target.addLegalOp<ReturnOp>();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ConvertEmptyTensorList, ConvertIdentity,
|
||||
ConvertTensorListFromTensor, ConvertTensorListGetItem,
|
||||
ConvertTensorListLength, ConvertTensorListPushBack,
|
||||
ConvertTensorListReserve, ConvertTensorListSetItem,
|
||||
ConvertTensorListStack, ConvertWhile>(context);
|
||||
return applyFullConversion(func, target, patterns);
|
||||
}
|
||||
|
||||
void LowerStaticTensorListPass::runOnModule() {
|
||||
@ -503,6 +552,8 @@ void LowerStaticTensorListPass::runOnModule() {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList
|
||||
/// pass.
|
||||
std::unique_ptr<OpPassBase<ModuleOp>> TFL::CreateLowerStaticTensorListPass() {
|
||||
|
@ -15,7 +15,6 @@ limitations under the License.
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/StandardOps/Ops.td"
|
||||
include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
|
||||
def CreateTFShapeOp : NativeCodeCall<
|
||||
@ -27,22 +26,10 @@ def CreateTFShapeOp : NativeCodeCall<
|
||||
// into regular tensors. We also assume that each element in the `TensorList` has
|
||||
// a same constant shape.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def ConvertTFTensorListFromTensor : Pat<
|
||||
def ConvertTensorListFromTensor : Pat<
|
||||
(TF_TensorListFromTensorOp $tensor, $element_shape),
|
||||
(replaceWithValue $tensor)>;
|
||||
|
||||
def ConvertTFTensorListStack : Pat<
|
||||
def ConvertTensorListStack : Pat<
|
||||
(TF_TensorListStackOp $input, $element_shape, $num_elements),
|
||||
(replaceWithValue $input)>;
|
||||
|
||||
def ConvertTFTensorListGetItem : Pat<
|
||||
(TF_TensorListGetItemOp $input, $index, $element_shape),
|
||||
(TF_GatherOp $input, $index, (NativeCodeCall<"$_builder.getBoolAttr(true)">))>;
|
||||
|
||||
// TensorListLength is equivalent to the size of the first dimension of the
|
||||
// input tensorlist, rewrite it to a combination of Gather and Shape op.
|
||||
def ConvertTFTensorListLength: Pat<
|
||||
(TF_TensorListLengthOp:$old_value $input),
|
||||
(TF_GatherOp
|
||||
(CreateTFShapeOp $old_value, $input, /*use 32bit*/ConstBoolAttrTrue),
|
||||
(ConstantOp ConstantAttr<I32ElementsAttr, "0">), ConstBoolAttrTrue)>;
|
||||
|
Loading…
Reference in New Issue
Block a user