Do not add unranked type to the while's non tensor list variant arguments while

lowering tensor lists

If not, non tensor list variant arguments will not match with the result of the
variant operators in the body or cond function.

PiperOrigin-RevId: 352373987
Change-Id: I9a6154f098207241b3f1700c3a1025b6a28a53ed
This commit is contained in:
Jaesung Chung 2021-01-18 02:47:25 -08:00 committed by TensorFlower Gardener
parent 3f837ab7ab
commit a2766ce758
2 changed files with 113 additions and 11 deletions

View File

@ -316,6 +316,50 @@ func @tensorlistWhileRegion(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
return %2 : tensor<*xf32>
}
func @otherVariantWhileLoop(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%2 = "tf.EmptyTensorMap"() {device = ""} : () -> tensor<!tf.variant>
%3:4 = "tf.While"(%0, %1, %0, %2) {_lower_using_switch_merge = true, _num_original_outputs = 4 : i64, _read_only_resource_inputs = [], body = @otherVariantWhileBody, cond = @otherVariantWhileCond, device = "", is_stateless = true, parallel_iterations = 10 : i64, shape_invariant} : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.variant>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.variant>)
%4 = "tf.Identity"(%3#3) {device = ""} : (tensor<!tf.variant>) -> tensor<!tf.variant>
%5 = "tf.TensorMapSize"(%4) {device = ""} : (tensor<!tf.variant>) -> tensor<i32>
%6 = "tf.AddV2"(%arg0, %5) {device = ""} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
return %6 : tensor<1xi32>
}
// Make sure the non TensorList variant types in input/output have remained.
// CHECK-LABEL: otherVariantWhileLoop
// CHECK: "tf.While"
// CHECK-SAME: (tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.variant>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.variant>)
func @otherVariantWhileBody(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<!tf.variant>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.variant>) {
%0 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.AddV2"(%arg2, %0) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%2 = "tf.TensorMapInsert"(%arg3, %arg2, %arg2) {device = "", key_dtype = i32, value_dtype = i32} : (tensor<!tf.variant>, tensor<i32>, tensor<i32>) -> tensor<!tf.variant>
%3 = "tf.AddV2"(%arg0, %0) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i32>
return %3, %arg1, %1, %2 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<!tf.variant>
}
// Verify `body` function's signature.
// CHECK-LABEL: func @otherVariantWhileBody
// CHECK: [[CST:%.*]] = "tf.Const"()
// CHECK-NEXT: [[ADD:%.*]] = "tf.AddV2"(%arg2, [[CST]])
// CHECK-NEXT: [[TENSOR_MAP_INSERT_RESULT:%.*]] = "tf.TensorMapInsert"(%arg3, %arg2, %arg2)
// CHECK-NEXT: [[ADD_2:%.*]] = "tf.AddV2"(%arg0, [[CST]])
// CHECK-NEXT: return [[ADD_2]], %arg1, [[ADD]], [[TENSOR_MAP_INSERT_RESULT]]
func @otherVariantWhileCond(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<!tf.variant>) -> tensor<i1> {
%0 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Less"(%arg2, %0) {device = ""} : (tensor<i32>, tensor<i32>) -> tensor<i1>
return %1 : tensor<i1>
}
// Verify `cond` function's signature.
// CHECK-LABEL: func @otherVariantWhileCond
// CHECK: [[CST:%.*]] = "tf.Const"()
// CHECK-NEXT: [[LESS:%.*]] = "tf.Less"(%arg2, [[CST]])
// CHECK-NEXT: return [[LESS]]
func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> tensor<?x10xf32> {
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
%1 = "tf.TensorListResize"(%0, %arg2) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>) -> tensor<!tf.variant<tensor<10xf32>>>

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
@ -753,10 +754,42 @@ Type VariantToUnrankedTensorType(Type type, Value value) {
return type;
}
llvm::SmallSet<int, 4> GetTensorListArgumentsFromWhileOp(TF::WhileOp op) {
llvm::SmallSet<int, 4> set;
for (FuncOp func : {op.cond_function(), op.body_function()}) {
if (!func) continue;
for (auto arg_and_idx : llvm::enumerate(func.getArguments())) {
mlir::BlockArgument arg = arg_and_idx.value();
auto variant_ty =
getElementTypeOrSelf(arg.getType()).dyn_cast<TF::VariantType>();
if (!variant_ty) continue;
for (auto &op_operand : arg.getUses()) {
auto op = op_operand.getOwner();
if (llvm::isa<TF::TensorListGetItemOp>(op) ||
llvm::isa<TF::TensorListLengthOp>(op) ||
llvm::isa<TF::TensorListPushBackOp>(op) ||
llvm::isa<TF::TensorListReserveOp>(op) ||
llvm::isa<TF::TensorListSetItemOp>(op) ||
llvm::isa<TF::TensorListStackOp>(op) ||
llvm::isa<TF::TensorListResizeOp>(op)) {
set.insert(arg_and_idx.index());
break;
}
}
}
}
return set;
}
// Changes the function type of `cond_func` and `body_func` for the given While
// op.
LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
LogicalResult UpdateFunctionTypes(TF::WhileOp op,
llvm::SmallSet<int, 4> tensor_list_args) {
int func_index = 0;
for (FuncOp func : {op.cond_function(), op.body_function()}) {
++func_index;
if (!func) continue;
FunctionType func_type = func.getType();
@ -767,18 +800,33 @@ LogicalResult UpdateFunctionTypes(TF::WhileOp op) {
// tensor type if it's a variant type.
SmallVector<Type, 8> updated_argument_types;
updated_argument_types.reserve(num_inputs);
for (auto it : llvm::zip(func_type.getInputs(), op.getOperands()))
updated_argument_types.push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
int i = 0;
for (auto it : llvm::zip(func_type.getInputs(), op.getOperands())) {
if (tensor_list_args.count(i)) {
updated_argument_types.push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
} else {
updated_argument_types.push_back(std::get<0>(it));
}
++i;
}
// Change all DT_VARIANT result types in function results to unranked tensor
// type with element type derived from the corresponding input operand. This
// is correct because while body's inputs and results have the same type.
SmallVector<Type, 8> updated_result_types;
updated_result_types.reserve(num_results);
for (auto it : llvm::zip(func_type.getResults(), op.getOperands()))
updated_result_types.push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
i = 0;
for (auto it : llvm::zip(func_type.getResults(), op.getOperands())) {
// Only update body's results.
if (func_index != 1 && tensor_list_args.count(i)) {
updated_result_types.push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
} else {
updated_result_types.push_back(std::get<0>(it));
}
++i;
}
// Change `func`'s argument type to `unranked_argument_types`. If it
// return types contain a `DT_VARIANT`, change it to the unranked type
@ -800,18 +848,28 @@ struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
LogicalResult matchAndRewrite(
TF::WhileOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Find all Tensor List arugments.
auto tensor_list_args = GetTensorListArgumentsFromWhileOp(op);
llvm::SmallVector<Type, 8> result_types;
result_types.reserve(op.getNumOperands());
// Change all DT_VARIANT result types to unranked tensor type.
for (auto it : llvm::zip(op.getResultTypes(), operands))
result_types.push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
int i = 0;
for (auto it : llvm::zip(op.getResultTypes(), operands)) {
if (tensor_list_args.count(i)) {
result_types.push_back(
VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it)));
} else {
result_types.push_back(std::get<0>(it));
}
++i;
}
// Create a new while op with new operands and updated result types.
auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
operands, op.getAttrs());
converted.removeAttr("T");
UpdateFunctionTypes(converted);
UpdateFunctionTypes(converted, tensor_list_args);
rewriter.replaceOp(op, converted.getResults());
return success();