[MLIR] Extend LowerStaticTensorListPass to handle WhileRegion

PiperOrigin-RevId: 321370545
Change-Id: Ieb8de21a584df9770d8806fccfa811c85c2d76ee
This commit is contained in:
Rahul Joshi 2020-07-15 09:02:28 -07:00 committed by TensorFlower Gardener
parent 49369a6652
commit dada5c989e
2 changed files with 113 additions and 43 deletions

View File

@ -277,6 +277,45 @@ func @tensorlistWhileCond(%arg0: tensor<i32>, %arg1: tensor<!tf.variant>) -> ten
// CHECK: return %[[RESULT]] : tensor<i1>
}
// CHECK-LABEL: func @tensorlistWhileRegion
func @tensorlistWhileRegion(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
%cst = constant dense<3> : tensor<1xi32>
%cst_0 = constant dense<0> : tensor<i32>
%cst_1 = constant dense<-1> : tensor<i32>
%0 = "tf.TensorListFromTensor"(%arg0, %cst) : (tensor<2x3xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<3xf32>>>
// CHECK: "tf.WhileRegion"
%1:2 = "tf.WhileRegion"(%cst_0, %0) ({
^bb0(%carg0: tensor<i32>, %carg1: tensor<!tf.variant>):
%cst_2 = constant dense<2> : tensor<i32>
%1 = "tf.Less"(%carg0, %cst_2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%1) : (tensor<i1>) -> ()
// verify condition types
// CHECK: ^bb0(%[[CARG0:.*]]: tensor<i32>, %[[CARG1:.*]]: tensor<*xf32>):
// CHECK: %[[COND:.*]] = "tf.Less"(%[[CARG0]], {{.*}}) : (tensor<i32>, tensor<i32>) -> tensor<i1>
// CHECK: "tf.Yield"(%[[COND]]) : (tensor<i1>) -> ()
},
{
^bb0(%barg0: tensor<i32>, %barg1: tensor<!tf.variant>):
%1 = "tf.TensorListLength"(%barg1) : (tensor<!tf.variant>) -> tensor<i32>
"tf.Yield"(%1, %barg1) : (tensor<i32>, tensor<!tf.variant>) -> ()
// verify body types
// CHECK: ^bb0(%[[BARG0:.*]]: tensor<i32>, %[[BARG1:.*]]: tensor<*xf32>):
// CHECK-NOT: tensor<!tf.variant>
// CHECK: %[[LEN:.*]] = "tf.Gather"
// CHECK-NOT: tensor<!tf.variant>
// CHECK: "tf.Yield"(%[[LEN]], %[[BARG1]]) : (tensor<i32>, tensor<*xf32>) -> ()
}) {is_stateless = false} : (tensor<i32>, tensor<!tf.variant<tensor<3xf32>>>) -> (tensor<i32>, tensor<!tf.variant<tensor<*xf32>>>)
// make sure the variant types in input/output have been updated
// CHECK: {is_stateless = false} : (tensor<i32>, tensor<2x3xf32>) -> (tensor<i32>, tensor<*xf32>)
%2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>) -> tensor<*xf32>
// CHECK: return %0#1 : tensor<*xf32>
return %2 : tensor<*xf32>
}
func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> tensor<?x10xf32> {
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
%1 = "tf.TensorListResize"(%0, %arg2) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>) -> tensor<!tf.variant<tensor<10xf32>>>

View File

