NFC: Derive tfl-lower-static-tensor-list pass patterns from OpConversionPattern

PiperOrigin-RevId: 289976784
Change-Id: I2e062774d1bba9df3e603a7585d614950d8ecfbd
This commit is contained in:
Smit Hinsu 2020-01-15 18:06:31 -08:00 committed by TensorFlower Gardener
parent fc7e43de1f
commit 52b8ba5463

View File

@ -162,10 +162,9 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
start_position, slice_size); start_position, slice_size);
} }
struct ConvertTensorListSetItem : public ConversionPattern { struct ConvertTensorListSetItem
explicit ConvertTensorListSetItem(MLIRContext *context) : public OpConversionPattern<TF::TensorListSetItemOp> {
: ConversionPattern(TF::TensorListSetItemOp::getOperationName(), 1, using OpConversionPattern::OpConversionPattern;
context) {}
// This function rewrites the original op into a series of slice and concat op // This function rewrites the original op into a series of slice and concat op
// to produce the same result. It first slices the first `$index` rows. Then // to produce the same result. It first slices the first `$index` rows. Then
@ -180,9 +179,8 @@ struct ConvertTensorListSetItem : public ConversionPattern {
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice // 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>; // $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
PatternMatchResult matchAndRewrite( PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands, TF::TensorListSetItemOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListSetItemOp>(operation);
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = operands[0]; Value input = operands[0];
Value index = operands[1]; Value index = operands[1];
@ -235,9 +233,8 @@ struct ConvertTensorListSetItem : public ConversionPattern {
// to generate an equivalent raw tensor. Derived classes are required to // to generate an equivalent raw tensor. Derived classes are required to
// override GetNumElements method. // override GetNumElements method.
template <typename OpT> template <typename OpT>
struct ConvertTensorListInitOp : public ConversionPattern { struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
explicit ConvertTensorListInitOp(MLIRContext *context) using OpConversionPattern<OpT>::OpConversionPattern;
: ConversionPattern(OpT::getOperationName(), 1, context) {}
// Create and return a 1-d tensor with exactly one element equal to the number // Create and return a 1-d tensor with exactly one element equal to the number
// of list elements to initialize the output tensor list with. // of list elements to initialize the output tensor list with.
@ -248,10 +245,8 @@ struct ConvertTensorListInitOp : public ConversionPattern {
// [num_element, element_shape]. All the values in the result tensor will be // [num_element, element_shape]. All the values in the result tensor will be
// initialized to 0. // initialized to 0.
PatternMatchResult matchAndRewrite( PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands, OpT op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
OpT op = llvm::cast<OpT>(operation);
Type dtype = op.element_dtype(); Type dtype = op.element_dtype();
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() || if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) || dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
@ -260,7 +255,7 @@ struct ConvertTensorListInitOp : public ConversionPattern {
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit " "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
"integer or 16-bit/32-bit/64-bit float type during TF Lite " "integer or 16-bit/32-bit/64-bit float type during TF Lite "
"transformation pass"); "transformation pass");
return matchFailure(); return ConversionPattern::matchFailure();
} }
Value element_shape = operands[0]; Value element_shape = operands[0];
@ -376,15 +371,13 @@ struct ConvertEmptyTensorList
} }
}; };
struct ConvertTensorListPushBack : public ConversionPattern { struct ConvertTensorListPushBack
explicit ConvertTensorListPushBack(MLIRContext *context) : public OpConversionPattern<TF::TensorListPushBackOp> {
: ConversionPattern(TF::TensorListPushBackOp::getOperationName(), 1, using OpConversionPattern::OpConversionPattern;
context) {}
PatternMatchResult matchAndRewrite( PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands, TF::TensorListPushBackOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
TF::TensorListPushBackOp push_back_op = cast<TF::TensorListPushBackOp>(op);
Value input_handle = operands[0]; Value input_handle = operands[0];
Value item = operands[1]; Value item = operands[1];
@ -392,21 +385,21 @@ struct ConvertTensorListPushBack : public ConversionPattern {
// tensor and it is compatible for the Concat Op. // tensor and it is compatible for the Concat Op.
Type expanded_item_type = Type expanded_item_type =
PrependLeadingDimIfRanked(1, item.getType(), &rewriter); PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0); Location loc = op.getLoc();
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
auto expanded_item = rewriter.create<TF::ExpandDimsOp>( auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
op->getLoc(), expanded_item_type, item, scalar_zero); loc, expanded_item_type, item, scalar_zero);
Type elem_type = getElementTypeOrSelf(item); Type elem_type = getElementTypeOrSelf(item);
auto handle_dtype = auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
getElementTypeOrSelf(push_back_op.output_handle().getType()) .cast<TF::VariantType>();
.cast<TF::VariantType>();
Type result_type = Type result_type =
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
// Concatenate tensor stored in the input handle with the expanded item to // Concatenate tensor stored in the input handle with the expanded item to
// get a tensor equivalent to the TensorList generated by this op. // get a tensor equivalent to the TensorList generated by this op.
rewriter.replaceOpWithNewOp<TF::ConcatOp>( rewriter.replaceOpWithNewOp<TF::ConcatOp>(
push_back_op, result_type, scalar_zero, op, result_type, scalar_zero,
ArrayRef<Value>({input_handle, expanded_item})); ArrayRef<Value>({input_handle, expanded_item}));
return matchSuccess(); return matchSuccess();
} }
@ -422,31 +415,28 @@ struct ConvertTensorListPushBack : public ConversionPattern {
// TODO(haoliang): We could simplify this transformation by rewriting to pure // TODO(haoliang): We could simplify this transformation by rewriting to pure
// tensorlist ops and a few non-tensorlist ops (such as `SliceOp`). By operating // tensorlist ops and a few non-tensorlist ops (such as `SliceOp`). By operating
// only on variant types, we could save some ops involved in rewriting this op. // only on variant types, we could save some ops involved in rewriting this op.
struct ConvertTensorListResize : public ConversionPattern { struct ConvertTensorListResize
explicit ConvertTensorListResize(MLIRContext *context) : public OpConversionPattern<TF::TensorListResizeOp> {
: ConversionPattern(TF::TensorListResizeOp::getOperationName(), 1, using OpConversionPattern::OpConversionPattern;
context) {}
PatternMatchResult matchAndRewrite( PatternMatchResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands, TF::TensorListResizeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
TF::TensorListResizeOp resize_op = cast<TF::TensorListResizeOp>(op);
Value input_handle = operands[0]; Value input_handle = operands[0];
Value size = operands[1]; Value size = operands[1];
Location loc = resize_op.getLoc(); Location loc = op.getLoc();
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0); Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
// Compute the input tensorlist's length and store it in `input_size`. // Compute the input tensorlist's length and store it in `input_size`.
IntegerType shape_dtype = rewriter.getIntegerType(32); IntegerType shape_dtype = rewriter.getIntegerType(32);
auto input_size = rewriter.create<TF::TensorListLengthOp>( auto input_size = rewriter.create<TF::TensorListLengthOp>(
loc, RankedTensorType::get({}, shape_dtype), op->getOperand(0)); loc, RankedTensorType::get({}, shape_dtype), op.getOperand(0));
// Infer result type of this op based on TF's shape inference result. // Infer result type of this op based on TF's shape inference result.
Type elem_type = getElementTypeOrSelf(input_handle); Type elem_type = getElementTypeOrSelf(input_handle);
auto handle_dtype = auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
getElementTypeOrSelf(resize_op.output_handle().getType()) .cast<TF::VariantType>();
.cast<TF::VariantType>();
Type result_type = Type result_type =
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
@ -471,7 +461,7 @@ struct ConvertTensorListResize : public ConversionPattern {
// Constructs `then_branch`, which is executed when `if_cond` evaluates to // Constructs `then_branch`, which is executed when `if_cond` evaluates to
// true. // true.
FuncOp then_branch_op = FuncOp::create(loc, "cond_true", func_type); FuncOp then_branch_op = FuncOp::create(loc, "cond_true", func_type);
CreateCondTrueBranch(resize_op, shape_dtype, result_type, then_branch_op, CreateCondTrueBranch(op, shape_dtype, result_type, then_branch_op,
&rewriter); &rewriter);
// Constructs `else_branch`, which is executed when `if_cond` evaluates to // Constructs `else_branch`, which is executed when `if_cond` evaluates to
@ -483,7 +473,7 @@ struct ConvertTensorListResize : public ConversionPattern {
// Inserts the two blocks' names into the symbol table held by the module. // Inserts the two blocks' names into the symbol table held by the module.
// Using SymbolTable will ensure that the inserted symbol names are // Using SymbolTable will ensure that the inserted symbol names are
// unique. // unique.
SymbolTable manager(resize_op.getParentOfType<ModuleOp>()); SymbolTable manager(op.getParentOfType<ModuleOp>());
manager.insert(then_branch_op); manager.insert(then_branch_op);
manager.insert(else_branch_op); manager.insert(else_branch_op);
@ -569,32 +559,28 @@ struct ConvertTensorListResize : public ConversionPattern {
} }
}; };
struct ConvertTensorListGetItem : public ConversionPattern { struct ConvertTensorListGetItem
explicit ConvertTensorListGetItem(MLIRContext *context) : public OpConversionPattern<TF::TensorListGetItemOp> {
: ConversionPattern(TF::TensorListGetItemOp::getOperationName(), 1, using OpConversionPattern::OpConversionPattern;
context) {}
PatternMatchResult matchAndRewrite( PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands, TF::TensorListGetItemOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListGetItemOp>(operation);
Value input = operands[0]; Value input = operands[0];
Value index = operands[1]; Value index = operands[1];
rewriter.replaceOpWithNewOp<TF::GatherOp>( rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
operation, op.getType(), input, index, rewriter.getBoolAttr(true)); rewriter.getBoolAttr(true));
return matchSuccess(); return matchSuccess();
} }
}; };
struct ConvertTensorListLength : public ConversionPattern { struct ConvertTensorListLength
explicit ConvertTensorListLength(MLIRContext *context) : public OpConversionPattern<TF::TensorListLengthOp> {
: ConversionPattern(TF::TensorListLengthOp::getOperationName(), 1, using OpConversionPattern::OpConversionPattern;
context) {}
PatternMatchResult matchAndRewrite( PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands, TF::TensorListLengthOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListLengthOp>(operation);
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input_handle = operands[0]; Value input_handle = operands[0];
@ -608,15 +594,13 @@ struct ConvertTensorListLength : public ConversionPattern {
} }
}; };
struct ConvertTensorListStack : public ConversionPattern { struct ConvertTensorListStack
explicit ConvertTensorListStack(MLIRContext *context) : public OpConversionPattern<TF::TensorListStackOp> {
: ConversionPattern(TF::TensorListStackOp::getOperationName(), 1, using OpConversionPattern::OpConversionPattern;
context) {}
PatternMatchResult matchAndRewrite( PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands, TF::TensorListStackOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::TensorListStackOp>(operation);
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = operands[0]; Value input = operands[0];
Value element_shape = operands[1]; Value element_shape = operands[1];
@ -649,14 +633,12 @@ struct ConvertTensorListStack : public ConversionPattern {
} }
}; };
struct ConvertIdentity : public ConversionPattern { struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
explicit ConvertIdentity(MLIRContext *context) using OpConversionPattern::OpConversionPattern;
: ConversionPattern(TF::IdentityOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite( PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands, TF::IdentityOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::IdentityOp>(operation);
Value input = operands[0]; Value input = operands[0];
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands, rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
op.getAttrs()); op.getAttrs());
@ -722,15 +704,12 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
return success(); return success();
} }
struct ConvertWhile : public ConversionPattern { struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
explicit ConvertWhile(MLIRContext *context) using OpConversionPattern::OpConversionPattern;
: ConversionPattern(TF::WhileOp::getOperationName(), 1, context) {}
PatternMatchResult matchAndRewrite( PatternMatchResult matchAndRewrite(
Operation *operation, ArrayRef<Value> operands, TF::WhileOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto op = llvm::cast<TF::WhileOp>(operation);
llvm::SmallVector<Type, 8> result_types; llvm::SmallVector<Type, 8> result_types;
result_types.reserve(op.getNumOperands()); result_types.reserve(op.getNumOperands());
for (int i = 0, e = operands.size(); i != e; ++i) { for (int i = 0, e = operands.size(); i != e; ++i) {