NFC: Derive tfl-lower-static-tensor-list pass patterns from OpConversionPattern

PiperOrigin-RevId: 289976784
Change-Id: I2e062774d1bba9df3e603a7585d614950d8ecfbd
This commit is contained in:
Smit Hinsu 2020-01-15 18:06:31 -08:00 committed by TensorFlower Gardener
parent fc7e43de1f
commit 52b8ba5463

View File

@ -162,10 +162,9 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
start_position, slice_size);
}
struct ConvertTensorListSetItem : public ConversionPattern {
explicit ConvertTensorListSetItem(MLIRContext *context)
: ConversionPattern(TF::TensorListSetItemOp::getOperationName(), 1,
context) {}
struct ConvertTensorListSetItem
: public OpConversionPattern<TF::TensorListSetItemOp> {
using OpConversionPattern::OpConversionPattern;
// This function rewrites the original op into a series of slice and concat op
// to produce the same result. It first slices the first `$index` rows. Then
@ -180,9 +179,8 @@ struct ConvertTensorListSetItem : public ConversionPattern {
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands,
TF::TensorListSetItemOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListSetItemOp>(operation);
Location loc = op.getLoc();
Value input = operands[0];
Value index = operands[1];
@ -235,9 +233,8 @@ struct ConvertTensorListSetItem : public ConversionPattern {
// to generate an equivalent raw tensor. Derived classes are required to
// override GetNumElements method.
template <typename OpT>
struct ConvertTensorListInitOp : public ConversionPattern {
explicit ConvertTensorListInitOp(MLIRContext *context)
: ConversionPattern(OpT::getOperationName(), 1, context) {}
struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
using OpConversionPattern<OpT>::OpConversionPattern;
// Create and return a 1-d tensor with exactly one element equal to the number
// of list elements to initialize the output tensor list with.
@ -248,10 +245,8 @@ struct ConvertTensorListInitOp : public ConversionPattern {
// [num_element, element_shape]. All the values in the result tensor will be
// initialized to 0.
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands,
OpT op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OpT op = llvm::cast<OpT>(operation);
Type dtype = op.element_dtype();
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
@ -260,7 +255,7 @@ struct ConvertTensorListInitOp : public ConversionPattern {
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
"transformation pass");
return matchFailure();
return ConversionPattern::matchFailure();
}
Value element_shape = operands[0];
@ -376,15 +371,13 @@ struct ConvertEmptyTensorList
}
};
struct ConvertTensorListPushBack : public ConversionPattern {
explicit ConvertTensorListPushBack(MLIRContext *context)
: ConversionPattern(TF::TensorListPushBackOp::getOperationName(), 1,
context) {}
struct ConvertTensorListPushBack
: public OpConversionPattern<TF::TensorListPushBackOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
TF::TensorListPushBackOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
TF::TensorListPushBackOp push_back_op = cast<TF::TensorListPushBackOp>(op);
Value input_handle = operands[0];
Value item = operands[1];
@ -392,21 +385,21 @@ struct ConvertTensorListPushBack : public ConversionPattern {
// tensor and it is compatible for the Concat Op.
Type expanded_item_type =
PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0);
Location loc = op.getLoc();
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), expanded_item_type, item, scalar_zero);
loc, expanded_item_type, item, scalar_zero);
Type elem_type = getElementTypeOrSelf(item);
auto handle_dtype =
getElementTypeOrSelf(push_back_op.output_handle().getType())
.cast<TF::VariantType>();
auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
.cast<TF::VariantType>();
Type result_type =
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
// Concatenate tensor stored in the input handle with the expanded item to
// get a tensor equivalent to the TensorList generated by this op.
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
push_back_op, result_type, scalar_zero,
op, result_type, scalar_zero,
ArrayRef<Value>({input_handle, expanded_item}));
return matchSuccess();
}
@ -422,31 +415,28 @@ struct ConvertTensorListPushBack : public ConversionPattern {
// TODO(haoliang): We could simplify this transformation by rewriting to pure
// tensorlist ops and a few non-tensorlist ops (such as `SliceOp`). By operating
// only on variant types, we could save some ops involved in rewriting this op.
struct ConvertTensorListResize : public ConversionPattern {
explicit ConvertTensorListResize(MLIRContext *context)
: ConversionPattern(TF::TensorListResizeOp::getOperationName(), 1,
context) {}
struct ConvertTensorListResize
: public OpConversionPattern<TF::TensorListResizeOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
TF::TensorListResizeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
TF::TensorListResizeOp resize_op = cast<TF::TensorListResizeOp>(op);
Value input_handle = operands[0];
Value size = operands[1];
Location loc = resize_op.getLoc();
Location loc = op.getLoc();
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
// Compute the input tensorlist's length and store it in `input_size`.
IntegerType shape_dtype = rewriter.getIntegerType(32);
auto input_size = rewriter.create<TF::TensorListLengthOp>(
loc, RankedTensorType::get({}, shape_dtype), op->getOperand(0));
loc, RankedTensorType::get({}, shape_dtype), op.getOperand(0));
// Infer result type of this op based on TF's shape inference result.
Type elem_type = getElementTypeOrSelf(input_handle);
auto handle_dtype =
getElementTypeOrSelf(resize_op.output_handle().getType())
.cast<TF::VariantType>();
auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
.cast<TF::VariantType>();
Type result_type =
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
@ -471,7 +461,7 @@ struct ConvertTensorListResize : public ConversionPattern {
// Constructs `then_branch`, which is executed when `if_cond` evaluates to
// true.
FuncOp then_branch_op = FuncOp::create(loc, "cond_true", func_type);
CreateCondTrueBranch(resize_op, shape_dtype, result_type, then_branch_op,
CreateCondTrueBranch(op, shape_dtype, result_type, then_branch_op,
&rewriter);
// Constructs `else_branch`, which is executed when `if_cond` evaluates to
@ -483,7 +473,7 @@ struct ConvertTensorListResize : public ConversionPattern {
// Inserts the two blocks' names into the symbol table held by the module.
// Using SymbolTable will ensure that the inserted symbol names are
// unique.
SymbolTable manager(resize_op.getParentOfType<ModuleOp>());
SymbolTable manager(op.getParentOfType<ModuleOp>());
manager.insert(then_branch_op);
manager.insert(else_branch_op);
@ -569,32 +559,28 @@ struct ConvertTensorListResize : public ConversionPattern {
}
};
struct ConvertTensorListGetItem : public ConversionPattern {
explicit ConvertTensorListGetItem(MLIRContext *context)
: ConversionPattern(TF::TensorListGetItemOp::getOperationName(), 1,
context) {}
struct ConvertTensorListGetItem
: public OpConversionPattern<TF::TensorListGetItemOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands,
TF::TensorListGetItemOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListGetItemOp>(operation);
Value input = operands[0];
Value index = operands[1];
rewriter.replaceOpWithNewOp<TF::GatherOp>(
operation, op.getType(), input, index, rewriter.getBoolAttr(true));
rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
rewriter.getBoolAttr(true));
return matchSuccess();
}
};
struct ConvertTensorListLength : public ConversionPattern {
explicit ConvertTensorListLength(MLIRContext *context)
: ConversionPattern(TF::TensorListLengthOp::getOperationName(), 1,
context) {}
struct ConvertTensorListLength
: public OpConversionPattern<TF::TensorListLengthOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands,
TF::TensorListLengthOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListLengthOp>(operation);
Location loc = op.getLoc();
Value input_handle = operands[0];
@ -608,15 +594,13 @@ struct ConvertTensorListLength : public ConversionPattern {
}
};
struct ConvertTensorListStack : public ConversionPattern {
explicit ConvertTensorListStack(MLIRContext *context)
: ConversionPattern(TF::TensorListStackOp::getOperationName(), 1,
context) {}
struct ConvertTensorListStack
: public OpConversionPattern<TF::TensorListStackOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands,
TF::TensorListStackOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListStackOp>(operation);
Location loc = op.getLoc();
Value input = operands[0];
Value element_shape = operands[1];
@ -649,14 +633,12 @@ struct ConvertTensorListStack : public ConversionPattern {
}
};
struct ConvertIdentity : public ConversionPattern {
explicit ConvertIdentity(MLIRContext *context)
: ConversionPattern(TF::IdentityOp::getOperationName(), 1, context) {}
struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands,
TF::IdentityOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::IdentityOp>(operation);
Value input = operands[0];
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
op.getAttrs());
@ -722,15 +704,12 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
return success();
}
struct ConvertWhile : public ConversionPattern {
explicit ConvertWhile(MLIRContext *context)
: ConversionPattern(TF::WhileOp::getOperationName(), 1, context) {}
struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
using OpConversionPattern::OpConversionPattern;
PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands,
TF::WhileOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::WhileOp>(operation);
llvm::SmallVector<Type, 8> result_types;
result_types.reserve(op.getNumOperands());
for (int i = 0, e = operands.size(); i != e; ++i) {