diff --git a/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir b/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir index 2ab252f75ad..eebfe6d4ad6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/merge_control_flow.mlir @@ -173,3 +173,185 @@ func @same_predicate_stateless_merge() { }) {cluster_attr = "cluster_attr"} : () -> () return } + +// Check that IfRegions with same predicates and returns are merged. + +// CHECK-LABEL: func @same_predicate_returns_merged +func @same_predicate_returns_merged() { + // CHECK: tf_device.cluster + // CHECK: %[[IF_OUTPUT:[0-9]*]]:2 = "tf.IfRegion" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK-NEXT: "tf.Yield"(%[[A_OUTPUT]], %[[B_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK-NEXT: "tf.Yield"(%[[C_OUTPUT]], %[[D_OUTPUT]]) + // CHECK-NOT: "tf.IfRegion" + // CHECK "tf.E"(%[[IF_OUTPUT]]#0, %[[IF_OUTPUT]]#1) + // CHECK-NOT: "tf.IfRegion" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %1 = "tf.IfRegion"(%0) ( { + %3 = "tf.A"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.C"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + %2 = "tf.IfRegion"(%0) ( { + %3 = "tf.B"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.D"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + "tf.E"(%1, %2) : (tensor, tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} +// Check that IfRegions with same predicates and unused returns. + +// CHECK-LABEL: func @same_predicate_returns_unused +func @same_predicate_returns_unused() { + // CHECK: tf_device.cluster + // CHECK: %[[IF_OUTPUT:[0-9]*]] = "tf.IfRegion" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK-NEXT: "tf.Yield"(%[[B_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK-NEXT: "tf.Yield"(%[[D_OUTPUT]]) + // CHECK-NOT: "tf.IfRegion" + // CHECK "tf.E"(%[[IF_OUTPUT]]) + // CHECK-NOT: "tf.IfRegion" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %1 = "tf.IfRegion"(%0) ( { + %3 = "tf.A"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.C"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + %2 = "tf.IfRegion"(%0) ( { + %3 = "tf.B"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.D"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + "tf.E"(%2) : (tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// CHECK-LABEL: func @same_predicate_dependency +func @same_predicate_dependency() { + // CHECK: tf_device.cluster + // CHECK: %[[IF_OUTPUT:[0-9]*]] = "tf.IfRegion" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK-NEXT: "tf.Yield"(%[[B_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK-NEXT: "tf.Yield"(%[[D_OUTPUT]]) + // CHECK-NOT: "tf.IfRegion" + // CHECK "tf.E"(%[[IF_OUTPUT]]) + // CHECK-NOT: "tf.IfRegion" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %1 = "tf.IfRegion"(%0) ( { + %3 = "tf.A"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.C"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + %2 = "tf.IfRegion"(%0) ( { + %3 = "tf.B"(%1) : (tensor) -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.D"(%1) : (tensor) -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + "tf.E"(%2) : (tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// Checks that results from first IfRegion are moved after merged IfRegion op as needed. + +// CHECK-LABEL: func @same_predicate_results_moved +func @same_predicate_results_moved(%arg0: tensor>>) { + // CHECK: tf_device.cluster + // CHECK: %[[IF_OUTPUT:[0-9]*]]:2 = "tf.IfRegion" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK-NEXT: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK-NEXT: "tf.Yield"(%[[A_OUTPUT]], %[[B_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK-NEXT: "tf.Yield"(%[[C_OUTPUT]], %[[D_OUTPUT]]) + // CHECK-NOT: "tf.IfRegion" + // CHECK "tf.AssignVariableOp(arg0, %[[IF_OUTPUT#0]]) + // CHECK "tf.E"(%[[IF_OUTPUT#1]]) + // CHECK-NEXT "tf.F"(%[[IF_OUTPUT#1]]) + // CHECK-NOT: "tf.IfRegion" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %1 = "tf.IfRegion"(%0) ( { + %3 = "tf.A"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.C"() : () -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + "tf.AssignVariableOp"(%arg0, %1) : (tensor>>, tensor) -> () + %4 = "tf.Const"() {value = dense<1.0> : tensor} : () -> (tensor) + %5 = "tf.IfRegion"(%0) ( { + %3 = "tf.B"(%4) : (tensor) -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }, { + %3 = "tf.D"(%4) : (tensor) -> (tensor) + "tf.Yield"(%3) : (tensor) -> () + }) { is_stateless = true } : (tensor) -> (tensor) + %6 = "tf.E"(%5) : (tensor) -> (tensor) + "tf.F"(%1, %6) : (tensor, tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} + +// Check that 3 IfRegions with same predicates and no intermediate dependencies are merged. + +// CHECK-LABEL: func @same_predicate_3_ifregions +func @same_predicate_3_ifregions() { + // CHECK: tf_device.cluster + // CHECK: "tf.IfRegion" + // CHECK-NOT: "tf.IfRegion" + "tf_device.cluster"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.IfRegion"(%0) ( { + %2 = "tf.A"() : () -> (tensor) + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) { is_stateless = true } : (tensor) -> () + "tf.IfRegion"(%0) ( { + %2 = "tf.B"() : () -> (tensor) + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) { is_stateless = true } : (tensor) -> () + "tf.IfRegion"(%0) ( { + %2 = "tf.C"() : () -> (tensor) + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) { is_stateless = true } : (tensor) -> () + tf_device.return + }) {cluster_attr = "cluster_attr"} : () -> () + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc index 2ed9e3a86a6..92c7df9a338 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/merge_control_flow.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include @@ -129,6 +130,21 @@ bool SafeToMerge(TF::IfRegionOp source, TF::IfRegionOp destination, return safe_to_merge; } +// Checks whether a return indice should be kep for `first_if_op` by checking +// for results in `second_if_op`. +llvm::SmallVector GetReturnIndicesToKeep(TF::IfRegionOp first_if_op, + TF::IfRegionOp second_if_op) { + llvm::SmallVector return_indices_to_keep; + for (auto& index_and_value : llvm::enumerate(first_if_op.getResults())) { + if (!llvm::all_of(index_and_value.value().getUsers(), [&](Operation* op) { + return second_if_op->isProperAncestor(op); + })) { + return_indices_to_keep.push_back(index_and_value.index()); + } + } + return return_indices_to_keep; +} + // Move the body excluding the terminators of else and then regions from // 'source' to 'destination'. void MoveBranches(TF::IfRegionOp source, TF::IfRegionOp destination) { @@ -145,31 +161,87 @@ void MoveBranches(TF::IfRegionOp source, TF::IfRegionOp destination) { source_else_body.begin(), std::prev(source_else_body.end())); } -Operation* GetIfInsertionPoint(TF::IfRegionOp source, - TF::IfRegionOp destination) { - // TODO(b/173422484): Pick this insertion point better. - return source.getOperation(); +// Move all ops that depends on the results from `result_op` after `after_op`. +void MoveResultsAfter(Operation* result_op, Operation* after_op) { + std::queue queue; + for (Operation* user : result_op->getUsers()) { + queue.push(user); + } + while (!queue.empty()) { + auto* op = queue.front(); + queue.pop(); + for (Operation* user : op->getUsers()) queue.push(user); + if (op->isBeforeInBlock(after_op)) op->moveAfter(after_op); + after_op = op; + } } -TF::IfRegionOp CreateMergedIf(TF::IfRegionOp source, +TF::IfRegionOp CreateMergedIf(ArrayRef source_return_indices_to_keep, + ArrayRef destination_return_indices_to_keep, + TF::IfRegionOp source, TF::IfRegionOp destination) { llvm::SmallVector merged_return_types; + for (int i : destination_return_indices_to_keep) + merged_return_types.push_back(destination.getResult(i).getType()); + for (int i : source_return_indices_to_keep) + merged_return_types.push_back(source.getResult(i).getType()); OpBuilder builder(destination); // Create new IfRegion with correct merged results. - builder.setInsertionPoint(GetIfInsertionPoint(source, destination)); + builder.setInsertionPoint(source.getOperation()); + auto new_if_op = builder.create( destination.getLoc(), merged_return_types, destination.cond(), destination.is_stateless() && source.is_stateless()); new_if_op.then_branch().push_back(new Block); new_if_op.else_branch().push_back(new Block); + // Replace internal usages of merged if ops. + for (OpResult result : destination.getResults()) { + replaceAllUsesInRegionWith( + result, + destination.then_branch().front().getTerminator()->getOperand( + result.getResultNumber()), + source.then_branch()); + replaceAllUsesInRegionWith( + result, + destination.else_branch().front().getTerminator()->getOperand( + result.getResultNumber()), + source.else_branch()); + } + + MoveResultsAfter(destination.getOperation(), new_if_op.getOperation()); + + // Replace external usages of merged if ops. + int new_return_index = 0; + for (int i : destination_return_indices_to_keep) { + destination.getResult(i).replaceAllUsesWith( + new_if_op.getResult(new_return_index++)); + } + for (int i : source_return_indices_to_keep) { + source.getResult(i).replaceAllUsesWith( + new_if_op.getResult(new_return_index++)); + } + + // Create the Yield ops for both branches with merged results. llvm::SmallVector merged_then_yield_values; + for (int i : destination_return_indices_to_keep) + merged_then_yield_values.push_back( + destination.then_branch().front().getTerminator()->getOperand(i)); + for (int i : source_return_indices_to_keep) + merged_then_yield_values.push_back( + source.then_branch().front().getTerminator()->getOperand(i)); builder.setInsertionPointToEnd(&new_if_op.then_branch().front()); builder.create( destination.then_branch().front().getTerminator()->getLoc(), /*operands=*/merged_then_yield_values); llvm::SmallVector merged_else_yield_values; + for (int i : destination_return_indices_to_keep) + merged_else_yield_values.push_back( + destination.else_branch().front().getTerminator()->getOperand(i)); + for (int i : source_return_indices_to_keep) + merged_else_yield_values.push_back( + source.else_branch().front().getTerminator()->getOperand(i)); builder.setInsertionPointToEnd(&new_if_op.else_branch().front()); builder.create( destination.else_branch().front().getTerminator()->getLoc(), @@ -201,7 +273,16 @@ void OptimizeIfRegions( TF::IfRegionOp if_op = if_ops[i]; if (!SafeToMerge(if_op, first_if_op, side_effect_analysis)) break; - auto new_if_op = CreateMergedIf(if_op, first_if_op); + // For both check if there are uses outside of IfRegion, keep these as + // part of the return and replace the internal uses. + auto first_return_indices_to_keep = + GetReturnIndicesToKeep(first_if_op, if_op); + auto second_return_indices_to_keep = + GetReturnIndicesToKeep(if_op, first_if_op); + + auto new_if_op = + CreateMergedIf(second_return_indices_to_keep, + first_return_indices_to_keep, if_op, first_if_op); first_if_op = new_if_op; }