From 00f027a374c71bd1a1976e8ac8f0d9c333c5e172 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Fri, 29 May 2020 17:12:33 -0700 Subject: [PATCH] Update TPUExtractHeadTailOutsideCompilation pass to support tail outside compiled computation extraction. This extends the pass to lift ops to after the cluster if they are outside compiled and have no dependencies, other than other ops that are to be lifted. PiperOrigin-RevId: 313877620 Change-Id: Ia5d068d74383206dc6ffaed06f429b7b93767271 --- ...extract_head_tail_outside_compilation.mlir | 228 ++++++++++++++++++ ...u_extract_head_tail_outside_compilation.cc | 217 ++++++++++++++--- 2 files changed, 406 insertions(+), 39 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir index dd31b7d06ef..d5fb821b5e6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir @@ -157,4 +157,232 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor } return } + + // CHECK-LABEL: func @tail_single_outside_compiled_op + func @tail_single_outside_compiled_op() { + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + "tf_device.cluster"() ( { + %a = "tf.A"() : () -> tensor + "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + "tf.C"() : () -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @tail_single_outside_compiled_op_user + func @tail_single_outside_compiled_op_user() -> tensor { + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]]) + // CHECK-NEXT: tf_device.return %[[B_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + %cluster = "tf_device.cluster"() ( { + %a = "tf.A"() : () -> tensor + %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + "tf.C"() : () -> () + tf_device.return %b : tensor + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor + // CHECK: return %[[LAUNCH_OUT]] + return %cluster : tensor + } + + // CHECK-LABEL: func @tail_multiple_outside_compiled_ops + func @tail_multiple_outside_compiled_ops(%arg0: tensor) { + // CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" + // CHECK-NEXT: tf_device.return %[[B_OUT]], %[[A_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: "tf_device.launch" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%arg0, %[[CLUSTER_OUT]]#1) + // CHECK-NEXT: "tf.D"(%[[C_OUT]], %arg0, %[[CLUSTER_OUT]]#0) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + "tf_device.cluster"() ( { + %a = "tf.A"() : () -> tensor + %b = "tf.B"(%arg0) : (tensor) -> tensor + %c = "tf.C"(%arg0, %a) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + "tf.D"(%c, %arg0, %b) {_xla_outside_compilation = "cluster1"} : (tensor, tensor, tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @tail_aliased_output + func @tail_aliased_output() -> (tensor, tensor, tensor, tensor, tensor) { + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + %a = "tf.A"() : () -> tensor + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" + %b = "tf.B"() : () -> tensor + // CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C" + // CHECK-NEXT: %[[E_OUT:.*]] = "tf.E" + // CHECK-NEXT: tf_device.return %[[C_OUT]], %[[E_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[CLUSTER_OUT]]#0, %[[A_OUT]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + %cluster:5 = "tf_device.cluster"() ( { + %c = "tf.C"() : () -> tensor + %d = "tf.D"(%c, %a) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + %e = "tf.E"() : () -> tensor + tf_device.return %a, %b, %c, %d, %e : tensor, tensor, tensor, tensor, tensor + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor, tensor, tensor, tensor, tensor) + // CHECK: return %[[A_OUT]], %[[B_OUT]], %[[CLUSTER_OUT]]#0, %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#1 + return %cluster#0, %cluster#1, %cluster#2, %cluster#3, %cluster#4 : tensor, tensor, tensor, tensor, tensor + } + + // CHECK-LABEL: func @tail_replicated_outside_compilation + func @tail_replicated_outside_compilation(%arg0: tensor, %arg1: tensor) { + // CHECK: tf_device.replicate([%arg0, %arg1] as %[[RI:.*]]: tensor) + // + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]]) + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK-NEXT: "tf_device.launch"() + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]], %[[RI]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + tf_device.replicate([%arg0, %arg1] as %ri : tensor) {n = 2 : i32} { + "tf_device.cluster"() ( { + %a = "tf.A"(%ri) : (tensor) -> tensor + %b = "tf.B"(%a, %ri) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + tf_device.return + } + return + } + + // CHECK-LABEL: func @head_tail_no_extraction_middle_outside_compiled_ops + func @head_tail_no_extraction_middle_outside_compiled_ops(%arg0: tensor) { + // CHECK-NOT: "tf_device.launch" + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %a = "tf.A"(%arg0) : (tensor) -> tensor + %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + "tf.C"(%b) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @head_tail_simple_extraction + func @head_tail_simple_extraction(%arg0: tensor) -> tensor { + // CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%arg0) + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[HEAD_LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return %[[B_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: %[[TAIL_LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[CLUSTER_OUT]]) + // CHECK-NEXT: tf_device.return %[[C_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + %cluster = "tf_device.cluster"() ( { + %a = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + %b = "tf.B"(%a) : (tensor) -> tensor + %c = "tf.C"(%b) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + tf_device.return %c : tensor + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor + // CHECK: return %[[TAIL_LAUNCH_OUT]] + return %cluster : tensor + } + + // CHECK-LABEL: func @head_tail_replicated_outside_compilation + func @head_tail_replicated_outside_compilation(%arg0: tensor, %arg1: tensor) { + // CHECK: tf_device.replicate([%arg0, %arg1] as %[[RI:.*]]: tensor) + // + // CHECK-NEXT: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]]) + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + // + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[RI]], %[[B_OUT]]) + // CHECK-NEXT: "tf.E"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return %[[C_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK-NEXT: "tf_device.launch"() + // CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]], %[[CLUSTER_OUT]], %[[RI]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + tf_device.replicate([%arg0, %arg1] as %ri : tensor) {n = 2 : i32} { + "tf_device.cluster"() ( { + %a = "tf.A"(%ri) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + %b = "tf.B"() : () -> tensor + %c = "tf.C"(%ri, %b) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + %d = "tf.D"(%a, %c, %ri) {_xla_outside_compilation = "cluster1"} : (tensor, tensor, tensor) -> tensor + %e = "tf.E"(%c, %a) : (tensor, tensor) -> tensor + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + tf_device.return + } + return + } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc index 95183e04223..688f21c1d52 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc @@ -78,9 +78,11 @@ bool OpInBlock(Operation* op, Block* block) { } // Wraps block in a Launch. External uses of ops in the block will be return -// values of the Launch and remapped to the Launch results. +// values of the Launch and remapped to the Launch results. If `before` is set +// to true, the Launch is created before `op`. Otherwise the Launch is created +// after `op`. tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op, - Block* launch_block, + bool before, Block* launch_block, llvm::StringRef host_device) { // Find results and result types of ops in block that needs to returned. llvm::SmallVector launch_results; @@ -100,7 +102,7 @@ tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op, } } - builder->setInsertionPoint(op); + before ? builder->setInsertionPoint(op) : builder->setInsertionPointAfter(op); auto launch = builder->create( op->getLoc(), builder->getStringAttr(host_device), launch_result_types); launch.body().push_back(launch_block); @@ -178,22 +180,21 @@ llvm::SmallVector FindOutsideCompiledOpsAtHead( Region* cluster_region = &cluster.body(); llvm::SmallSetVector head_outside_compiled_ops; - auto walk_operands = [&](Operation* op) { - for (Value operand : op->getOperands()) { - Operation* operand_op = GetOpOfValue(operand); - if (head_outside_compiled_ops.count(operand_op)) continue; - - if (operand_op->getParentRegion() == cluster_region) - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }; auto cluster_ops = cluster.GetBody().without_terminator(); for (Operation& cluster_op : cluster_ops) { if (!HasOutsideCompilationAttribute(&cluster_op)) continue; // An outside compiled op can be extracted if its operands are not from // other ops in the cluster that cannot be extracted. - auto walk_result = cluster_op.walk(walk_operands); + auto walk_result = cluster_op.walk([&](Operation* op) { + for (Value operand : op->getOperands()) { + Operation* operand_op = GetOpOfValue(operand); + if (head_outside_compiled_ops.count(operand_op)) continue; + + if (operand_op->getParentRegion() == cluster_region) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); if (!walk_result.wasInterrupted()) head_outside_compiled_ops.insert(&cluster_op); @@ -211,8 +212,8 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster, for (Operation* head_outside_compiled_op : head_outside_compiled_ops) head_outside_compiled_op->moveBefore(launch_block, launch_block->end()); - tf_device::LaunchOp launch = - CreateLaunchForBlock(builder, cluster, launch_block, host_device); + tf_device::LaunchOp launch = CreateLaunchForBlock( + builder, cluster, /*before=*/true, launch_block, host_device); for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(), launch.getResults())) @@ -220,6 +221,160 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster, cluster.body()); } +// Extracts and move outside compiled ops that have no dependencies in the +// cluster to before the cluster. +mlir::LogicalResult LiftHeadOutsideCompiledOps( + OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, + tf_device::ClusterOp cluster, std::string* host_device, + bool* cluster_updated) { + llvm::SmallVector head_outside_compiled_ops = + FindOutsideCompiledOpsAtHead(cluster); + if (head_outside_compiled_ops.empty()) return success(); + if (failed( + GetHostDeviceForHeadTailComputation(devices, cluster, host_device))) + return failure(); + + CreateHeadComputation(builder, cluster, head_outside_compiled_ops, + *host_device); + + *cluster_updated = true; + return success(); +} + +// Fills `tail_outside_compiled_ops` with ops that are outside compiled and +// can be extracted to after the TPU computation, and `cluster_results` with new +// results of the cluster. These ops are either connected to the output of the +// TPU computation or other ops that can be extracted, and have no results used +// by other ops in the TPU computation that cannot be extracted. +void FindOutsideCompiledOpsAtTailAndClusterResults( + tf_device::ClusterOp cluster, + llvm::SmallVectorImpl* tail_outside_compiled_ops, + llvm::SmallVectorImpl* cluster_results) { + Region* cluster_region = &cluster.body(); + llvm::SmallSetVector tail_outside_compiled_ops_set; + Operation* terminator = cluster.GetBody().getTerminator(); + llvm::SmallSetVector cluster_results_set; + cluster_results_set.insert(terminator->getOperands().begin(), + terminator->getOperands().end()); + + auto cluster_ops = llvm::reverse(cluster.GetBody().without_terminator()); + for (Operation& cluster_op : cluster_ops) { + if (!HasOutsideCompilationAttribute(&cluster_op)) continue; + + llvm::SmallVector results_to_forward; + bool can_be_extracted = + llvm::all_of(cluster_op.getUsers(), [&](Operation* op) { + return op == terminator || tail_outside_compiled_ops_set.count(op); + }); + if (!can_be_extracted) continue; + + // Collect operands of cluster op that are generated within the cluster. + // These values should be returned by the cluster. + cluster_op.walk([&](Operation* op) { + for (Value operand : op->getOperands()) { + Operation* operand_op = GetOpOfValue(operand); + if (operand_op->getParentRegion() == cluster_region) + cluster_results_set.insert(operand); + } + }); + + // Remove results of op to be extracted as there are no uses in the cluster. + for (Value result : cluster_op.getResults()) + cluster_results_set.remove(result); + tail_outside_compiled_ops_set.insert(&cluster_op); + } + + *tail_outside_compiled_ops = tail_outside_compiled_ops_set.takeVector(); + *cluster_results = cluster_results_set.takeVector(); +} + +// Moves tail outside compiled ops into its own `tf_device.LaunchOp` +// computation after the cluster. +void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster, + llvm::ArrayRef tail_outside_compiled_ops, + llvm::StringRef host_device) { + Block* launch_block = new Block; + for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops) + tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin()); + + tf_device::LaunchOp launch = CreateLaunchForBlock( + builder, cluster, /*before=*/false, launch_block, host_device); + + auto operand_not_in_launch = [&](OpOperand& operand) { + return !launch.getOperation()->isProperAncestor(operand.getOwner()); + }; + for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(), + launch.getResults())) + std::get<0>(result).replaceUsesWithIf(std::get<1>(result), + operand_not_in_launch); +} + +// Updates cluster with updated cluster results after extracting tail outside +// compiled ops. +tf_device::ClusterOp UpdateClusterResults( + OpBuilder* builder, tf_device::ClusterOp cluster, + llvm::ArrayRef new_cluster_results) { + Operation* old_terminator = cluster.GetBody().getTerminator(); + builder->setInsertionPoint(old_terminator); + builder->create(old_terminator->getLoc(), + new_cluster_results); + old_terminator->erase(); + + builder->setInsertionPoint(cluster); + llvm::SmallVector new_cluster_result_types; + new_cluster_result_types.reserve(new_cluster_results.size()); + for (const auto& new_cluster_result : new_cluster_results) + new_cluster_result_types.push_back(new_cluster_result.getType()); + + auto new_cluster = builder->create( + cluster.getLoc(), new_cluster_result_types, + /*operands=*/llvm::ArrayRef{}, cluster.getAttrs()); + new_cluster.body().takeBody(cluster.body()); + + auto operand_not_in_cluster = [&](OpOperand& operand) { + return !new_cluster.getOperation()->isProperAncestor(operand.getOwner()); + }; + for (auto result : + llvm::zip(new_cluster.GetBody().getTerminator()->getOperands(), + new_cluster.getResults())) + std::get<0>(result).replaceUsesWithIf(std::get<1>(result), + operand_not_in_cluster); + + cluster.erase(); + return new_cluster; +} + +// Extracts and move outside compiled ops that do not create dependencies in the +// cluster to after the cluster. +mlir::LogicalResult LiftTailOutsideCompiledOps( + OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, + std::string host_device, tf_device::ClusterOp* cluster, + bool* cluster_updated) { + llvm::SmallVector tail_outside_compiled_ops; + llvm::SmallVector cluster_results; + FindOutsideCompiledOpsAtTailAndClusterResults( + *cluster, &tail_outside_compiled_ops, &cluster_results); + if (tail_outside_compiled_ops.empty()) return success(); + + if (host_device.empty()) + if (failed(GetHostDeviceForHeadTailComputation(devices, *cluster, + &host_device))) + return failure(); + + // Forward all results of cluster first. These results will be remapped once + // a new cluster is formed. + cluster->replaceAllUsesWith( + cluster->GetBody().getTerminator()->getOperands()); + + CreateTailComputation(builder, *cluster, tail_outside_compiled_ops, + host_device); + + *cluster = UpdateClusterResults(builder, *cluster, cluster_results); + + *cluster_updated = true; + return success(); +} + // Removes aliased outputs in cluster from ops outside of cluster. void RemoveClusterAliasedOutputs(OpBuilder* builder, tf_device::ClusterOp cluster) { @@ -256,26 +411,6 @@ void RemoveClusterAliasedOutputs(OpBuilder* builder, cluster.erase(); } -// Extracts and move outside compiled ops that have no dependencies in the -// cluster to before the cluster. -mlir::LogicalResult LiftHeadOutsideCompiledOps( - OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, - tf_device::ClusterOp cluster) { - llvm::SmallVector head_outside_compiled_ops = - FindOutsideCompiledOpsAtHead(cluster); - if (head_outside_compiled_ops.empty()) return success(); - std::string host_device; - if (failed( - GetHostDeviceForHeadTailComputation(devices, cluster, &host_device))) - return failure(); - - CreateHeadComputation(builder, cluster, head_outside_compiled_ops, - host_device); - - RemoveClusterAliasedOutputs(builder, cluster); - return success(); -} - struct TPUExtractHeadTailOutsideCompilation : public PassWrapper> { @@ -295,10 +430,14 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() { [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); }); for (tf_device::ClusterOp cluster : clusters) { - if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster))) + std::string host_device; + bool cluster_updated = false; + if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster, + &host_device, &cluster_updated)) || + failed(LiftTailOutsideCompiledOps(&builder, devices, host_device, + &cluster, &cluster_updated))) return signalPassFailure(); - - // TODO(b/157160906): Implement tail outside compiled op extraction. + if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster); } }