[MLIR] Extend LowerStaticTensorListPass to handle WhileRegion
PiperOrigin-RevId: 321370545 Change-Id: Ieb8de21a584df9770d8806fccfa811c85c2d76ee
This commit is contained in:
parent
49369a6652
commit
dada5c989e
@ -277,6 +277,45 @@ func @tensorlistWhileCond(%arg0: tensor<i32>, %arg1: tensor<!tf.variant>) -> ten
|
|||||||
// CHECK: return %[[RESULT]] : tensor<i1>
|
// CHECK: return %[[RESULT]] : tensor<i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @tensorlistWhileRegion
|
||||||
|
func @tensorlistWhileRegion(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant dense<3> : tensor<1xi32>
|
||||||
|
%cst_0 = constant dense<0> : tensor<i32>
|
||||||
|
%cst_1 = constant dense<-1> : tensor<i32>
|
||||||
|
%0 = "tf.TensorListFromTensor"(%arg0, %cst) : (tensor<2x3xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<3xf32>>>
|
||||||
|
// CHECK: "tf.WhileRegion"
|
||||||
|
%1:2 = "tf.WhileRegion"(%cst_0, %0) ({
|
||||||
|
^bb0(%carg0: tensor<i32>, %carg1: tensor<!tf.variant>):
|
||||||
|
%cst_2 = constant dense<2> : tensor<i32>
|
||||||
|
%1 = "tf.Less"(%carg0, %cst_2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
"tf.Yield"(%1) : (tensor<i1>) -> ()
|
||||||
|
|
||||||
|
// verify condition types
|
||||||
|
// CHECK: ^bb0(%[[CARG0:.*]]: tensor<i32>, %[[CARG1:.*]]: tensor<*xf32>):
|
||||||
|
// CHECK: %[[COND:.*]] = "tf.Less"(%[[CARG0]], {{.*}}) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
// CHECK: "tf.Yield"(%[[COND]]) : (tensor<i1>) -> ()
|
||||||
|
|
||||||
|
},
|
||||||
|
{
|
||||||
|
^bb0(%barg0: tensor<i32>, %barg1: tensor<!tf.variant>):
|
||||||
|
%1 = "tf.TensorListLength"(%barg1) : (tensor<!tf.variant>) -> tensor<i32>
|
||||||
|
"tf.Yield"(%1, %barg1) : (tensor<i32>, tensor<!tf.variant>) -> ()
|
||||||
|
|
||||||
|
// verify body types
|
||||||
|
// CHECK: ^bb0(%[[BARG0:.*]]: tensor<i32>, %[[BARG1:.*]]: tensor<*xf32>):
|
||||||
|
// CHECK-NOT: tensor<!tf.variant>
|
||||||
|
// CHECK: %[[LEN:.*]] = "tf.Gather"
|
||||||
|
// CHECK-NOT: tensor<!tf.variant>
|
||||||
|
// CHECK: "tf.Yield"(%[[LEN]], %[[BARG1]]) : (tensor<i32>, tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
}) {is_stateless = false} : (tensor<i32>, tensor<!tf.variant<tensor<3xf32>>>) -> (tensor<i32>, tensor<!tf.variant<tensor<*xf32>>>)
|
||||||
|
// make sure the variant types in input/output have been updated
|
||||||
|
// CHECK: {is_stateless = false} : (tensor<i32>, tensor<2x3xf32>) -> (tensor<i32>, tensor<*xf32>)
|
||||||
|
%2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>) -> tensor<*xf32>
|
||||||
|
// CHECK: return %0#1 : tensor<*xf32>
|
||||||
|
return %2 : tensor<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> tensor<?x10xf32> {
|
func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> tensor<?x10xf32> {
|
||||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||||
%1 = "tf.TensorListResize"(%0, %arg2) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
%1 = "tf.TensorListResize"(%0, %arg2) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
// converting Tensorlist operations in TensorFlow dialect into operations that
|
// converting Tensorlist operations in TensorFlow dialect into operations that
|
||||||
// can be legalized to TensorFlow Lite dialect with simple replacements. The
|
// can be legalized to TensorFlow Lite dialect with simple replacements. The
|
||||||
// newly created operations are in the TensorFlow dialect if the operation can
|
// newly created operations are in the TensorFlow dialect if the operation can
|
||||||
// be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
|
// be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op
|
||||||
// is used.
|
// is used.
|
||||||
|
|
||||||
#include <climits>
|
#include <climits>
|
||||||
@ -738,9 +738,17 @@ struct ConvertIdentity : public OpConversionPattern<TF::IdentityOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Returns an unranked tensor type with an element of the same type as `value`
|
||||||
|
// if `type` is a tensor of variant. Otherwise, returns `type` unmodified.
|
||||||
|
Type VariantToUnrankedTensorType(Type type, Value value) {
|
||||||
|
if (getElementTypeOrSelf(type).isa<TF::VariantType>())
|
||||||
|
return UnrankedTensorType::get(getElementTypeOrSelf(value.getType()));
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
// Changes the function type of `cond_func` and `body_func` for the given While
|
// Changes the function type of `cond_func` and `body_func` for the given While
|
||||||
// op.
|
// op.
|
||||||
static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||||
auto module = op.getParentOfType<ModuleOp>();
|
auto module = op.getParentOfType<ModuleOp>();
|
||||||
auto *context = module.getContext();
|
auto *context = module.getContext();
|
||||||
|
|
||||||
@ -756,30 +764,18 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
|||||||
// tensor type if it's a variant type.
|
// tensor type if it's a variant type.
|
||||||
SmallVector<Type, 8> updated_argument_types;
|
SmallVector<Type, 8> updated_argument_types;
|
||||||
updated_argument_types.reserve(num_inputs);
|
updated_argument_types.reserve(num_inputs);
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
for (auto it : llvm::zip(func_type.getInputs(), op.getOperands()))
|
||||||
Type arg_type = func_type.getInput(i);
|
updated_argument_types.push_back(
|
||||||
if (getElementTypeOrSelf(arg_type).isa<TF::VariantType>()) {
|
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
|
||||||
arg_type = UnrankedTensorType::get(
|
|
||||||
getElementTypeOrSelf(op.getOperand(i).getType()));
|
|
||||||
}
|
|
||||||
updated_argument_types.push_back(arg_type);
|
|
||||||
}
|
|
||||||
|
|
||||||
// For each result type in function's results, change it to unranked tensor
|
// Change all DT_VARIANT result types in function results to unranked tensor
|
||||||
// type if it's a variant type.
|
// type with element type derived from the corresponding input operand. This
|
||||||
|
// is correct because while body's inputs and results have the same type.
|
||||||
SmallVector<Type, 8> updated_result_types;
|
SmallVector<Type, 8> updated_result_types;
|
||||||
updated_result_types.reserve(num_results);
|
updated_result_types.reserve(num_results);
|
||||||
for (int i = 0; i < num_results; ++i) {
|
for (auto it : llvm::zip(func_type.getResults(), op.getOperands()))
|
||||||
Type result_type = func_type.getResult(i);
|
updated_result_types.push_back(
|
||||||
if (getElementTypeOrSelf(result_type).isa<TF::VariantType>()) {
|
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
|
||||||
// Here update the variant type with the unranked tensor type derived
|
|
||||||
// from the corresponding input operand. This is correct because while
|
|
||||||
// body's inputs and results have the same type.
|
|
||||||
result_type = UnrankedTensorType::get(
|
|
||||||
getElementTypeOrSelf(op.getOperand(i).getType()));
|
|
||||||
}
|
|
||||||
updated_result_types.push_back(result_type);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Change `func`'s argument type to `unranked_argument_types`. If it
|
// Change `func`'s argument type to `unranked_argument_types`. If it
|
||||||
// return types contain a `DT_VARIANT`, change it to the unranked type
|
// return types contain a `DT_VARIANT`, change it to the unranked type
|
||||||
@ -788,10 +784,9 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
|||||||
context));
|
context));
|
||||||
|
|
||||||
// Change the argument type for the first block.
|
// Change the argument type for the first block.
|
||||||
Block &body_first_bb = func.front();
|
llvm::for_each(func.getArguments(), [&](BlockArgument &arg) {
|
||||||
for (int i = 0; i < body_first_bb.getNumArguments(); ++i) {
|
arg.setType(updated_argument_types[arg.getArgNumber()]);
|
||||||
body_first_bb.getArgument(i).setType(updated_argument_types[i]);
|
});
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@ -804,25 +799,60 @@ struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
|
|||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
llvm::SmallVector<Type, 8> result_types;
|
llvm::SmallVector<Type, 8> result_types;
|
||||||
result_types.reserve(op.getNumOperands());
|
result_types.reserve(op.getNumOperands());
|
||||||
for (int i = 0, e = operands.size(); i != e; ++i) {
|
// Change all DT_VARIANT result types to unranked tensor type.
|
||||||
Type result_ty = op.getResult(i).getType();
|
for (auto it : llvm::zip(op.getResultTypes(), operands))
|
||||||
|
result_types.push_back(
|
||||||
|
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
|
||||||
|
|
||||||
// If we notice the result type is a DT_VARIANT, we change the
|
// Create a new while op with new operands and updated result types.
|
||||||
// corresponding result type to unranked tensor type.
|
auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
|
||||||
if (getElementTypeOrSelf(result_ty).isa<TF::VariantType>()) {
|
operands, op.getAttrs());
|
||||||
Type element_ty = getElementTypeOrSelf(operands[i].getType());
|
converted.removeAttr("T");
|
||||||
result_ty = UnrankedTensorType::get(element_ty);
|
UpdateFunctionTypes(converted);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, converted.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ConvertWhileRegion : public OpConversionPattern<TF::WhileRegionOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
TF::WhileRegionOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
llvm::SmallVector<Type, 8> result_types;
|
||||||
|
result_types.reserve(op.getNumOperands());
|
||||||
|
// Change all DT_VARIANT result types to unranked tensor type.
|
||||||
|
for (auto it : llvm::zip(op.getResultTypes(), operands))
|
||||||
|
result_types.push_back(
|
||||||
|
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
|
||||||
|
|
||||||
|
// Create a new while op with new operands and updated result types.
|
||||||
|
auto converted = rewriter.create<TF::WhileRegionOp>(
|
||||||
|
op.getLoc(), result_types, operands, op.getAttrs());
|
||||||
|
|
||||||
|
// Inline the regions from the old while into the new one, and apply
|
||||||
|
// signature conversion to inlined region.
|
||||||
|
for (auto it : llvm::zip(op.getRegions(), converted.getRegions())) {
|
||||||
|
Region &old_region = *std::get<0>(it);
|
||||||
|
Region &new_region = *std::get<1>(it);
|
||||||
|
|
||||||
|
Block &entry = old_region.front();
|
||||||
|
// Build signature conversion for the region.
|
||||||
|
TypeConverter::SignatureConversion signature_conversion(operands.size());
|
||||||
|
for (auto it : llvm::zip(entry.getArguments(), operands)) {
|
||||||
|
BlockArgument arg = std::get<0>(it);
|
||||||
|
signature_conversion.addInputs(
|
||||||
|
arg.getArgNumber(),
|
||||||
|
VariantToUnrankedTensorType(arg.getType(), std::get<1>(it)));
|
||||||
}
|
}
|
||||||
result_types.push_back(result_ty);
|
|
||||||
|
rewriter.inlineRegionBefore(old_region, new_region, new_region.end());
|
||||||
|
rewriter.applySignatureConversion(&new_region, signature_conversion);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clone original while op with new operands and updated result types.
|
rewriter.replaceOp(op, converted.getResults());
|
||||||
auto cloned = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
|
|
||||||
operands, op.getAttrs());
|
|
||||||
cloned.removeAttr("T");
|
|
||||||
UpdateFunctionTypes(cloned);
|
|
||||||
|
|
||||||
rewriter.replaceOp(op, cloned.getResults());
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -871,7 +901,8 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction(
|
|||||||
ConvertTensorListGetItem, ConvertTensorListLength,
|
ConvertTensorListGetItem, ConvertTensorListLength,
|
||||||
ConvertTensorListPushBack, ConvertTensorListReserve,
|
ConvertTensorListPushBack, ConvertTensorListReserve,
|
||||||
ConvertTensorListSetItem, ConvertTensorListStack,
|
ConvertTensorListSetItem, ConvertTensorListStack,
|
||||||
ConvertTensorListResize, ConvertWhile>(context);
|
ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>(
|
||||||
|
context);
|
||||||
return applyPartialConversion(func, target, patterns);
|
return applyPartialConversion(func, target, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user