NFC: Derive tfl-lower-static-tensor-list pass patterns from OpConversionPattern
PiperOrigin-RevId: 289976784 Change-Id: I2e062774d1bba9df3e603a7585d614950d8ecfbd
This commit is contained in:
parent
fc7e43de1f
commit
52b8ba5463
@ -162,10 +162,9 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list,
|
||||
start_position, slice_size);
|
||||
}
|
||||
|
||||
struct ConvertTensorListSetItem : public ConversionPattern {
|
||||
explicit ConvertTensorListSetItem(MLIRContext *context)
|
||||
: ConversionPattern(TF::TensorListSetItemOp::getOperationName(), 1,
|
||||
context) {}
|
||||
struct ConvertTensorListSetItem
|
||||
: public OpConversionPattern<TF::TensorListSetItemOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
// This function rewrites the original op into a series of slice and concat op
|
||||
// to produce the same result. It first slices the first `$index` rows. Then
|
||||
@ -180,9 +179,8 @@ struct ConvertTensorListSetItem : public ConversionPattern {
|
||||
// 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice
|
||||
// $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>;
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
TF::TensorListSetItemOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::TensorListSetItemOp>(operation);
|
||||
Location loc = op.getLoc();
|
||||
Value input = operands[0];
|
||||
Value index = operands[1];
|
||||
@ -235,9 +233,8 @@ struct ConvertTensorListSetItem : public ConversionPattern {
|
||||
// to generate an equivalent raw tensor. Derived classes are required to
|
||||
// override GetNumElements method.
|
||||
template <typename OpT>
|
||||
struct ConvertTensorListInitOp : public ConversionPattern {
|
||||
explicit ConvertTensorListInitOp(MLIRContext *context)
|
||||
: ConversionPattern(OpT::getOperationName(), 1, context) {}
|
||||
struct ConvertTensorListInitOp : public OpConversionPattern<OpT> {
|
||||
using OpConversionPattern<OpT>::OpConversionPattern;
|
||||
|
||||
// Create and return a 1-d tensor with exactly one element equal to the number
|
||||
// of list elements to initialize the output tensor list with.
|
||||
@ -248,10 +245,8 @@ struct ConvertTensorListInitOp : public ConversionPattern {
|
||||
// [num_element, element_shape]. All the values in the result tensor will be
|
||||
// initialized to 0.
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
OpT op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
OpT op = llvm::cast<OpT>(operation);
|
||||
|
||||
Type dtype = op.element_dtype();
|
||||
if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() ||
|
||||
dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) ||
|
||||
@ -260,7 +255,7 @@ struct ConvertTensorListInitOp : public ConversionPattern {
|
||||
"requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit "
|
||||
"integer or 16-bit/32-bit/64-bit float type during TF Lite "
|
||||
"transformation pass");
|
||||
return matchFailure();
|
||||
return ConversionPattern::matchFailure();
|
||||
}
|
||||
|
||||
Value element_shape = operands[0];
|
||||
@ -376,15 +371,13 @@ struct ConvertEmptyTensorList
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTensorListPushBack : public ConversionPattern {
|
||||
explicit ConvertTensorListPushBack(MLIRContext *context)
|
||||
: ConversionPattern(TF::TensorListPushBackOp::getOperationName(), 1,
|
||||
context) {}
|
||||
struct ConvertTensorListPushBack
|
||||
: public OpConversionPattern<TF::TensorListPushBackOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
TF::TensorListPushBackOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
TF::TensorListPushBackOp push_back_op = cast<TF::TensorListPushBackOp>(op);
|
||||
Value input_handle = operands[0];
|
||||
Value item = operands[1];
|
||||
|
||||
@ -392,13 +385,13 @@ struct ConvertTensorListPushBack : public ConversionPattern {
|
||||
// tensor and it is compatible for the Concat Op.
|
||||
Type expanded_item_type =
|
||||
PrependLeadingDimIfRanked(1, item.getType(), &rewriter);
|
||||
Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0);
|
||||
Location loc = op.getLoc();
|
||||
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
|
||||
auto expanded_item = rewriter.create<TF::ExpandDimsOp>(
|
||||
op->getLoc(), expanded_item_type, item, scalar_zero);
|
||||
loc, expanded_item_type, item, scalar_zero);
|
||||
|
||||
Type elem_type = getElementTypeOrSelf(item);
|
||||
auto handle_dtype =
|
||||
getElementTypeOrSelf(push_back_op.output_handle().getType())
|
||||
auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
|
||||
.cast<TF::VariantType>();
|
||||
Type result_type =
|
||||
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
|
||||
@ -406,7 +399,7 @@ struct ConvertTensorListPushBack : public ConversionPattern {
|
||||
// Concatenate tensor stored in the input handle with the expanded item to
|
||||
// get a tensor equivalent to the TensorList generated by this op.
|
||||
rewriter.replaceOpWithNewOp<TF::ConcatOp>(
|
||||
push_back_op, result_type, scalar_zero,
|
||||
op, result_type, scalar_zero,
|
||||
ArrayRef<Value>({input_handle, expanded_item}));
|
||||
return matchSuccess();
|
||||
}
|
||||
@ -422,30 +415,27 @@ struct ConvertTensorListPushBack : public ConversionPattern {
|
||||
// TODO(haoliang): We could simplify this transformation by rewriting to pure
|
||||
// tensorlist ops and a few non-tensorlist ops (such as `SliceOp`). By operating
|
||||
// only on variant types, we could save some ops involved in rewriting this op.
|
||||
struct ConvertTensorListResize : public ConversionPattern {
|
||||
explicit ConvertTensorListResize(MLIRContext *context)
|
||||
: ConversionPattern(TF::TensorListResizeOp::getOperationName(), 1,
|
||||
context) {}
|
||||
struct ConvertTensorListResize
|
||||
: public OpConversionPattern<TF::TensorListResizeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
TF::TensorListResizeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
TF::TensorListResizeOp resize_op = cast<TF::TensorListResizeOp>(op);
|
||||
Value input_handle = operands[0];
|
||||
Value size = operands[1];
|
||||
|
||||
Location loc = resize_op.getLoc();
|
||||
Location loc = op.getLoc();
|
||||
Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0);
|
||||
|
||||
// Compute the input tensorlist's length and store it in `input_size`.
|
||||
IntegerType shape_dtype = rewriter.getIntegerType(32);
|
||||
auto input_size = rewriter.create<TF::TensorListLengthOp>(
|
||||
loc, RankedTensorType::get({}, shape_dtype), op->getOperand(0));
|
||||
loc, RankedTensorType::get({}, shape_dtype), op.getOperand(0));
|
||||
|
||||
// Infer result type of this op based on TF's shape inference result.
|
||||
Type elem_type = getElementTypeOrSelf(input_handle);
|
||||
auto handle_dtype =
|
||||
getElementTypeOrSelf(resize_op.output_handle().getType())
|
||||
auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType())
|
||||
.cast<TF::VariantType>();
|
||||
Type result_type =
|
||||
GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter);
|
||||
@ -471,7 +461,7 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
// Constructs `then_branch`, which is executed when `if_cond` evaluates to
|
||||
// true.
|
||||
FuncOp then_branch_op = FuncOp::create(loc, "cond_true", func_type);
|
||||
CreateCondTrueBranch(resize_op, shape_dtype, result_type, then_branch_op,
|
||||
CreateCondTrueBranch(op, shape_dtype, result_type, then_branch_op,
|
||||
&rewriter);
|
||||
|
||||
// Constructs `else_branch`, which is executed when `if_cond` evaluates to
|
||||
@ -483,7 +473,7 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
// Inserts the two blocks' names into the symbol table held by the module.
|
||||
// Using SymbolTable will ensure that the inserted symbol names are
|
||||
// unique.
|
||||
SymbolTable manager(resize_op.getParentOfType<ModuleOp>());
|
||||
SymbolTable manager(op.getParentOfType<ModuleOp>());
|
||||
manager.insert(then_branch_op);
|
||||
manager.insert(else_branch_op);
|
||||
|
||||
@ -569,32 +559,28 @@ struct ConvertTensorListResize : public ConversionPattern {
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTensorListGetItem : public ConversionPattern {
|
||||
explicit ConvertTensorListGetItem(MLIRContext *context)
|
||||
: ConversionPattern(TF::TensorListGetItemOp::getOperationName(), 1,
|
||||
context) {}
|
||||
struct ConvertTensorListGetItem
|
||||
: public OpConversionPattern<TF::TensorListGetItemOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
TF::TensorListGetItemOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::TensorListGetItemOp>(operation);
|
||||
Value input = operands[0];
|
||||
Value index = operands[1];
|
||||
rewriter.replaceOpWithNewOp<TF::GatherOp>(
|
||||
operation, op.getType(), input, index, rewriter.getBoolAttr(true));
|
||||
rewriter.replaceOpWithNewOp<TF::GatherOp>(op, op.getType(), input, index,
|
||||
rewriter.getBoolAttr(true));
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTensorListLength : public ConversionPattern {
|
||||
explicit ConvertTensorListLength(MLIRContext *context)
|
||||
: ConversionPattern(TF::TensorListLengthOp::getOperationName(), 1,
|
||||
context) {}
|
||||
struct ConvertTensorListLength
|
||||
: public OpConversionPattern<TF::TensorListLengthOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
TF::TensorListLengthOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::TensorListLengthOp>(operation);
|
||||
Location loc = op.getLoc();
|
||||
Value input_handle = operands[0];
|
||||
|
||||
@ -608,15 +594,13 @@ struct ConvertTensorListLength : public ConversionPattern {
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertTensorListStack : public ConversionPattern {
|
||||
explicit ConvertTensorListStack(MLIRContext *context)
|
||||
: ConversionPattern(TF::TensorListStackOp::getOperationName(), 1,
|
||||
context) {}
|
||||
struct ConvertTensorListStack
|
||||
: public OpConversionPattern<TF::TensorListStackOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
TF::TensorListStackOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::TensorListStackOp>(operation);
|
||||
Location loc = op.getLoc();
|
||||
Value input = operands[0];
|
||||
Value element_shape = operands[1];
|
||||
@ -649,14 +633,12 @@ struct ConvertTensorListStack : public ConversionPattern {
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertIdentity : public ConversionPattern {
|
||||
explicit ConvertIdentity(MLIRContext *context)
|
||||
: ConversionPattern(TF::IdentityOp::getOperationName(), 1, context) {}
|
||||
struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
TF::IdentityOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::IdentityOp>(operation);
|
||||
Value input = operands[0];
|
||||
rewriter.replaceOpWithNewOp<TF::IdentityOp>(op, input.getType(), operands,
|
||||
op.getAttrs());
|
||||
@ -722,15 +704,12 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
struct ConvertWhile : public ConversionPattern {
|
||||
explicit ConvertWhile(MLIRContext *context)
|
||||
: ConversionPattern(TF::WhileOp::getOperationName(), 1, context) {}
|
||||
struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(
|
||||
Operation *operation, ArrayRef<Value> operands,
|
||||
TF::WhileOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto op = llvm::cast<TF::WhileOp>(operation);
|
||||
|
||||
llvm::SmallVector<Type, 8> result_types;
|
||||
result_types.reserve(op.getNumOperands());
|
||||
for (int i = 0, e = operands.size(); i != e; ++i) {
|
||||
|
Loading…
Reference in New Issue
Block a user