diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir index dd31b7d06ef..d5fb821b5e6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir @@ -157,4 +157,232 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor } return } + + // CHECK-LABEL: func @tail_single_outside_compiled_op + func @tail_single_outside_compiled_op() { + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: "tf_device.launch" + // CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + "tf_device.cluster"() ( { + %a = "tf.A"() : () -> tensor + "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + "tf.C"() : () -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @tail_single_outside_compiled_op_user + func @tail_single_outside_compiled_op_user() -> tensor { + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]]) + // CHECK-NEXT: tf_device.return %[[B_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + %cluster = "tf_device.cluster"() ( { + %a = "tf.A"() : () -> tensor + %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + "tf.C"() : () -> () + tf_device.return %b : tensor + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor + // CHECK: return %[[LAUNCH_OUT]] + return %cluster : tensor + } + + // CHECK-LABEL: func @tail_multiple_outside_compiled_ops + func @tail_multiple_outside_compiled_ops(%arg0: tensor) { + // CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" + // CHECK-NEXT: tf_device.return %[[B_OUT]], %[[A_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: "tf_device.launch" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%arg0, %[[CLUSTER_OUT]]#1) + // CHECK-NEXT: "tf.D"(%[[C_OUT]], %arg0, %[[CLUSTER_OUT]]#0) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + "tf_device.cluster"() ( { + %a = "tf.A"() : () -> tensor + %b = "tf.B"(%arg0) : (tensor) -> tensor + %c = "tf.C"(%arg0, %a) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + "tf.D"(%c, %arg0, %b) {_xla_outside_compilation = "cluster1"} : (tensor, tensor, tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @tail_aliased_output + func @tail_aliased_output() -> (tensor, tensor, tensor, tensor, tensor) { + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" + %a = "tf.A"() : () -> tensor + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" + %b = "tf.B"() : () -> tensor + // CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C" + // CHECK-NEXT: %[[E_OUT:.*]] = "tf.E" + // CHECK-NEXT: tf_device.return %[[C_OUT]], %[[E_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[CLUSTER_OUT]]#0, %[[A_OUT]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + %cluster:5 = "tf_device.cluster"() ( { + %c = "tf.C"() : () -> tensor + %d = "tf.D"(%c, %a) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + %e = "tf.E"() : () -> tensor + tf_device.return %a, %b, %c, %d, %e : tensor, tensor, tensor, tensor, tensor + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor, tensor, tensor, tensor, tensor) + // CHECK: return %[[A_OUT]], %[[B_OUT]], %[[CLUSTER_OUT]]#0, %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#1 + return %cluster#0, %cluster#1, %cluster#2, %cluster#3, %cluster#4 : tensor, tensor, tensor, tensor, tensor + } + + // CHECK-LABEL: func @tail_replicated_outside_compilation + func @tail_replicated_outside_compilation(%arg0: tensor, %arg1: tensor) { + // CHECK: tf_device.replicate([%arg0, %arg1] as %[[RI:.*]]: tensor) + // + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]]) + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK-NEXT: "tf_device.launch"() + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]], %[[RI]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + tf_device.replicate([%arg0, %arg1] as %ri : tensor) {n = 2 : i32} { + "tf_device.cluster"() ( { + %a = "tf.A"(%ri) : (tensor) -> tensor + %b = "tf.B"(%a, %ri) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + tf_device.return + } + return + } + + // CHECK-LABEL: func @head_tail_no_extraction_middle_outside_compiled_ops + func @head_tail_no_extraction_middle_outside_compiled_ops(%arg0: tensor) { + // CHECK-NOT: "tf_device.launch" + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.A" + // CHECK-NEXT: "tf.B" + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %a = "tf.A"(%arg0) : (tensor) -> tensor + %b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + "tf.C"(%b) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } + + // CHECK-LABEL: func @head_tail_simple_extraction + func @head_tail_simple_extraction(%arg0: tensor) -> tensor { + // CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%arg0) + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[HEAD_LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return %[[B_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: %[[TAIL_LAUNCH_OUT:.*]] = "tf_device.launch" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[CLUSTER_OUT]]) + // CHECK-NEXT: tf_device.return %[[C_OUT]] + // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" + %cluster = "tf_device.cluster"() ( { + %a = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + %b = "tf.B"(%a) : (tensor) -> tensor + %c = "tf.C"(%b) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + tf_device.return %c : tensor + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor + // CHECK: return %[[TAIL_LAUNCH_OUT]] + return %cluster : tensor + } + + // CHECK-LABEL: func @head_tail_replicated_outside_compilation + func @head_tail_replicated_outside_compilation(%arg0: tensor, %arg1: tensor) { + // CHECK: tf_device.replicate([%arg0, %arg1] as %[[RI:.*]]: tensor) + // + // CHECK-NEXT: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]]) + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + // + // CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[RI]], %[[B_OUT]]) + // CHECK-NEXT: "tf.E"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]]) + // CHECK-NEXT: tf_device.return %[[C_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK-NEXT: "tf_device.launch"() + // CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]], %[[CLUSTER_OUT]], %[[RI]]) + // CHECK-NEXT: tf_device.return + // CHECK-NEXT: device = "TPU_REPLICATED_HOST" + tf_device.replicate([%arg0, %arg1] as %ri : tensor) {n = 2 : i32} { + "tf_device.cluster"() ( { + %a = "tf.A"(%ri) {_xla_outside_compilation = "cluster1"} : (tensor) -> tensor + %b = "tf.B"() : () -> tensor + %c = "tf.C"(%ri, %b) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + %d = "tf.D"(%a, %c, %ri) {_xla_outside_compilation = "cluster1"} : (tensor, tensor, tensor) -> tensor + %e = "tf.E"(%c, %a) : (tensor, tensor) -> tensor + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + tf_device.return + } + return + } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc index 95183e04223..688f21c1d52 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc @@ -78,9 +78,11 @@ bool OpInBlock(Operation* op, Block* block) { } // Wraps block in a Launch. External uses of ops in the block will be return -// values of the Launch and remapped to the Launch results. +// values of the Launch and remapped to the Launch results. If `before` is set +// to true, the Launch is created before `op`. Otherwise the Launch is created +// after `op`. tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op, - Block* launch_block, + bool before, Block* launch_block, llvm::StringRef host_device) { // Find results and result types of ops in block that needs to returned. llvm::SmallVector launch_results; @@ -100,7 +102,7 @@ tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op, } } - builder->setInsertionPoint(op); + before ? builder->setInsertionPoint(op) : builder->setInsertionPointAfter(op); auto launch = builder->create( op->getLoc(), builder->getStringAttr(host_device), launch_result_types); launch.body().push_back(launch_block); @@ -178,22 +180,21 @@ llvm::SmallVector FindOutsideCompiledOpsAtHead( Region* cluster_region = &cluster.body(); llvm::SmallSetVector head_outside_compiled_ops; - auto walk_operands = [&](Operation* op) { - for (Value operand : op->getOperands()) { - Operation* operand_op = GetOpOfValue(operand); - if (head_outside_compiled_ops.count(operand_op)) continue; - - if (operand_op->getParentRegion() == cluster_region) - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }; auto cluster_ops = cluster.GetBody().without_terminator(); for (Operation& cluster_op : cluster_ops) { if (!HasOutsideCompilationAttribute(&cluster_op)) continue; // An outside compiled op can be extracted if its operands are not from // other ops in the cluster that cannot be extracted. - auto walk_result = cluster_op.walk(walk_operands); + auto walk_result = cluster_op.walk([&](Operation* op) { + for (Value operand : op->getOperands()) { + Operation* operand_op = GetOpOfValue(operand); + if (head_outside_compiled_ops.count(operand_op)) continue; + + if (operand_op->getParentRegion() == cluster_region) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); if (!walk_result.wasInterrupted()) head_outside_compiled_ops.insert(&cluster_op); @@ -211,8 +212,8 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster, for (Operation* head_outside_compiled_op : head_outside_compiled_ops) head_outside_compiled_op->moveBefore(launch_block, launch_block->end()); - tf_device::LaunchOp launch = - CreateLaunchForBlock(builder, cluster, launch_block, host_device); + tf_device::LaunchOp launch = CreateLaunchForBlock( + builder, cluster, /*before=*/true, launch_block, host_device); for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(), launch.getResults())) @@ -220,6 +221,160 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster, cluster.body()); } +// Extracts and move outside compiled ops that have no dependencies in the +// cluster to before the cluster. +mlir::LogicalResult LiftHeadOutsideCompiledOps( + OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, + tf_device::ClusterOp cluster, std::string* host_device, + bool* cluster_updated) { + llvm::SmallVector head_outside_compiled_ops = + FindOutsideCompiledOpsAtHead(cluster); + if (head_outside_compiled_ops.empty()) return success(); + if (failed( + GetHostDeviceForHeadTailComputation(devices, cluster, host_device))) + return failure(); + + CreateHeadComputation(builder, cluster, head_outside_compiled_ops, + *host_device); + + *cluster_updated = true; + return success(); +} + +// Fills `tail_outside_compiled_ops` with ops that are outside compiled and +// can be extracted to after the TPU computation, and `cluster_results` with new +// results of the cluster. These ops are either connected to the output of the +// TPU computation or other ops that can be extracted, and have no results used +// by other ops in the TPU computation that cannot be extracted. +void FindOutsideCompiledOpsAtTailAndClusterResults( + tf_device::ClusterOp cluster, + llvm::SmallVectorImpl* tail_outside_compiled_ops, + llvm::SmallVectorImpl* cluster_results) { + Region* cluster_region = &cluster.body(); + llvm::SmallSetVector tail_outside_compiled_ops_set; + Operation* terminator = cluster.GetBody().getTerminator(); + llvm::SmallSetVector cluster_results_set; + cluster_results_set.insert(terminator->getOperands().begin(), + terminator->getOperands().end()); + + auto cluster_ops = llvm::reverse(cluster.GetBody().without_terminator()); + for (Operation& cluster_op : cluster_ops) { + if (!HasOutsideCompilationAttribute(&cluster_op)) continue; + + llvm::SmallVector results_to_forward; + bool can_be_extracted = + llvm::all_of(cluster_op.getUsers(), [&](Operation* op) { + return op == terminator || tail_outside_compiled_ops_set.count(op); + }); + if (!can_be_extracted) continue; + + // Collect operands of cluster op that are generated within the cluster. + // These values should be returned by the cluster. + cluster_op.walk([&](Operation* op) { + for (Value operand : op->getOperands()) { + Operation* operand_op = GetOpOfValue(operand); + if (operand_op->getParentRegion() == cluster_region) + cluster_results_set.insert(operand); + } + }); + + // Remove results of op to be extracted as there are no uses in the cluster. + for (Value result : cluster_op.getResults()) + cluster_results_set.remove(result); + tail_outside_compiled_ops_set.insert(&cluster_op); + } + + *tail_outside_compiled_ops = tail_outside_compiled_ops_set.takeVector(); + *cluster_results = cluster_results_set.takeVector(); +} + +// Moves tail outside compiled ops into its own `tf_device.LaunchOp` +// computation after the cluster. +void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster, + llvm::ArrayRef tail_outside_compiled_ops, + llvm::StringRef host_device) { + Block* launch_block = new Block; + for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops) + tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin()); + + tf_device::LaunchOp launch = CreateLaunchForBlock( + builder, cluster, /*before=*/false, launch_block, host_device); + + auto operand_not_in_launch = [&](OpOperand& operand) { + return !launch.getOperation()->isProperAncestor(operand.getOwner()); + }; + for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(), + launch.getResults())) + std::get<0>(result).replaceUsesWithIf(std::get<1>(result), + operand_not_in_launch); +} + +// Updates cluster with updated cluster results after extracting tail outside +// compiled ops. +tf_device::ClusterOp UpdateClusterResults( + OpBuilder* builder, tf_device::ClusterOp cluster, + llvm::ArrayRef new_cluster_results) { + Operation* old_terminator = cluster.GetBody().getTerminator(); + builder->setInsertionPoint(old_terminator); + builder->create(old_terminator->getLoc(), + new_cluster_results); + old_terminator->erase(); + + builder->setInsertionPoint(cluster); + llvm::SmallVector new_cluster_result_types; + new_cluster_result_types.reserve(new_cluster_results.size()); + for (const auto& new_cluster_result : new_cluster_results) + new_cluster_result_types.push_back(new_cluster_result.getType()); + + auto new_cluster = builder->create( + cluster.getLoc(), new_cluster_result_types, + /*operands=*/llvm::ArrayRef{}, cluster.getAttrs()); + new_cluster.body().takeBody(cluster.body()); + + auto operand_not_in_cluster = [&](OpOperand& operand) { + return !new_cluster.getOperation()->isProperAncestor(operand.getOwner()); + }; + for (auto result : + llvm::zip(new_cluster.GetBody().getTerminator()->getOperands(), + new_cluster.getResults())) + std::get<0>(result).replaceUsesWithIf(std::get<1>(result), + operand_not_in_cluster); + + cluster.erase(); + return new_cluster; +} + +// Extracts and move outside compiled ops that do not create dependencies in the +// cluster to after the cluster. +mlir::LogicalResult LiftTailOutsideCompiledOps( + OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, + std::string host_device, tf_device::ClusterOp* cluster, + bool* cluster_updated) { + llvm::SmallVector tail_outside_compiled_ops; + llvm::SmallVector cluster_results; + FindOutsideCompiledOpsAtTailAndClusterResults( + *cluster, &tail_outside_compiled_ops, &cluster_results); + if (tail_outside_compiled_ops.empty()) return success(); + + if (host_device.empty()) + if (failed(GetHostDeviceForHeadTailComputation(devices, *cluster, + &host_device))) + return failure(); + + // Forward all results of cluster first. These results will be remapped once + // a new cluster is formed. + cluster->replaceAllUsesWith( + cluster->GetBody().getTerminator()->getOperands()); + + CreateTailComputation(builder, *cluster, tail_outside_compiled_ops, + host_device); + + *cluster = UpdateClusterResults(builder, *cluster, cluster_results); + + *cluster_updated = true; + return success(); +} + // Removes aliased outputs in cluster from ops outside of cluster. void RemoveClusterAliasedOutputs(OpBuilder* builder, tf_device::ClusterOp cluster) { @@ -256,26 +411,6 @@ void RemoveClusterAliasedOutputs(OpBuilder* builder, cluster.erase(); } -// Extracts and move outside compiled ops that have no dependencies in the -// cluster to before the cluster. -mlir::LogicalResult LiftHeadOutsideCompiledOps( - OpBuilder* builder, const mlir::TF::RuntimeDevices& devices, - tf_device::ClusterOp cluster) { - llvm::SmallVector head_outside_compiled_ops = - FindOutsideCompiledOpsAtHead(cluster); - if (head_outside_compiled_ops.empty()) return success(); - std::string host_device; - if (failed( - GetHostDeviceForHeadTailComputation(devices, cluster, &host_device))) - return failure(); - - CreateHeadComputation(builder, cluster, head_outside_compiled_ops, - host_device); - - RemoveClusterAliasedOutputs(builder, cluster); - return success(); -} - struct TPUExtractHeadTailOutsideCompilation : public PassWrapper> { @@ -295,10 +430,14 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() { [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); }); for (tf_device::ClusterOp cluster : clusters) { - if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster))) + std::string host_device; + bool cluster_updated = false; + if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster, + &host_device, &cluster_updated)) || + failed(LiftTailOutsideCompiledOps(&builder, devices, host_device, + &cluster, &cluster_updated))) return signalPassFailure(); - - // TODO(b/157160906): Implement tail outside compiled op extraction. + if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster); } }