From 070b02f4419ccd0896bd636e3e89ddfbec3a47fa Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Wed, 9 Dec 2020 08:58:48 -0800 Subject: [PATCH] Handle return values in MergeControlFlow pass. Returned values used only in the second IfRegion are removed from the merged return value. The remaining returns are merged together for the combined IfRegionOp. Any ops using the return from the first IfRegion are moved after the combined IfRegionOp as needed. PiperOrigin-RevId: 346562251 Change-Id: I249f638272e59cd1d56428a59d4135d57b751df0 --- .../tensorflow/tests/merge_control_flow.mlir | 182 ++++++++++++++++++ .../transforms/merge_control_flow.cc | 95 ++++++++- 2 files changed, 270 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir b/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir index 2ab252f75ad..eebfe6d4ad6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir @@ -173,3 +173,185 @@ func @same_predicate_stateless_merge() { }) {cluster_attr = "cluster_attr"} : () -> () return } + +// Check that IfRegions with same predicates and returns are merged. + +// CHECK-LABEL: func @same_predicate_returns_merged +func @same_predicate_returns_merged() { + // CHECK: tf_device.cluster + // CHECK: %[[IF_OUTPUT:[0-9]*]]:2 = "tf.IfRegion" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK-NEXT: "tf.Yield"(%[[A_OUTPUT]], %[[B_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK-NEXT: "tf.Yield"(%[[C_OUTPUT]], %[[D_OUTPUT]]) + // CHECK-NOT: "tf.IfRegion" + // CHECK "tf.E"(%[[IF_OUTPUT]]#0, %[[IF_OUTPUT]]#1) + // CHECK-NOT: "tf.IfRegion" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %1 = "tf.IfRegion"(%0) ( { + %3 = "tf.A"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.C"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + %2 = "tf.IfRegion"(%0) ( { + %3 = "tf.B"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.D"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + "tf.E"(%1, %2) : (tensor, tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} +// Check that IfRegions with same predicates and unused returns. + +// CHECK-LABEL: func @same_predicate_returns_unused +func @same_predicate_returns_unused() { + // CHECK: tf_device.cluster + // CHECK: %[[IF_OUTPUT:[0-9]*]] = "tf.IfRegion" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK-NEXT: "tf.Yield"(%[[B_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK-NEXT: "tf.Yield"(%[[D_OUTPUT]]) + // CHECK-NOT: "tf.IfRegion" + // CHECK "tf.E"(%[[IF_OUTPUT]]) + // CHECK-NOT: "tf.IfRegion" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %1 = "tf.IfRegion"(%0) ( { + %3 = "tf.A"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.C"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + %2 = "tf.IfRegion"(%0) ( { + %3 = "tf.B"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.D"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + "tf.E"(%2) : (tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @same_predicate_dependency +func @same_predicate_dependency() { + // CHECK: tf_device.cluster + // CHECK: %[[IF_OUTPUT:[0-9]*]] = "tf.IfRegion" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK-NEXT: "tf.Yield"(%[[B_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK-NEXT: "tf.Yield"(%[[D_OUTPUT]]) + // CHECK-NOT: "tf.IfRegion" + // CHECK "tf.E"(%[[IF_OUTPUT]]) + // CHECK-NOT: "tf.IfRegion" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %1 = "tf.IfRegion"(%0) ( { + %3 = "tf.A"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.C"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + %2 = "tf.IfRegion"(%0) ( { + %3 = "tf.B"(%1) : (tensor) -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.D"(%1) : (tensor) -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + "tf.E"(%2) : (tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// Checks that results from first IfRegion are moved after merged IfRegion op as needed. + +// CHECK-LABEL: func @same_predicate_results_moved +func @same_predicate_results_moved(%arg0: tensor>>) { + // CHECK: tf_device.cluster + // CHECK: %[[IF_OUTPUT:[0-9]*]]:2 = "tf.IfRegion" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK-NEXT: "tf.Yield"(%[[A_OUTPUT]], %[[B_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK-NEXT: "tf.Yield"(%[[C_OUTPUT]], %[[D_OUTPUT]]) + // CHECK-NOT: "tf.IfRegion" + // CHECK "tf.AssignVariableOp(arg0, %[[IF_OUTPUT#0]]) + // CHECK "tf.E"(%[[IF_OUTPUT#1]]) + // CHECK-NEXT "tf.F"(%[[IF_OUTPUT#1]]) + // CHECK-NOT: "tf.IfRegion" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %1 = "tf.IfRegion"(%0) ( { + %3 = "tf.A"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.C"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + "tf.AssignVariableOp"(%arg0, %1) : (tensor>>, tensor) -> () + %4 = "tf.Const"() {value = dense<1.0> : tensor} : () -> (tensor) + %5 = "tf.IfRegion"(%0) ( { + %3 = "tf.B"(%4) : (tensor) -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.D"(%4) : (tensor) -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + %6 = "tf.E"(%5) : (tensor) -> (tensor) + "tf.F"(%1, %6) : (tensor, tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// Check that 3 IfRegions with same predicates and no intermediate dependencies are merged. + +// CHECK-LABEL: func @same_predicate_3_ifregions +func @same_predicate_3_ifregions() { + // CHECK: tf_device.cluster + // CHECK: "tf.IfRegion" + // CHECK-NOT: "tf.IfRegion" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.IfRegion"(%0) ( { + %2 = "tf.A"() : () -> (tensor) + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) { is_stateless = true } : (tensor) -> () + "tf.IfRegion"(%0) ( { + %2 = "tf.B"() : () -> (tensor) + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) { is_stateless = true } : (tensor) -> () + "tf.IfRegion"(%0) ( { + %2 = "tf.C"() : () -> (tensor) + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) { is_stateless = true } : (tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc index 2ed9e3a86a6..92c7df9a338 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include @@ -129,6 +130,21 @@ bool SafeToMerge(TF::IfRegionOp source, TF::IfRegionOp destination, return safe_to_merge; } +// Checks whether a return indice should be kep for `first_if_op` by checking +// for results in `second_if_op`. +llvm::SmallVector GetReturnIndicesToKeep(TF::IfRegionOp first_if_op, + TF::IfRegionOp second_if_op) { + llvm::SmallVector return_indices_to_keep; + for (auto& index_and_value : llvm::enumerate(first_if_op.getResults())) { + if (!llvm::all_of(index_and_value.value().getUsers(), [&](Operation* op) { + return second_if_op->isProperAncestor(op); + })) { + return_indices_to_keep.push_back(index_and_value.index()); + } + } + return return_indices_to_keep; +} + // Move the body excluding the terminators of else and then regions from // 'source' to 'destination'. void MoveBranches(TF::IfRegionOp source, TF::IfRegionOp destination) { @@ -145,31 +161,87 @@ void MoveBranches(TF::IfRegionOp source, TF::IfRegionOp destination) { source_else_body.begin(), std::prev(source_else_body.end())); } -Operation* GetIfInsertionPoint(TF::IfRegionOp source, - TF::IfRegionOp destination) { - // TODO(b/173422484): Pick this insertion point better. - return source.getOperation(); +// Move all ops that depends on the results from `result_op` after `after_op`. +void MoveResultsAfter(Operation* result_op, Operation* after_op) { + std::queue queue; + for (Operation* user : result_op->getUsers()) { + queue.push(user); + } + while (!queue.empty()) { + auto* op = queue.front(); + queue.pop(); + for (Operation* user : op->getUsers()) queue.push(user); + if (op->isBeforeInBlock(after_op)) op->moveAfter(after_op); + after_op = op; + } } -TF::IfRegionOp CreateMergedIf(TF::IfRegionOp source, +TF::IfRegionOp CreateMergedIf(ArrayRef source_return_indices_to_keep, + ArrayRef destination_return_indices_to_keep, + TF::IfRegionOp source, TF::IfRegionOp destination) { llvm::SmallVector merged_return_types; + for (int i : destination_return_indices_to_keep) + merged_return_types.push_back(destination.getResult(i).getType()); + for (int i : source_return_indices_to_keep) + merged_return_types.push_back(source.getResult(i).getType()); OpBuilder builder(destination); // Create new IfRegion with correct merged results. - builder.setInsertionPoint(GetIfInsertionPoint(source, destination)); + builder.setInsertionPoint(source.getOperation()); + auto new_if_op = builder.create( destination.getLoc(), merged_return_types, destination.cond(), destination.is_stateless() && source.is_stateless()); new_if_op.then_branch().push_back(new Block); new_if_op.else_branch().push_back(new Block); + // Replace internal usages of merged if ops. + for (OpResult result : destination.getResults()) { + replaceAllUsesInRegionWith( + result, + destination.then_branch().front().getTerminator()->getOperand( + result.getResultNumber()), + source.then_branch()); + replaceAllUsesInRegionWith( + result, + destination.else_branch().front().getTerminator()->getOperand( + result.getResultNumber()), + source.else_branch()); + } + + MoveResultsAfter(destination.getOperation(), new_if_op.getOperation()); + + // Replace external usages of merged if ops. + int new_return_index = 0; + for (int i : destination_return_indices_to_keep) { + destination.getResult(i).replaceAllUsesWith( + new_if_op.getResult(new_return_index++)); + } + for (int i : source_return_indices_to_keep) { + source.getResult(i).replaceAllUsesWith( + new_if_op.getResult(new_return_index++)); + } + + // Create the Yield ops for both branches with merged results. llvm::SmallVector merged_then_yield_values; + for (int i : destination_return_indices_to_keep) + merged_then_yield_values.push_back( + destination.then_branch().front().getTerminator()->getOperand(i)); + for (int i : source_return_indices_to_keep) + merged_then_yield_values.push_back( + source.then_branch().front().getTerminator()->getOperand(i)); builder.setInsertionPointToEnd(&new_if_op.then_branch().front()); builder.create( destination.then_branch().front().getTerminator()->getLoc(), /*operands=*/merged_then_yield_values); llvm::SmallVector merged_else_yield_values; + for (int i : destination_return_indices_to_keep) + merged_else_yield_values.push_back( + destination.else_branch().front().getTerminator()->getOperand(i)); + for (int i : source_return_indices_to_keep) + merged_else_yield_values.push_back( + source.else_branch().front().getTerminator()->getOperand(i)); builder.setInsertionPointToEnd(&new_if_op.else_branch().front()); builder.create( destination.else_branch().front().getTerminator()->getLoc(), @@ -201,7 +273,16 @@ void OptimizeIfRegions( TF::IfRegionOp if_op = if_ops[i]; if (!SafeToMerge(if_op, first_if_op, side_effect_analysis)) break; - auto new_if_op = CreateMergedIf(if_op, first_if_op); + // For both check if there are uses outside of IfRegion, keep these as + // part of the return and replace the internal uses. + auto first_return_indices_to_keep = + GetReturnIndicesToKeep(first_if_op, if_op); + auto second_return_indices_to_keep = + GetReturnIndicesToKeep(if_op, first_if_op); + + auto new_if_op = + CreateMergedIf(second_return_indices_to_keep, + first_return_indices_to_keep, if_op, first_if_op); first_if_op = new_if_op; }