diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir index 1a61bc3f517..1ebe912284b 100644 --- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir +++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir @@ -277,6 +277,45 @@ func @tensorlistWhileCond(%arg0: tensor, %arg1: tensor) -> ten // CHECK: return %[[RESULT]] : tensor } +// CHECK-LABEL: func @tensorlistWhileRegion +func @tensorlistWhileRegion(%arg0: tensor<2x3xf32>) -> tensor<*xf32> { + %cst = constant dense<3> : tensor<1xi32> + %cst_0 = constant dense<0> : tensor + %cst_1 = constant dense<-1> : tensor + %0 = "tf.TensorListFromTensor"(%arg0, %cst) : (tensor<2x3xf32>, tensor<1xi32>) -> tensor>> + // CHECK: "tf.WhileRegion" + %1:2 = "tf.WhileRegion"(%cst_0, %0) ({ + ^bb0(%carg0: tensor, %carg1: tensor): + %cst_2 = constant dense<2> : tensor + %1 = "tf.Less"(%carg0, %cst_2) : (tensor, tensor) -> tensor + "tf.Yield"(%1) : (tensor) -> () + + // verify condition types + // CHECK: ^bb0(%[[CARG0:.*]]: tensor, %[[CARG1:.*]]: tensor<*xf32>): + // CHECK: %[[COND:.*]] = "tf.Less"(%[[CARG0]], {{.*}}) : (tensor, tensor) -> tensor + // CHECK: "tf.Yield"(%[[COND]]) : (tensor) -> () + + }, + { + ^bb0(%barg0: tensor, %barg1: tensor): + %1 = "tf.TensorListLength"(%barg1) : (tensor) -> tensor + "tf.Yield"(%1, %barg1) : (tensor, tensor) -> () + + // verify body types + // CHECK: ^bb0(%[[BARG0:.*]]: tensor, %[[BARG1:.*]]: tensor<*xf32>): + // CHECK-NOT: tensor + // CHECK: %[[LEN:.*]] = "tf.Gather" + // CHECK-NOT: tensor + // CHECK: "tf.Yield"(%[[LEN]], %[[BARG1]]) : (tensor, tensor<*xf32>) -> () + + }) {is_stateless = false} : (tensor, tensor>>) -> (tensor, tensor>>) + // make sure the variant types in input/output have been updated + // CHECK: {is_stateless = false} : (tensor, tensor<2x3xf32>) -> (tensor, tensor<*xf32>) + %2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor>>, tensor) -> tensor<*xf32> + // CHECK: return %0#1 : tensor<*xf32> + return %2 : tensor<*xf32> +} + func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor) -> tensor { %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor>> %1 = "tf.TensorListResize"(%0, %arg2) : (tensor>>, tensor) -> tensor>> diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index c76a6cfafab..439c44dc77e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -17,7 +17,7 @@ limitations under the License. // converting Tensorlist operations in TensorFlow dialect into operations that // can be legalized to TensorFlow Lite dialect with simple replacements. The // newly created operations are in the TensorFlow dialect if the operation can -// be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op +// be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op // is used. #include @@ -738,9 +738,17 @@ struct ConvertIdentity : public OpConversionPattern { } }; +// Returns an unranked tensor type with an element of the same type as `value` +// if `type` is a tensor of variant. Otherwise, returns `type` unmodified. +Type VariantToUnrankedTensorType(Type type, Value value) { + if (getElementTypeOrSelf(type).isa()) + return UnrankedTensorType::get(getElementTypeOrSelf(value.getType())); + return type; +} + // Changes the function type of `cond_func` and `body_func` for the given While // op. -static LogicalResult UpdateFunctionTypes(TF::WhileOp op) { +LogicalResult UpdateFunctionTypes(TF::WhileOp op) { auto module = op.getParentOfType(); auto *context = module.getContext(); @@ -756,30 +764,18 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) { // tensor type if it's a variant type. SmallVector updated_argument_types; updated_argument_types.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - Type arg_type = func_type.getInput(i); - if (getElementTypeOrSelf(arg_type).isa()) { - arg_type = UnrankedTensorType::get( - getElementTypeOrSelf(op.getOperand(i).getType())); - } - updated_argument_types.push_back(arg_type); - } + for (auto it : llvm::zip(func_type.getInputs(), op.getOperands())) + updated_argument_types.push_back( + VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it))); - // For each result type in function's results, change it to unranked tensor - // type if it's a variant type. + // Change all DT_VARIANT result types in function results to unranked tensor + // type with element type derived from the corresponding input operand. This + // is correct because while body's inputs and results have the same type. SmallVector updated_result_types; updated_result_types.reserve(num_results); - for (int i = 0; i < num_results; ++i) { - Type result_type = func_type.getResult(i); - if (getElementTypeOrSelf(result_type).isa()) { - // Here update the variant type with the unranked tensor type derived - // from the corresponding input operand. This is correct because while - // body's inputs and results have the same type. - result_type = UnrankedTensorType::get( - getElementTypeOrSelf(op.getOperand(i).getType())); - } - updated_result_types.push_back(result_type); - } + for (auto it : llvm::zip(func_type.getResults(), op.getOperands())) + updated_result_types.push_back( + VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it))); // Change `func`'s argument type to `unranked_argument_types`. If it // return types contain a `DT_VARIANT`, change it to the unranked type @@ -788,10 +784,9 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) { context)); // Change the argument type for the first block. - Block &body_first_bb = func.front(); - for (int i = 0; i < body_first_bb.getNumArguments(); ++i) { - body_first_bb.getArgument(i).setType(updated_argument_types[i]); - } + llvm::for_each(func.getArguments(), [&](BlockArgument &arg) { + arg.setType(updated_argument_types[arg.getArgNumber()]); + }); } return success(); } @@ -804,25 +799,60 @@ struct ConvertWhile : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { llvm::SmallVector result_types; result_types.reserve(op.getNumOperands()); - for (int i = 0, e = operands.size(); i != e; ++i) { - Type result_ty = op.getResult(i).getType(); + // Change all DT_VARIANT result types to unranked tensor type. + for (auto it : llvm::zip(op.getResultTypes(), operands)) + result_types.push_back( + VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it))); - // If we notice the result type is a DT_VARIANT, we change the - // corresponding result type to unranked tensor type. - if (getElementTypeOrSelf(result_ty).isa()) { - Type element_ty = getElementTypeOrSelf(operands[i].getType()); - result_ty = UnrankedTensorType::get(element_ty); + // Create a new while op with new operands and updated result types. + auto converted = rewriter.create(op.getLoc(), result_types, + operands, op.getAttrs()); + converted.removeAttr("T"); + UpdateFunctionTypes(converted); + + rewriter.replaceOp(op, converted.getResults()); + return success(); + } +}; + +struct ConvertWhileRegion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + TF::WhileRegionOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallVector result_types; + result_types.reserve(op.getNumOperands()); + // Change all DT_VARIANT result types to unranked tensor type. + for (auto it : llvm::zip(op.getResultTypes(), operands)) + result_types.push_back( + VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it))); + + // Create a new while op with new operands and updated result types. + auto converted = rewriter.create( + op.getLoc(), result_types, operands, op.getAttrs()); + + // Inline the regions from the old while into the new one, and apply + // signature conversion to inlined region. + for (auto it : llvm::zip(op.getRegions(), converted.getRegions())) { + Region &old_region = *std::get<0>(it); + Region &new_region = *std::get<1>(it); + + Block &entry = old_region.front(); + // Build signature conversion for the region. + TypeConverter::SignatureConversion signature_conversion(operands.size()); + for (auto it : llvm::zip(entry.getArguments(), operands)) { + BlockArgument arg = std::get<0>(it); + signature_conversion.addInputs( + arg.getArgNumber(), + VariantToUnrankedTensorType(arg.getType(), std::get<1>(it))); } - result_types.push_back(result_ty); + + rewriter.inlineRegionBefore(old_region, new_region, new_region.end()); + rewriter.applySignatureConversion(&new_region, signature_conversion); } - // Clone original while op with new operands and updated result types. - auto cloned = rewriter.create(op.getLoc(), result_types, - operands, op.getAttrs()); - cloned.removeAttr("T"); - UpdateFunctionTypes(cloned); - - rewriter.replaceOp(op, cloned.getResults()); + rewriter.replaceOp(op, converted.getResults()); return success(); } }; @@ -871,7 +901,8 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( ConvertTensorListGetItem, ConvertTensorListLength, ConvertTensorListPushBack, ConvertTensorListReserve, ConvertTensorListSetItem, ConvertTensorListStack, - ConvertTensorListResize, ConvertWhile>(context); + ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>( + context); return applyPartialConversion(func, target, patterns); }