Do not add unranked type to the while's non tensor list variant arguments while
lowering tensor lists If not, non tensor list variant arguments will not match with the result of the variant operators in the body or cond function. PiperOrigin-RevId: 352373987 Change-Id: I9a6154f098207241b3f1700c3a1025b6a28a53ed
This commit is contained in:
parent
3f837ab7ab
commit
a2766ce758
@ -316,6 +316,50 @@ func @tensorlistWhileRegion(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
|
||||
return %2 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @otherVariantWhileLoop(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
|
||||
%2 = "tf.EmptyTensorMap"() {device = ""} : () -> tensor<!tf.variant>
|
||||
%3:4 = "tf.While"(%0, %1, %0, %2) {_lower_using_switch_merge = true, _num_original_outputs = 4 : i64, _read_only_resource_inputs = [], body = @otherVariantWhileBody, cond = @otherVariantWhileCond, device = "", is_stateless = true, parallel_iterations = 10 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.variant>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.variant>)
|
||||
%4 = "tf.Identity"(%3#3) {device = ""} : (tensor<!tf.variant>) -> tensor<!tf.variant>
|
||||
%5 = "tf.TensorMapSize"(%4) {device = ""} : (tensor<!tf.variant>) -> tensor<i32>
|
||||
%6 = "tf.AddV2"(%arg0, %5) {device = ""} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
return %6 : tensor<1xi32>
|
||||
}
|
||||
|
||||
// Make sure the non TensorList variant types in input/output have remained.
|
||||
// CHECK-LABEL: otherVariantWhileLoop
|
||||
// CHECK: "tf.While"
|
||||
// CHECK-SAME: (tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.variant>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.variant>)
|
||||
|
||||
func @otherVariantWhileBody(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<!tf.variant>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.variant>) {
|
||||
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.AddV2"(%arg2, %0) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%2 = "tf.TensorMapInsert"(%arg3, %arg2, %arg2) {device = "", key_dtype = i32, value_dtype = i32} : (tensor<!tf.variant>, tensor<i32>, tensor<i32>) -> tensor<!tf.variant>
|
||||
%3 = "tf.AddV2"(%arg0, %0) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
return %3, %arg1, %1, %2 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.variant>
|
||||
}
|
||||
|
||||
// Verify `body` function's signature.
|
||||
// CHECK-LABEL: func @otherVariantWhileBody
|
||||
// CHECK: [[CST:%.*]] = "tf.Const"()
|
||||
// CHECK-NEXT: [[ADD:%.*]] = "tf.AddV2"(%arg2, [[CST]])
|
||||
// CHECK-NEXT: [[TENSOR_MAP_INSERT_RESULT:%.*]] = "tf.TensorMapInsert"(%arg3, %arg2, %arg2)
|
||||
// CHECK-NEXT: [[ADD_2:%.*]] = "tf.AddV2"(%arg0, [[CST]])
|
||||
// CHECK-NEXT: return [[ADD_2]], %arg1, [[ADD]], [[TENSOR_MAP_INSERT_RESULT]]
|
||||
|
||||
func @otherVariantWhileCond(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<!tf.variant>) -> tensor<i1> {
|
||||
%0 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.Less"(%arg2, %0) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %1 : tensor<i1>
|
||||
}
|
||||
|
||||
// Verify `cond` function's signature.
|
||||
// CHECK-LABEL: func @otherVariantWhileCond
|
||||
// CHECK: [[CST:%.*]] = "tf.Const"()
|
||||
// CHECK-NEXT: [[LESS:%.*]] = "tf.Less"(%arg2, [[CST]])
|
||||
// CHECK-NEXT: return [[LESS]]
|
||||
|
||||
func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> tensor<?x10xf32> {
|
||||
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||
%1 = "tf.TensorListResize"(%0, %arg2) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>) -> tensor<!tf.variant<tensor<10xf32>>>
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/None.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
@ -753,10 +754,42 @@ Type VariantToUnrankedTensorType(Type type, Value value) {
|
||||
return type;
|
||||
}
|
||||
|
||||
llvm::SmallSet<int, 4> GetTensorListArgumentsFromWhileOp(TF::WhileOp op) {
|
||||
llvm::SmallSet<int, 4> set;
|
||||
for (FuncOp func : {op.cond_function(), op.body_function()}) {
|
||||
if (!func) continue;
|
||||
|
||||
for (auto arg_and_idx : llvm::enumerate(func.getArguments())) {
|
||||
mlir::BlockArgument arg = arg_and_idx.value();
|
||||
auto variant_ty =
|
||||
getElementTypeOrSelf(arg.getType()).dyn_cast<TF::VariantType>();
|
||||
if (!variant_ty) continue;
|
||||
|
||||
for (auto &op_operand : arg.getUses()) {
|
||||
auto op = op_operand.getOwner();
|
||||
if (llvm::isa<TF::TensorListGetItemOp>(op) ||
|
||||
llvm::isa<TF::TensorListLengthOp>(op) ||
|
||||
llvm::isa<TF::TensorListPushBackOp>(op) ||
|
||||
llvm::isa<TF::TensorListReserveOp>(op) ||
|
||||
llvm::isa<TF::TensorListSetItemOp>(op) ||
|
||||
llvm::isa<TF::TensorListStackOp>(op) ||
|
||||
llvm::isa<TF::TensorListResizeOp>(op)) {
|
||||
set.insert(arg_and_idx.index());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return set;
|
||||
}
|
||||
|
||||
// Changes the function type of `cond_func` and `body_func` for the given While
|
||||
// op.
|
||||
LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||
LogicalResult UpdateFunctionTypes(TF::WhileOp op,
|
||||
llvm::SmallSet<int, 4> tensor_list_args) {
|
||||
int func_index = 0;
|
||||
for (FuncOp func : {op.cond_function(), op.body_function()}) {
|
||||
++func_index;
|
||||
if (!func) continue;
|
||||
|
||||
FunctionType func_type = func.getType();
|
||||
@ -767,18 +800,33 @@ LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
|
||||
// tensor type if it's a variant type.
|
||||
SmallVector<Type, 8> updated_argument_types;
|
||||
updated_argument_types.reserve(num_inputs);
|
||||
for (auto it : llvm::zip(func_type.getInputs(), op.getOperands()))
|
||||
updated_argument_types.push_back(
|
||||
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
|
||||
int i = 0;
|
||||
for (auto it : llvm::zip(func_type.getInputs(), op.getOperands())) {
|
||||
if (tensor_list_args.count(i)) {
|
||||
updated_argument_types.push_back(
|
||||
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
|
||||
} else {
|
||||
updated_argument_types.push_back(std::get<0>(it));
|
||||
}
|
||||
++i;
|
||||
}
|
||||
|
||||
// Change all DT_VARIANT result types in function results to unranked tensor
|
||||
// type with element type derived from the corresponding input operand. This
|
||||
// is correct because while body's inputs and results have the same type.
|
||||
SmallVector<Type, 8> updated_result_types;
|
||||
updated_result_types.reserve(num_results);
|
||||
for (auto it : llvm::zip(func_type.getResults(), op.getOperands()))
|
||||
updated_result_types.push_back(
|
||||
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
|
||||
i = 0;
|
||||
for (auto it : llvm::zip(func_type.getResults(), op.getOperands())) {
|
||||
// Only update body's results.
|
||||
if (func_index != 1 && tensor_list_args.count(i)) {
|
||||
updated_result_types.push_back(
|
||||
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
|
||||
} else {
|
||||
updated_result_types.push_back(std::get<0>(it));
|
||||
}
|
||||
++i;
|
||||
}
|
||||
|
||||
// Change `func`'s argument type to `unranked_argument_types`. If it
|
||||
// return types contain a `DT_VARIANT`, change it to the unranked type
|
||||
@ -800,18 +848,28 @@ struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
|
||||
LogicalResult matchAndRewrite(
|
||||
TF::WhileOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Find all Tensor List arugments.
|
||||
auto tensor_list_args = GetTensorListArgumentsFromWhileOp(op);
|
||||
|
||||
llvm::SmallVector<Type, 8> result_types;
|
||||
result_types.reserve(op.getNumOperands());
|
||||
// Change all DT_VARIANT result types to unranked tensor type.
|
||||
for (auto it : llvm::zip(op.getResultTypes(), operands))
|
||||
result_types.push_back(
|
||||
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
|
||||
int i = 0;
|
||||
for (auto it : llvm::zip(op.getResultTypes(), operands)) {
|
||||
if (tensor_list_args.count(i)) {
|
||||
result_types.push_back(
|
||||
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
|
||||
} else {
|
||||
result_types.push_back(std::get<0>(it));
|
||||
}
|
||||
++i;
|
||||
}
|
||||
|
||||
// Create a new while op with new operands and updated result types.
|
||||
auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
|
||||
operands, op.getAttrs());
|
||||
converted.removeAttr("T");
|
||||
UpdateFunctionTypes(converted);
|
||||
UpdateFunctionTypes(converted, tensor_list_args);
|
||||
|
||||
rewriter.replaceOp(op, converted.getResults());
|
||||
return success();
|
||||
|
Loading…
Reference in New Issue
Block a user