@ -17,7 +17,7 @@ limitations under the License.
// converting Tensorlist operations in TensorFlow dialect into operations that
// can be legalized to TensorFlow Lite dialect with simple replacements. The
// newly created operations are in the TensorFlow dialect if the operation can
// be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
// be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
// is used.
#include <climits>
@ -738,9 +738,17 @@ struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
}
};
// Returns an unranked tensor type with an element of the same type as `value`
// if `type` is a tensor of variant. Otherwise, returns `type` unmodified.
Type VariantToUnrankedTensorType(Type type, Value value) {
if (getElementTypeOrSelf(type).isa<TF::VariantType>())
return UnrankedTensorType::get(getElementTypeOrSelf(value.getType()));
return type;
}
// Changes the function type of `cond_func` and `body_func` for the given While
// op.
static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
auto module = op.getParentOfType<ModuleOp>();
auto *context = module.getContext();
@ -756,30 +764,18 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
// tensor type if it's a variant type.
SmallVector<Type, 8> updated_argument_types;
updated_argument_types.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
Type arg_type = func_type.getInput(i);
if (getElementTypeOrSelf(arg_type).isa<TF::VariantType>()) {
arg_type = UnrankedTensorType::get(
getElementTypeOrSelf(op.getOperand(i).getType()));
}
updated_argument_types.push_back(arg_type);
}
for (auto it : llvm::zip(func_type.getInputs(), op.getOperands()))
updated_argument_types.push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
// For each result type in function's results, change it to unranked tensor
// type if it's a variant type.
// Change all DT_VARIANT result types in function results to unranked tensor
// type with element type derived from the corresponding input operand. This
// is correct because while body's inputs and results have the same type.
SmallVector<Type, 8> updated_result_types;
updated_result_types.reserve(num_results);
for (int i = 0; i < num_results; ++i) {
Type result_type = func_type.getResult(i);
if (getElementTypeOrSelf(result_type).isa<TF::VariantType>()) {
// Here update the variant type with the unranked tensor type derived
// from the corresponding input operand. This is correct because while
// body's inputs and results have the same type.
result_type = UnrankedTensorType::get(
getElementTypeOrSelf(op.getOperand(i).getType()));
}
updated_result_types.push_back(result_type);
}
for (auto it : llvm::zip(func_type.getResults(), op.getOperands()))
updated_result_types.push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
// Change `func`'s argument type to `unranked_argument_types`. If it
// return types contain a `DT_VARIANT`, change it to the unranked type
@ -788,10 +784,9 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
context));
// Change the argument type for the first block.
Block &body_first_bb = func.front();
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
body_first_bb.getArgument(i).setType(updated_argument_types[i]);
}
llvm::for_each(func.getArguments(), [&](BlockArgument &arg) {
arg.setType(updated_argument_types[arg.getArgNumber()]);
});
}
return success();
}
@ -804,25 +799,60 @@ struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<Type, 8> result_types;
result_types.reserve(op.getNumOperands());
for (int i = 0, e = operands.size(); i != e; ++i) {
Type result_ty = op.getResult(i).getType();
// Change all DT_VARIANT result types to unranked tensor type.
for (auto it : llvm::zip(op.getResultTypes(), operands))
result_types.push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
// If we notice the result type is a DT_VARIANT, we change the
// corresponding result type to unranked tensor type.
if (getElementTypeOrSelf(result_ty).isa<TF::VariantType>()) {
Type element_ty = getElementTypeOrSelf(operands[i].getType());
result_ty = UnrankedTensorType::get(element_ty);
// Create a new while op with new operands and updated result types.
auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
operands, op.getAttrs());
converted.removeAttr("T");
UpdateFunctionTypes(converted);
rewriter.replaceOp(op, converted.getResults());
return success();
}
};
struct ConvertWhileRegion : public OpConversionPattern<TF::WhileRegionOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
TF::WhileRegionOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
llvm::SmallVector<Type, 8> result_types;
result_types.reserve(op.getNumOperands());
// Change all DT_VARIANT result types to unranked tensor type.
for (auto it : llvm::zip(op.getResultTypes(), operands))
result_types.push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
// Create a new while op with new operands and updated result types.
auto converted = rewriter.create<TF::WhileRegionOp>(
op.getLoc(), result_types, operands, op.getAttrs());
// Inline the regions from the old while into the new one, and apply
// signature conversion to inlined region.
for (auto it : llvm::zip(op.getRegions(), converted.getRegions())) {
Region &old_region = *std::get<0>(it);
Region &new_region = *std::get<1>(it);
Block &entry = old_region.front();
// Build signature conversion for the region.
TypeConverter::SignatureConversion signature_conversion(operands.size());
for (auto it : llvm::zip(entry.getArguments(), operands)) {
BlockArgument arg = std::get<0>(it);
signature_conversion.addInputs(
arg.getArgNumber(),
VariantToUnrankedTensorType(arg.getType(), std::get<1>(it)));
}
result_types.push_back(result_ty);
rewriter.inlineRegionBefore(old_region, new_region, new_region.end());
rewriter.applySignatureConversion(&new_region, signature_conversion);
}
// Clone original while op with new operands and updated result types.
auto cloned = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
operands, op.getAttrs());
cloned.removeAttr("T");
UpdateFunctionTypes(cloned);
rewriter.replaceOp(op, cloned.getResults());
rewriter.replaceOp(op, converted.getResults());
return success();
}
};
@ -871,7 +901,8 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
ConvertTensorListGetItem, ConvertTensorListLength,
ConvertTensorListPushBack, ConvertTensorListReserve,
ConvertTensorListSetItem, ConvertTensorListStack,
ConvertTensorListResize, ConvertWhile>(context);
ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>(
context);
return applyPartialConversion(func, target, patterns);
}