From 52b8ba5463c31d11d53e9971e9d461ba49b6b4d5 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Wed, 15 Jan 2020 18:06:31 -0800 Subject: [PATCH] NFC: Derive tfl-lower-static-tensor-list pass patterns from OpConversionPattern PiperOrigin-RevId: 289976784 Change-Id: I2e062774d1bba9df3e603a7585d614950d8ecfbd --- .../transforms/lower_static_tensor_list.cc | 117 +++++++----------- 1 file changed, 48 insertions(+), 69 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index bc8d9162b78..8c3f00359fc 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -162,10 +162,9 @@ TF::SliceOp CreateSliceOpForTensorList(Location loc, Value input_list, start_position, slice_size); } -struct ConvertTensorListSetItem : public ConversionPattern { - explicit ConvertTensorListSetItem(MLIRContext *context) - : ConversionPattern(TF::TensorListSetItemOp::getOperationName(), 1, - context) {} +struct ConvertTensorListSetItem + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; // This function rewrites the original op into a series of slice and concat op // to produce the same result. It first slices the first `$index` rows. Then @@ -180,9 +179,8 @@ struct ConvertTensorListSetItem : public ConversionPattern { // 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice // $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>; PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + TF::TensorListSetItemOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto op = llvm::cast(operation); Location loc = op.getLoc(); Value input = operands[0]; Value index = operands[1]; @@ -235,9 +233,8 @@ struct ConvertTensorListSetItem : public ConversionPattern { // to generate an equivalent raw tensor. Derived classes are required to // override GetNumElements method. template -struct ConvertTensorListInitOp : public ConversionPattern { - explicit ConvertTensorListInitOp(MLIRContext *context) - : ConversionPattern(OpT::getOperationName(), 1, context) {} +struct ConvertTensorListInitOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; // Create and return a 1-d tensor with exactly one element equal to the number // of list elements to initialize the output tensor list with. @@ -248,10 +245,8 @@ struct ConvertTensorListInitOp : public ConversionPattern { // [num_element, element_shape]. All the values in the result tensor will be // initialized to 0. PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + OpT op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - OpT op = llvm::cast(operation); - Type dtype = op.element_dtype(); if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() || dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) || @@ -260,7 +255,7 @@ struct ConvertTensorListInitOp : public ConversionPattern { "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit " "integer or 16-bit/32-bit/64-bit float type during TF Lite " "transformation pass"); - return matchFailure(); + return ConversionPattern::matchFailure(); } Value element_shape = operands[0]; @@ -376,15 +371,13 @@ struct ConvertEmptyTensorList } }; -struct ConvertTensorListPushBack : public ConversionPattern { - explicit ConvertTensorListPushBack(MLIRContext *context) - : ConversionPattern(TF::TensorListPushBackOp::getOperationName(), 1, - context) {} +struct ConvertTensorListPushBack + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - Operation *op, ArrayRef operands, + TF::TensorListPushBackOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - TF::TensorListPushBackOp push_back_op = cast(op); Value input_handle = operands[0]; Value item = operands[1]; @@ -392,21 +385,21 @@ struct ConvertTensorListPushBack : public ConversionPattern { // tensor and it is compatible for the Concat Op. Type expanded_item_type = PrependLeadingDimIfRanked(1, item.getType(), &rewriter); - Value scalar_zero = CreateI32SplatConst(op->getLoc(), &rewriter, {}, 0); + Location loc = op.getLoc(); + Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0); auto expanded_item = rewriter.create( - op->getLoc(), expanded_item_type, item, scalar_zero); + loc, expanded_item_type, item, scalar_zero); Type elem_type = getElementTypeOrSelf(item); - auto handle_dtype = - getElementTypeOrSelf(push_back_op.output_handle().getType()) - .cast(); + auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType()) + .cast(); Type result_type = GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); // Concatenate tensor stored in the input handle with the expanded item to // get a tensor equivalent to the TensorList generated by this op. rewriter.replaceOpWithNewOp( - push_back_op, result_type, scalar_zero, + op, result_type, scalar_zero, ArrayRef({input_handle, expanded_item})); return matchSuccess(); } @@ -422,31 +415,28 @@ struct ConvertTensorListPushBack : public ConversionPattern { // TODO(haoliang): We could simplify this transformation by rewriting to pure // tensorlist ops and a few non-tensorlist ops (such as `SliceOp`). By operating // only on variant types, we could save some ops involved in rewriting this op. -struct ConvertTensorListResize : public ConversionPattern { - explicit ConvertTensorListResize(MLIRContext *context) - : ConversionPattern(TF::TensorListResizeOp::getOperationName(), 1, - context) {} +struct ConvertTensorListResize + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - Operation *op, ArrayRef operands, + TF::TensorListResizeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - TF::TensorListResizeOp resize_op = cast(op); Value input_handle = operands[0]; Value size = operands[1]; - Location loc = resize_op.getLoc(); + Location loc = op.getLoc(); Value scalar_zero = CreateI32SplatConst(loc, &rewriter, {}, 0); // Compute the input tensorlist's length and store it in `input_size`. IntegerType shape_dtype = rewriter.getIntegerType(32); auto input_size = rewriter.create( - loc, RankedTensorType::get({}, shape_dtype), op->getOperand(0)); + loc, RankedTensorType::get({}, shape_dtype), op.getOperand(0)); // Infer result type of this op based on TF's shape inference result. Type elem_type = getElementTypeOrSelf(input_handle); - auto handle_dtype = - getElementTypeOrSelf(resize_op.output_handle().getType()) - .cast(); + auto handle_dtype = getElementTypeOrSelf(op.output_handle().getType()) + .cast(); Type result_type = GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); @@ -471,7 +461,7 @@ struct ConvertTensorListResize : public ConversionPattern { // Constructs `then_branch`, which is executed when `if_cond` evaluates to // true. FuncOp then_branch_op = FuncOp::create(loc, "cond_true", func_type); - CreateCondTrueBranch(resize_op, shape_dtype, result_type, then_branch_op, + CreateCondTrueBranch(op, shape_dtype, result_type, then_branch_op, &rewriter); // Constructs `else_branch`, which is executed when `if_cond` evaluates to @@ -483,7 +473,7 @@ struct ConvertTensorListResize : public ConversionPattern { // Inserts the two blocks' names into the symbol table held by the module. // Using SymbolTable will ensure that the inserted symbol names are // unique. - SymbolTable manager(resize_op.getParentOfType()); + SymbolTable manager(op.getParentOfType()); manager.insert(then_branch_op); manager.insert(else_branch_op); @@ -569,32 +559,28 @@ struct ConvertTensorListResize : public ConversionPattern { } }; -struct ConvertTensorListGetItem : public ConversionPattern { - explicit ConvertTensorListGetItem(MLIRContext *context) - : ConversionPattern(TF::TensorListGetItemOp::getOperationName(), 1, - context) {} +struct ConvertTensorListGetItem + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + TF::TensorListGetItemOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto op = llvm::cast(operation); Value input = operands[0]; Value index = operands[1]; - rewriter.replaceOpWithNewOp( - operation, op.getType(), input, index, rewriter.getBoolAttr(true)); + rewriter.replaceOpWithNewOp(op, op.getType(), input, index, + rewriter.getBoolAttr(true)); return matchSuccess(); } }; -struct ConvertTensorListLength : public ConversionPattern { - explicit ConvertTensorListLength(MLIRContext *context) - : ConversionPattern(TF::TensorListLengthOp::getOperationName(), 1, - context) {} +struct ConvertTensorListLength + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + TF::TensorListLengthOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto op = llvm::cast(operation); Location loc = op.getLoc(); Value input_handle = operands[0]; @@ -608,15 +594,13 @@ struct ConvertTensorListLength : public ConversionPattern { } }; -struct ConvertTensorListStack : public ConversionPattern { - explicit ConvertTensorListStack(MLIRContext *context) - : ConversionPattern(TF::TensorListStackOp::getOperationName(), 1, - context) {} +struct ConvertTensorListStack + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + TF::TensorListStackOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto op = llvm::cast(operation); Location loc = op.getLoc(); Value input = operands[0]; Value element_shape = operands[1]; @@ -649,14 +633,12 @@ struct ConvertTensorListStack : public ConversionPattern { } }; -struct ConvertIdentity : public ConversionPattern { - explicit ConvertIdentity(MLIRContext *context) - : ConversionPattern(TF::IdentityOp::getOperationName(), 1, context) {} +struct ConvertIdentity : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + TF::IdentityOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto op = llvm::cast(operation); Value input = operands[0]; rewriter.replaceOpWithNewOp(op, input.getType(), operands, op.getAttrs()); @@ -722,15 +704,12 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) { return success(); } -struct ConvertWhile : public ConversionPattern { - explicit ConvertWhile(MLIRContext *context) - : ConversionPattern(TF::WhileOp::getOperationName(), 1, context) {} +struct ConvertWhile : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; PatternMatchResult matchAndRewrite( - Operation *operation, ArrayRef operands, + TF::WhileOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto op = llvm::cast(operation); - llvm::SmallVector result_types; result_types.reserve(op.getNumOperands()); for (int i = 0, e = operands.size(); i != e; ++i) {