Handle return values in MergeControlFlow pass.

Returned values used only in the second IfRegion are removed from the merged return value.  The remaining returns are merged together for the combined IfRegionOp.  Any ops using the return from the first IfRegion are moved after the combined IfRegionOp as needed.

PiperOrigin-RevId: 346562251
Change-Id: I249f638272e59cd1d56428a59d4135d57b751df0
This commit is contained in:
Ken Franko 2020-12-09 08:58:48 -08:00 committed by TensorFlower Gardener
parent f9ab869fcf
commit 070b02f441
2 changed files with 270 additions and 7 deletions

View File

@ -173,3 +173,185 @@ func @same_predicate_stateless_merge() {
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// Check that IfRegions with same predicates and returns are merged.
// CHECK-LABEL: func @same_predicate_returns_merged
func @same_predicate_returns_merged() {
// CHECK: tf_device.cluster
// CHECK: %[[IF_OUTPUT:[0-9]*]]:2 = "tf.IfRegion"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK-NEXT: "tf.Yield"(%[[A_OUTPUT]], %[[B_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK-NEXT: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK-NEXT: "tf.Yield"(%[[C_OUTPUT]], %[[D_OUTPUT]])
// CHECK-NOT: "tf.IfRegion"
// CHECK "tf.E"(%[[IF_OUTPUT]]#0, %[[IF_OUTPUT]]#1)
// CHECK-NOT: "tf.IfRegion"
"tf_device.cluster"() ( {
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
%1 = "tf.IfRegion"(%0) ( {
%3 = "tf.A"() : () -> (tensor<f32>)
"tf.Yield"(%3) : (tensor<f32>) -> ()
}, {
%3 = "tf.C"() : () -> (tensor<f32>)
"tf.Yield"(%3) : (tensor<f32>) -> ()
}) { is_stateless = true } : (tensor<i1>) -> (tensor<f32>)
%2 = "tf.IfRegion"(%0) ( {
%3 = "tf.B"() : () -> (tensor<i32>)
"tf.Yield"(%3) : (tensor<i32>) -> ()
}, {
%3 = "tf.D"() : () -> (tensor<i32>)
"tf.Yield"(%3) : (tensor<i32>) -> ()
}) { is_stateless = true } : (tensor<i1>) -> (tensor<i32>)
"tf.E"(%1, %2) : (tensor<f32>, tensor<i32>) -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// Check that IfRegions with same predicates and unused returns.
// CHECK-LABEL: func @same_predicate_returns_unused
func @same_predicate_returns_unused() {
// CHECK: tf_device.cluster
// CHECK: %[[IF_OUTPUT:[0-9]*]] = "tf.IfRegion"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK-NEXT: "tf.Yield"(%[[B_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK-NEXT: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK-NEXT: "tf.Yield"(%[[D_OUTPUT]])
// CHECK-NOT: "tf.IfRegion"
// CHECK "tf.E"(%[[IF_OUTPUT]])
// CHECK-NOT: "tf.IfRegion"
"tf_device.cluster"() ( {
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
%1 = "tf.IfRegion"(%0) ( {
%3 = "tf.A"() : () -> (tensor<f32>)
"tf.Yield"(%3) : (tensor<f32>) -> ()
}, {
%3 = "tf.C"() : () -> (tensor<f32>)
"tf.Yield"(%3) : (tensor<f32>) -> ()
}) { is_stateless = true } : (tensor<i1>) -> (tensor<f32>)
%2 = "tf.IfRegion"(%0) ( {
%3 = "tf.B"() : () -> (tensor<i32>)
"tf.Yield"(%3) : (tensor<i32>) -> ()
}, {
%3 = "tf.D"() : () -> (tensor<i32>)
"tf.Yield"(%3) : (tensor<i32>) -> ()
}) { is_stateless = true } : (tensor<i1>) -> (tensor<i32>)
"tf.E"(%2) : (tensor<i32>) -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// CHECK-LABEL: func @same_predicate_dependency
func @same_predicate_dependency() {
// CHECK: tf_device.cluster
// CHECK: %[[IF_OUTPUT:[0-9]*]] = "tf.IfRegion"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK-NEXT: "tf.Yield"(%[[B_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK-NEXT: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK-NEXT: "tf.Yield"(%[[D_OUTPUT]])
// CHECK-NOT: "tf.IfRegion"
// CHECK "tf.E"(%[[IF_OUTPUT]])
// CHECK-NOT: "tf.IfRegion"
"tf_device.cluster"() ( {
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
%1 = "tf.IfRegion"(%0) ( {
%3 = "tf.A"() : () -> (tensor<f32>)
"tf.Yield"(%3) : (tensor<f32>) -> ()
}, {
%3 = "tf.C"() : () -> (tensor<f32>)
"tf.Yield"(%3) : (tensor<f32>) -> ()
}) { is_stateless = true } : (tensor<i1>) -> (tensor<f32>)
%2 = "tf.IfRegion"(%0) ( {
%3 = "tf.B"(%1) : (tensor<f32>) -> (tensor<i32>)
"tf.Yield"(%3) : (tensor<i32>) -> ()
}, {
%3 = "tf.D"(%1) : (tensor<f32>) -> (tensor<i32>)
"tf.Yield"(%3) : (tensor<i32>) -> ()
}) { is_stateless = true } : (tensor<i1>) -> (tensor<i32>)
"tf.E"(%2) : (tensor<i32>) -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// Checks that results from first IfRegion are moved after merged IfRegion op as needed.
// CHECK-LABEL: func @same_predicate_results_moved
func @same_predicate_results_moved(%arg0: tensor<!tf.resource<tensor<f32>>>) {
// CHECK: tf_device.cluster
// CHECK: %[[IF_OUTPUT:[0-9]*]]:2 = "tf.IfRegion"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK-NEXT: "tf.Yield"(%[[A_OUTPUT]], %[[B_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK-NEXT: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK-NEXT: "tf.Yield"(%[[C_OUTPUT]], %[[D_OUTPUT]])
// CHECK-NOT: "tf.IfRegion"
// CHECK "tf.AssignVariableOp(arg0, %[[IF_OUTPUT#0]])
// CHECK "tf.E"(%[[IF_OUTPUT#1]])
// CHECK-NEXT "tf.F"(%[[IF_OUTPUT#1]])
// CHECK-NOT: "tf.IfRegion"
"tf_device.cluster"() ( {
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
%1 = "tf.IfRegion"(%0) ( {
%3 = "tf.A"() : () -> (tensor<f32>)
"tf.Yield"(%3) : (tensor<f32>) -> ()
}, {
%3 = "tf.C"() : () -> (tensor<f32>)
"tf.Yield"(%3) : (tensor<f32>) -> ()
}) { is_stateless = true } : (tensor<i1>) -> (tensor<f32>)
"tf.AssignVariableOp"(%arg0, %1) : (tensor<!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
%4 = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> (tensor<f32>)
%5 = "tf.IfRegion"(%0) ( {
%3 = "tf.B"(%4) : (tensor<f32>) -> (tensor<i32>)
"tf.Yield"(%3) : (tensor<i32>) -> ()
}, {
%3 = "tf.D"(%4) : (tensor<f32>) -> (tensor<i32>)
"tf.Yield"(%3) : (tensor<i32>) -> ()
}) { is_stateless = true } : (tensor<i1>) -> (tensor<i32>)
%6 = "tf.E"(%5) : (tensor<i32>) -> (tensor<f32>)
"tf.F"(%1, %6) : (tensor<f32>, tensor<f32>) -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// Check that 3 IfRegions with same predicates and no intermediate dependencies are merged.
// CHECK-LABEL: func @same_predicate_3_ifregions
func @same_predicate_3_ifregions() {
// CHECK: tf_device.cluster
// CHECK: "tf.IfRegion"
// CHECK-NOT: "tf.IfRegion"
"tf_device.cluster"() ( {
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
"tf.IfRegion"(%0) ( {
%2 = "tf.A"() : () -> (tensor<f32>)
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) { is_stateless = true } : (tensor<i1>) -> ()
"tf.IfRegion"(%0) ( {
%2 = "tf.B"() : () -> (tensor<f32>)
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) { is_stateless = true } : (tensor<i1>) -> ()
"tf.IfRegion"(%0) ( {
%2 = "tf.C"() : () -> (tensor<f32>)
"tf.Yield"() : () -> ()
}, {
"tf.Yield"() : () -> ()
}) { is_stateless = true } : (tensor<i1>) -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <memory>
#include <queue>
#include <string>
#include <utility>
@ -129,6 +130,21 @@ bool SafeToMerge(TF::IfRegionOp source, TF::IfRegionOp destination,
return safe_to_merge;
}
// Checks whether a return indice should be kep for `first_if_op` by checking
// for results in `second_if_op`.
llvm::SmallVector<int, 4> GetReturnIndicesToKeep(TF::IfRegionOp first_if_op,
TF::IfRegionOp second_if_op) {
llvm::SmallVector<int, 4> return_indices_to_keep;
for (auto& index_and_value : llvm::enumerate(first_if_op.getResults())) {
if (!llvm::all_of(index_and_value.value().getUsers(), [&](Operation* op) {
return second_if_op->isProperAncestor(op);
})) {
return_indices_to_keep.push_back(index_and_value.index());
}
}
return return_indices_to_keep;
}
// Move the body excluding the terminators of else and then regions from
// 'source' to 'destination'.
void MoveBranches(TF::IfRegionOp source, TF::IfRegionOp destination) {
@ -145,31 +161,87 @@ void MoveBranches(TF::IfRegionOp source, TF::IfRegionOp destination) {
source_else_body.begin(), std::prev(source_else_body.end()));
}
Operation* GetIfInsertionPoint(TF::IfRegionOp source,
TF::IfRegionOp destination) {
// TODO(b/173422484): Pick this insertion point better.
return source.getOperation();
// Move all ops that depends on the results from `result_op` after `after_op`.
void MoveResultsAfter(Operation* result_op, Operation* after_op) {
std::queue<Operation*> queue;
for (Operation* user : result_op->getUsers()) {
queue.push(user);
}
while (!queue.empty()) {
auto* op = queue.front();
queue.pop();
for (Operation* user : op->getUsers()) queue.push(user);
if (op->isBeforeInBlock(after_op)) op->moveAfter(after_op);
after_op = op;
}
}
TF::IfRegionOp CreateMergedIf(TF::IfRegionOp source,
TF::IfRegionOp CreateMergedIf(ArrayRef<int> source_return_indices_to_keep,
ArrayRef<int> destination_return_indices_to_keep,
TF::IfRegionOp source,
TF::IfRegionOp destination) {
llvm::SmallVector<Type, 4> merged_return_types;
for (int i : destination_return_indices_to_keep)
merged_return_types.push_back(destination.getResult(i).getType());
for (int i : source_return_indices_to_keep)
merged_return_types.push_back(source.getResult(i).getType());
OpBuilder builder(destination);
// Create new IfRegion with correct merged results.
builder.setInsertionPoint(GetIfInsertionPoint(source, destination));
builder.setInsertionPoint(source.getOperation());
auto new_if_op = builder.create<TF::IfRegionOp>(
destination.getLoc(), merged_return_types, destination.cond(),
destination.is_stateless() && source.is_stateless());
new_if_op.then_branch().push_back(new Block);
new_if_op.else_branch().push_back(new Block);
// Replace internal usages of merged if ops.
for (OpResult result : destination.getResults()) {
replaceAllUsesInRegionWith(
result,
destination.then_branch().front().getTerminator()->getOperand(
result.getResultNumber()),
source.then_branch());
replaceAllUsesInRegionWith(
result,
destination.else_branch().front().getTerminator()->getOperand(
result.getResultNumber()),
source.else_branch());
}
MoveResultsAfter(destination.getOperation(), new_if_op.getOperation());
// Replace external usages of merged if ops.
int new_return_index = 0;
for (int i : destination_return_indices_to_keep) {
destination.getResult(i).replaceAllUsesWith(
new_if_op.getResult(new_return_index++));
}
for (int i : source_return_indices_to_keep) {
source.getResult(i).replaceAllUsesWith(
new_if_op.getResult(new_return_index++));
}
// Create the Yield ops for both branches with merged results.
llvm::SmallVector<Value, 4> merged_then_yield_values;
for (int i : destination_return_indices_to_keep)
merged_then_yield_values.push_back(
destination.then_branch().front().getTerminator()->getOperand(i));
for (int i : source_return_indices_to_keep)
merged_then_yield_values.push_back(
source.then_branch().front().getTerminator()->getOperand(i));
builder.setInsertionPointToEnd(&new_if_op.then_branch().front());
builder.create<TF::YieldOp>(
destination.then_branch().front().getTerminator()->getLoc(),
/*operands=*/merged_then_yield_values);
llvm::SmallVector<Value, 4> merged_else_yield_values;
for (int i : destination_return_indices_to_keep)
merged_else_yield_values.push_back(
destination.else_branch().front().getTerminator()->getOperand(i));
for (int i : source_return_indices_to_keep)
merged_else_yield_values.push_back(
source.else_branch().front().getTerminator()->getOperand(i));
builder.setInsertionPointToEnd(&new_if_op.else_branch().front());
builder.create<TF::YieldOp>(
destination.else_branch().front().getTerminator()->getLoc(),
@ -201,7 +273,16 @@ void OptimizeIfRegions(
TF::IfRegionOp if_op = if_ops[i];
if (!SafeToMerge(if_op, first_if_op, side_effect_analysis)) break;
auto new_if_op = CreateMergedIf(if_op, first_if_op);
// For both check if there are uses outside of IfRegion, keep these as
// part of the return and replace the internal uses.
auto first_return_indices_to_keep =
GetReturnIndicesToKeep(first_if_op, if_op);
auto second_return_indices_to_keep =
GetReturnIndicesToKeep(if_op, first_if_op);
auto new_if_op =
CreateMergedIf(second_return_indices_to_keep,
first_return_indices_to_keep, if_op, first_if_op);
first_if_op = new_if_op;
}