Add support for tf.Case op in Resource Op Lifting pass.

PiperOrigin-RevId: 314196123
Change-Id: I8eef7aad5b6da7cd69b821d92f940a8f08fef08c
This commit is contained in:
Prakalp Srivastava 2020-06-01 13:49:26 -07:00 committed by TensorFlower Gardener
parent e1a3f9a9e7
commit bfc2553173
2 changed files with 127 additions and 50 deletions

View File

@ -406,6 +406,61 @@ func @while_cond(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
// -----
// CHECK: func @cluster_with_case(%[[ARG0:.*]]: tensor<i32>) -> tensor<4xf32>
func @cluster_with_case(%arg0: tensor<i32>) -> tensor<4xf32> {
// CHECK: %[[VH0:.*]] = "tf.VarHandleOp"()
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
// CHECK: %[[VH1:.*]] = "tf.VarHandleOp"()
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
// CHECK-DAG: %[[READ0:.*]] = "tf.ReadVariableOp"(%[[VH0]])
// CHECK-DAG: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[VH1]])
// CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"()
%2 = "tf_device.cluster"() ( {
// CHECK: %[[CASE:.*]]:2 = "tf.Case"(%[[ARG0]], %[[READ0]], %[[READ1]])
%3:2 = "tf.Case"(%arg0, %0, %1) {branches = [@branch_0, @branch_1, @branch_2]}
: (tensor<i32>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>)
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[CASE]]#1, %[[CASE]]#0)
%4 = "tf.ReadVariableOp"(%3#0) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
%5 = "tf.AddV2"(%4, %3#1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: tf_device.return %[[ADD]], %[[CASE]]#1
tf_device.return %5 : tensor<4xf32>
// CHECK: {cluster_attr = "cluster_attr"} : () -> (tensor<4xf32>, tensor<4xf32>)
}) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32>
// CHECK: "tf.AssignVariableOp"(%[[VH0]], %[[CLUSTER]]#1)
// CHECK: return %[[CLUSTER]]#0
return %2 : tensor<4xf32>
}
// CHECK: func @branch_0(%[[TARG0:.*]]: tensor<4xf32>, %[[TARG1:.*]]: tensor<4xf32>)
func @branch_0(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.resource<tensor<4xf32>>>)
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) {
// CHECK-NEXT: %[[CONST:.*]] = "tf.Const"()
%constant = "tf.Const"() {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32>
"tf.AssignVariableOp"(%arg0, %constant) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
// CHECK-NEXT: return %[[CONST]], %[[CONST]]
return %arg0, %constant : tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>
}
// CHECK: func @branch_1(%[[EARG0:.*]]: tensor<4xf32>, %[[EARG1:.*]]: tensor<4xf32>)
func @branch_1(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.resource<tensor<4xf32>>>)
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) {
%id = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<*x!tf.resource<tensor<4xf32>>>
%read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
"tf.AssignVariableOp"(%arg0, %read) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
// CHECK-NEXT: return %[[EARG1]], %[[EARG1]]
return %arg0, %read : tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>
}
// CHECK: func @branch_2(%[[EARG0:.*]]: tensor<4xf32>, %[[EARG1:.*]]: tensor<4xf32>)
func @branch_2(%arg0: tensor<*x!tf.resource<tensor<4xf32>>>, %arg1: tensor<*x!tf.resource<tensor<4xf32>>>)
-> (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) {
%id = "tf.Identity"(%arg1) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<*x!tf.resource<tensor<4xf32>>>
%read = "tf.ReadVariableOp"(%id) : (tensor<*x!tf.resource<tensor<4xf32>>>) -> tensor<4xf32>
"tf.AssignVariableOp"(%arg0, %read) : (tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>) -> ()
// CHECK-NEXT: return %[[EARG1]], %[[EARG1]]
return %arg0, %read : tensor<*x!tf.resource<tensor<4xf32>>>, tensor<4xf32>
}
// -----
// Tests that pass lifts resource reads from if branches.
// CHECK: func @cluster_with_if(%[[ARG0:.*]]: tensor<i1>) -> tensor<4xf32>
@ -524,7 +579,7 @@ func @cluster_with_if(%arg0: tensor<i1>) -> tensor<4xf32> {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource<tensor<4xf32>>>
%2 = "tf_device.cluster"() ( {
// expected-error @+1 {{unsupported tf.IfOp output: resource does not alias a single input.}}
// expected-error @+1 {{unsupported output: resource does not alias a single input}}
%3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else,
is_stateless = false}
: (tensor<i1>, tensor<*x!tf.resource<tensor<4xf32>>>, tensor<*x!tf.resource<tensor<4xf32>>>)

View File

@ -668,67 +668,74 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) {
return success();
}
// Lifts loads/stores from an IfOp's branches.
LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
FuncOp else_branch) {
// Lifts loads/stores from an IfOp or CaseOp's branches.
template <class CaseOrIfOp>
LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef<FuncOp> branches) {
// Remove identity nodes to avoid aliasing.
RemoveIdentity(&then_branch.front());
RemoveIdentity(&else_branch.front());
for (auto func : branches) RemoveIdentity(&func.front());
// Sanity check: branch return of resources should be aliases of inputs. If
// so, replace the output uses with the input so that we can remove these
// outputs.
for (auto entry : llvm::enumerate(
llvm::zip(then_branch.front().getTerminator()->getOperands(),
else_branch.front().getTerminator()->getOperands()))) {
auto then_retval = std::get<0>(entry.value());
auto else_retval = std::get<1>(entry.value());
assert(then_retval.getType() == else_retval.getType());
if (!getElementTypeOrSelf(then_retval.getType()).isa<TF::ResourceType>()) {
for (OpResult result : op.getResults()) {
if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>())
continue;
unsigned result_index = result.getResultNumber();
constexpr unsigned kUnassigned = -1;
unsigned common_aliasing_arg_num = kUnassigned;
for (auto func : branches) {
auto retval = func.front().getTerminator()->getOperand(result_index);
assert(result.getType() == retval.getType());
auto aliasing_arg = retval.dyn_cast<BlockArgument>();
if (common_aliasing_arg_num == kUnassigned)
common_aliasing_arg_num = aliasing_arg.getArgNumber();
if (!aliasing_arg ||
aliasing_arg.getArgNumber() != common_aliasing_arg_num)
return op.emitOpError("unsupported output: ")
<< "resource does not alias a single input";
}
auto then_aliasing_arg = then_retval.dyn_cast<BlockArgument>();
auto else_aliasing_arg = else_retval.dyn_cast<BlockArgument>();
if (!then_aliasing_arg || !else_aliasing_arg ||
then_aliasing_arg.getArgNumber() != else_aliasing_arg.getArgNumber()) {
return if_op.emitOpError("unsupported tf.IfOp output: ")
<< "resource does not alias a single input.";
}
if_op.getResult(entry.index())
.replaceAllUsesWith(
if_op.getOperand(then_aliasing_arg.getArgNumber() + 1));
assert(common_aliasing_arg_num != kUnassigned);
result.replaceAllUsesWith(op.getOperand(common_aliasing_arg_num + 1));
}
// Erase the resource outputs from the branches.
int64_t non_resource_results = 0;
llvm::SmallVector<int64_t, 4> old_to_new_output_indices;
bool output_removed = false;
for (auto result : if_op.getResults()) {
if (!getElementTypeOrSelf(result.getType()).isa<TF::ResourceType>()) {
for (auto result : op.getResults()) {
if (!getElementTypeOrSelf(result.getType())
.template isa<TF::ResourceType>()) {
old_to_new_output_indices.push_back(non_resource_results++);
continue;
}
old_to_new_output_indices.push_back(-1);
then_branch.front().getTerminator()->eraseOperand(non_resource_results);
else_branch.front().getTerminator()->eraseOperand(non_resource_results);
for (auto func : branches)
func.front().getTerminator()->eraseOperand(non_resource_results);
output_removed = true;
}
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> then_use_info;
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> else_use_info;
if (failed(FindResourceArgUseInfo(then_branch, &then_use_info)) ||
failed(FindResourceArgUseInfo(else_branch, &else_use_info))) {
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> resource_arg_uses;
if (failed(FindResourceArgUseInfo(branches.front(), &resource_arg_uses)))
return failure();
for (auto func : branches.drop_front()) {
llvm::SmallDenseMap<int64_t, ResourceArgUseInfo> branch_use_info;
if (failed(FindResourceArgUseInfo(func, &branch_use_info)))
return failure();
// A resource is considered used as long as it is used in either branch.
resource_arg_uses =
MergeArgResourceUseInfo(resource_arg_uses, branch_use_info);
}
// A resource is considered used as long as it is used in either branch.
auto resource_arg_uses =
MergeArgResourceUseInfo(then_use_info, else_use_info);
if (resource_arg_uses.empty() && !output_removed) return success();
// Remove unused resources in functions.
llvm::SmallDenseMap<int64_t, Type> remaining_resource_data_types;
RemoveUnusedResourceArgumentsAndForwardedRetvals(
resource_arg_uses, then_branch, /*old_to_new_arg_indices=*/nullptr,
resource_arg_uses, branches.front(), /*old_to_new_arg_indices=*/nullptr,
&remaining_resource_data_types);
RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses,
else_branch);
for (auto func : branches.drop_front())
RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, func);
// Forward resource inputs updated in any branch to the outputs of both
// branches. First prepare the mapping from arg to new update output.
llvm::SmallDenseMap<int64_t, int64_t> resource_arg_to_new_output;
@ -746,10 +753,11 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
new_output_index;
}
}
// Append resource updates to the return ops: now they are just forwarded
// input resources, but will be replaced by the data value in
// LiftArgRetResourcesForFunction().
for (auto branch : {then_branch, else_branch}) {
for (auto branch : branches) {
auto new_retvals =
llvm::to_vector<4>(branch.front().getTerminator()->getOperands());
for (const auto& entry : resource_arg_to_new_output) {
@ -766,16 +774,17 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
});
}
// Recreate the if op.
OpBuilder builder(if_op);
// Recreate the op without resource operands.
OpBuilder builder(op);
// Now use the filtered original operands, which will be replaced by
// AddLoadsStoresOutsideControlFlowOp().
auto new_operands =
FilterRange<Value, OperandRange>(if_op.input(), resource_arg_uses);
new_operands.insert(new_operands.begin(), if_op.cond());
auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
then_branch.getType().getResults(),
new_operands, if_op.getAttrs());
FilterRange<Value, OperandRange>(op.input(), resource_arg_uses);
new_operands.insert(new_operands.begin(), op.getOperand(0));
FuncOp first_func = branches.front();
auto new_op =
builder.create<CaseOrIfOp>(op.getLoc(), first_func.getType().getResults(),
new_operands, op.getAttrs());
// Prepare for AddLoadsStoresOutsideControlFlowOp()
llvm::SmallDenseMap<int64_t, std::pair<Type, int64_t>>
arg_data_type_and_updated_output_index;
@ -787,16 +796,16 @@ LogicalResult HandleIfOP(TF::IfOp if_op, FuncOp then_branch,
arg_data_type_and_updated_output_index[entry.getFirst() + 1] = {
entry.getSecond(), update_index};
}
AddLoadsStoresOutsideControlFlowOp(new_if,
AddLoadsStoresOutsideControlFlowOp(new_op,
arg_data_type_and_updated_output_index);
// Replace uses.
for (int64_t i = 0; i < old_to_new_output_indices.size(); ++i) {
if (old_to_new_output_indices[i] >= 0) {
if_op.getResult(i).replaceAllUsesWith(
new_if.getResult(old_to_new_output_indices[i]));
op.getResult(i).replaceAllUsesWith(
new_op.getResult(old_to_new_output_indices[i]));
}
}
if_op.erase();
op.erase();
return success();
}
@ -985,7 +994,20 @@ LogicalResult HoistForFunctionalControlFlow(
lifted_partitioned_call_callees);
HoistForFunctionalControlFlow(&else_branch.front(), module,
lifted_partitioned_call_callees);
if (failed(HandleIfOP(if_op, then_branch, else_branch))) return failure();
if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch})))
return failure();
} else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
SmallVector<FuncOp, 4> branch_functions;
branch_functions.reserve(case_op.branches().size());
for (const Attribute& branch : case_op.branches()) {
FuncOp func =
module.lookupSymbol<FuncOp>(branch.cast<FlatSymbolRefAttr>());
// Recursively handle the nested control flow.
HoistForFunctionalControlFlow(&func.front(), module,
lifted_partitioned_call_callees);
branch_functions.push_back(func);
}
if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure();
} else if (auto call_op = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
if (!call_op.f().isa<FlatSymbolRefAttr>()) {
return call_op.emitOpError(