diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir index 90fa8cff5dc..3e8ade180b1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_head_tail_outside_compilation.mlir @@ -6,12 +6,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-LABEL: func @single_head_outside_compilation func @single_head_outside_compilation(%arg0 : tensor) -> () { // CHECK: tf_device.launch - // // CHECK: "tf.A" // CHECK-NEXT: tf_device.return - // - // CHECK: device - // CHECK-SAME: "/job:worker/replica:0/task:0/device:CPU:0" + // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" // // CHECK: "tf_device.cluster" // CHECK: "tf.C" @@ -28,6 +25,88 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + // CHECK-LABEL: func @ops_no_operands + func @ops_no_operands() -> () { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK-NEXT: "tf.B"(%[[LAUNCH_OUT]]) + // CHECK-NEXT: "tf.C" + // CHECK-NEXT: tf_device.return + "tf_device.cluster"() ( { + %0 = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> (tensor) + %1 = "tf.B"(%0) {}: (tensor) -> (tensor) + "tf.C"(%1) : (tensor) -> () + tf_device.return + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> () + return + } +} + +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + // CHECK-LABEL: func @aliased_output + func @aliased_output() -> (tensor, tensor, tensor) { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK-NEXT: tf_device.return %[[A_OUT]] + // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster" + // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[LAUNCH_OUT]]) + // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C" + // CHECK-NEXT: tf_device.return %[[C_OUT]], %[[B_OUT]] + // CHECK-NEXT: { + // CHECK-DAG: num_cores_per_replica = 1 + // CHECK-DAG: step_marker_location = "" + // CHECK-DAG: padding_map = [] + // CHECK-DAG: topology = "" + // CHECK-DAG: device_assignment = [] + // + // CHECK: return %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#0, %[[CLUSTER_OUT]]#1 + %0:3 = "tf_device.cluster"() ( { + %1 = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> (tensor) + %2 = "tf.B"(%1) {}: (tensor) -> (tensor) + %3 = "tf.C"(%2) : (tensor) -> (tensor) + tf_device.return %1, %3, %2 : tensor, tensor, tensor + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor, tensor, tensor) + return %0#0, %0#1, %0#2 : tensor, tensor, tensor + } +} + +// ----- + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { + // CHECK-LABEL: func @all_head_computation_ops + func @all_head_computation_ops(%arg0 : tensor) -> (tensor) { + // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() + // CHECK: %[[A_OUT:.*]] = "tf.A" + // CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]]) + // CHECK: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]], %arg0) + // CHECK-NEXT: tf_device.return %[[C_OUT]] + // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" + // + // CHECK: "tf_device.cluster" + // CHECK-NEXT: tf_device.return + // + // CHECK: return %[[LAUNCH_OUT]] + %0 = "tf_device.cluster"() ( { + %1 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + %2 = "tf.B"(%1) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + %3 = "tf.C"(%2, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> (tensor) + tf_device.return %3 : tensor + }) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor) + return %0 : tensor + } +} + +// ----- + module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @multiple_head_outside_compilation func @multiple_head_outside_compilation(%arg0 : tensor) -> () { @@ -36,8 +115,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]]) // CHECK: "tf.C" // CHECK-NEXT: tf_device.return %[[B_OUT]] - // CHECK: device - // CHECK-SAME: "/job:worker/replica:0/task:0/device:CPU:0" + // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" // // CHECK: "tf_device.cluster" // CHECK: "tf.D"(%[[LAUNCH_OUT]]) @@ -83,8 +161,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[A_OUT:.*]] = "tf.A" // CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]]) // CHECK-NEXT: tf_device.return %[[D_OUT]] - // CHECK: device - // CHECK-SAME: "/job:worker/replica:0/task:0/device:CPU:0" + // CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0" // // CHECK: "tf_device.cluster" // CHECK: "tf.B" @@ -105,15 +182,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} { // CHECK-LABEL: func @test_replicated_head_outside_compilation func @test_replicated_head_outside_compilation(%arg0 : tensor) -> () { // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() // CHECK: %[[A_OUT:.*]] = "tf.A" // CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]]) // CHECK-NEXT: tf_device.return %[[D_OUT]] - // CHECK: device - // CHECK-SAME: "TPU_REPLICATED_HOST" + // CHECK: device = "TPU_REPLICATED_HOST" // // CHECK: "tf_device.cluster" // CHECK: "tf.B" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc index 02d0c3e849b..5a059ce507c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc @@ -14,9 +14,10 @@ limitations under the License. ==============================================================================*/ #include +#include #include -#include "llvm/ADT/Optional.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" @@ -26,6 +27,7 @@ limitations under the License. #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project @@ -51,181 +53,84 @@ bool HasOutsideCompilationAttribute(Operation* op) { return op->getAttrOfType(kXlaOutsideCompilationAttr) != nullptr; } -// Returns whether all operands of `op` are from values inside the -// `input_value_set`. -bool OpContainsOperandsFromSet(Operation* op, - const llvm::SetVector& input_value_set) { - for (auto operand : op->getOperands()) - if (input_value_set.count(operand) == 0) return false; +Operation* GetOpOfValue(Value value) { + if (auto block_arg = value.dyn_cast()) + return block_arg.getOwner()->getParentOp(); - return true; + return value.getDefiningOp(); } -void RecordOutsideCompiledOpsAndUsages( - Operation* op, llvm::SmallSetVector* outside_compiled_ops, - llvm::SetVector* outside_compiled_op_usages) { - if (HasOutsideCompilationAttribute(op) && - OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) { - outside_compiled_ops->insert(op); - outside_compiled_op_usages->insert(op->getResults().begin(), - op->getResults().end()); - } -} +// Returns a set of ops that are outside compiled and can be extracted to before +// the TPU computation. These ops are either connected to the inputs of the TPU +// computation or other ops that can be extracted, and have no dependencies with +// other ops in the TPU computation that cannot be extracted. +llvm::SmallVector FindOutsideCompiledOpsAtHead( + tf_device::ClusterOp cluster) { + llvm::SmallSetVector head_outside_compiled_ops; -// Traverses the MLIR graph and returns a set of ops that -// are connected to inputs of TPU computation and outside compiled. -void ExtractOutsideCompiledOpsConnectedToHead( - Value input_value, llvm::SetVector* values_used_in_host_cluster, - llvm::SmallSetVector* outside_compiled_ops) { - llvm::SmallSetVector parent_outside_compiled_ops_at_head; - for (auto& usage : input_value.getUses()) { - auto head_operation = usage.getOwner(); - RecordOutsideCompiledOpsAndUsages(head_operation, - &parent_outside_compiled_ops_at_head, - values_used_in_host_cluster); - } + auto cluster_ops = cluster.GetBody().without_terminator(); + for (Operation& cluster_op : cluster_ops) { + if (!HasOutsideCompilationAttribute(&cluster_op)) continue; + // An outside compiled op can be extracted if its operands are not from + // other ops in the cluster that cannot be extracted. + auto result = cluster_op.walk([&](Operation* op) { + for (Value operand : op->getOperands()) { + Operation* operand_op = GetOpOfValue(operand); + if (operand_op->isProperAncestor(cluster) || + cluster_op.isAncestor(operand_op) || + head_outside_compiled_ops.count(operand_op)) + continue; - // Traverse the graph and find all outside compiled ops connected from - // the `input_value`. - while (!parent_outside_compiled_ops_at_head.empty()) { - llvm::SmallSetVector connected_outside_compiled_ops; - for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) { - auto op_results = head_outside_compiled_op->getOpResults(); - for (auto op_result : op_results) { - for (auto& use : op_result.getUses()) { - auto connected_op = use.getOwner(); - RecordOutsideCompiledOpsAndUsages(connected_op, - &connected_outside_compiled_ops, - values_used_in_host_cluster); - } + return WalkResult::interrupt(); } - } + return WalkResult::advance(); + }); - outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(), - parent_outside_compiled_ops_at_head.end()); - std::swap(parent_outside_compiled_ops_at_head, - connected_outside_compiled_ops); - } -} - -// TODO(hongjunchoi): Also handle ops without inputs that are outside -// compiled. -// -// Returns set of ops that are outside compiled and are directly connected -// to inputs to the TPU computation. -llvm::SmallSetVector IdentifyOutsideCompiledOpsAtHead( - tf_device::ClusterOp tpu_cluster) { - llvm::SmallSetVector outside_compiled_at_head_ops; - llvm::SetVector values_used_in_cluster; - auto& cluster_region = tpu_cluster.body(); - getUsedValuesDefinedAbove(cluster_region, cluster_region, - values_used_in_cluster); - - auto input_value_list = llvm::to_vector<8>(values_used_in_cluster); - for (auto input_value : input_value_list) - ExtractOutsideCompiledOpsConnectedToHead( - input_value, &values_used_in_cluster, &outside_compiled_at_head_ops); - return outside_compiled_at_head_ops; -} - -// Returns output values of extracted outside compiled cluster at head that -// are used by the TPU computation. -llvm::SmallVector GetHeadExtractedClusterOutputs( - const llvm::SmallSetVector& head_outside_compiled_ops) { - llvm::SmallVector outputs; - outputs.reserve(head_outside_compiled_ops.size()); - - for (auto op : head_outside_compiled_ops) { - for (Operation* user : op->getUsers()) { - if (!head_outside_compiled_ops.count(user)) { - outputs.append(op->result_begin(), op->result_end()); - break; - } - } + if (!result.wasInterrupted()) head_outside_compiled_ops.insert(&cluster_op); } - return outputs; + return head_outside_compiled_ops.takeVector(); } -// Creates new tf_device.launch op with outside compiled ops extracted -// from the head of TPU computation. -llvm::Optional IsolateHeadExtractedOpsToLaunchOp( - OpBuilder* builder, tf_device::ClusterOp cluster, - const llvm::SmallSetVector& head_outside_compiled_ops) { - if (head_outside_compiled_ops.empty()) - return llvm::Optional(); - - // Create tf_device.launch op to separate all extracted outside compiled ops - // before the tf_device.cluster. - auto output_values = - GetHeadExtractedClusterOutputs(head_outside_compiled_ops); - - llvm::SmallVector output_return_types; - output_return_types.reserve(output_values.size()); - for (auto output : output_values) - output_return_types.emplace_back(output.getType()); - - builder->setInsertionPoint(cluster); - auto host_launch_op = builder->create( - cluster.getLoc(), builder->getStringAttr(""), output_return_types); - - // Replace all usages of outside compiled ops that are used in TPU - // computation with the results of the above created launch op. - for (auto output_and_index : llvm::enumerate(output_values)) { - auto output_index = output_and_index.index(); - auto output = output_and_index.value(); - for (auto& use : output.getUses()) { - if (!head_outside_compiled_ops.count(use.getOwner())) - use.set(host_launch_op.getResult(output_index)); - } +// Parses TPU compilation and execution devices from a TPU cluster and returns +// the host device for the head and tail computations. If the TPU computation is +// replicated, kTPUReplicatedHost is returned instead. +LogicalResult GetHostDeviceForHeadTailComputation( + mlir::TF::RuntimeDevices devices, tf_device::ClusterOp cluster, + std::string* host_device) { + auto replicate = cluster.getParentOfType(); + if (replicate) { + *host_device = tensorflow::kTPUReplicatedHost; + return success(); } - // Create terminator op for the newly created launch op. - host_launch_op.body().push_back(new Block()); - builder->setInsertionPointToEnd(&host_launch_op.GetBody()); - auto terminator = builder->create( - host_launch_op.getLoc(), output_values); - - // Move all outside compile ops from cluster op to launch op. - for (auto outside_compiled_op : head_outside_compiled_ops) - outside_compiled_op->moveBefore(terminator); - - return host_launch_op; -} - -// Parses TPU compilation and execution device form tpu cluster and assigns -// host device to `host_launch` device attribute. -LogicalResult SetCompilationDeviceToHostLaunch( - OpBuilder* builder, mlir::TF::RuntimeDevices devices, - tf_device::ClusterOp tpu_cluster, tf_device::LaunchOp host_launch) { - auto num_cores_per_replica_attr = tpu_cluster.getAttrOfType( - tensorflow::kNumCoresPerReplicaAttr); + auto num_cores_per_replica_attr = + cluster.getAttrOfType(tensorflow::kNumCoresPerReplicaAttr); if (!num_cores_per_replica_attr) - return tpu_cluster.emitOpError( + return cluster.emitOpError( "cluster op missing `num_cores_per_replica` attribute"); if (num_cores_per_replica_attr.getInt() != 1) - return tpu_cluster.emitOpError( + return cluster.emitOpError( "outside compilation is not supported with model parallelism."); auto topology_attr = - tpu_cluster.getAttrOfType(tensorflow::kTopologyAttr); + cluster.getAttrOfType(tensorflow::kTopologyAttr); if (!topology_attr) - return tpu_cluster.emitOpError("cluster op missing `topology` attribute"); + return cluster.emitOpError("cluster op missing `topology` attribute"); - auto device_assignment_attr = tpu_cluster.getAttrOfType( - tensorflow::kDeviceAssignmentAttr); + auto device_assignment_attr = + cluster.getAttrOfType(tensorflow::kDeviceAssignmentAttr); if (!device_assignment_attr) - return tpu_cluster.emitOpError( - llvm::formatv("requires attribute '{0}'", - tensorflow::kDeviceAssignmentAttr) - .str()); + return cluster.emitOpError(llvm::formatv("requires attribute '{0}'", + tensorflow::kDeviceAssignmentAttr) + .str()); auto status_or_device_coodinates = tensorflow::GetDeviceCoordinates(device_assignment_attr); if (!status_or_device_coodinates.ok()) - return tpu_cluster.emitError() + return cluster.emitError() << "error in fetching tpu device coordinates: " << status_or_device_coodinates.status().error_message(); @@ -236,37 +141,96 @@ LogicalResult SetCompilationDeviceToHostLaunch( /*num_cores_per_replica=*/1, topology_attr.getValue(), status_or_device_coodinates.ConsumeValueOrDie()); if (!status_or_tpu_device_assignment.ok()) - return tpu_cluster.emitError() + return cluster.emitError() << "error in fetching TPU compilation/execution devices: " << status_or_tpu_device_assignment.status().error_message(); auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie(); - host_launch.deviceAttr( - builder->getStringAttr(tpu_device_assignment.tpu_devices[0][0].host)); + *host_device = tpu_device_assignment.tpu_devices[0][0].host; return success(); } -// Assigns host device attribute to host launch op or enclosing -// tf_device.replicate op if TPU computation is replicated. -LogicalResult HandleHostLaunchDeviceAssignment( - OpBuilder* builder, mlir::TF::RuntimeDevices devices, - tf_device::ClusterOp tpu_cluster, tf_device::LaunchOp host_launch) { - auto parent_replicate_op = - llvm::dyn_cast_or_null(host_launch.getParentOp()); - // If computation is replicated, then add TPU_REPLICATED_HOST device alias - // to the host launch op. This device alias would later be a reference to - // host device string in the device map of tf_device.replicate op - // during tpu_rewrite pass. - if (parent_replicate_op) { - host_launch.deviceAttr( - builder->getStringAttr(tensorflow::kTPUReplicatedHost)); - } else { - if (failed(SetCompilationDeviceToHostLaunch(builder, devices, tpu_cluster, - host_launch))) - return failure(); +// Moves head outside compiled ops into its own `tf_device.LaunchOp` +// computation. +tf_device::LaunchOp CreateHeadComputation( + OpBuilder* builder, tf_device::ClusterOp cluster, + llvm::ArrayRef head_outside_compiled_ops, + llvm::StringRef host_device) { + Block* launch_block = new Block; + for (Operation* head_outside_compiled_op : head_outside_compiled_ops) + head_outside_compiled_op->moveBefore(launch_block, launch_block->end()); + + // Find results of ops in head computation that needs to returned. + llvm::SmallVector launch_results; + llvm::SmallVector launch_result_types; + for (Operation& head_outside_compiled_op : *launch_block) { + for (Value result : head_outside_compiled_op.getResults()) { + bool has_uses_in_cluster = false; + for (Operation* user : result.getUsers()) { + if (user->getParentRegion() && + cluster.body().isAncestor(user->getParentRegion())) { + has_uses_in_cluster = true; + break; + } + } + if (has_uses_in_cluster) { + launch_results.push_back(result); + launch_result_types.push_back(result.getType()); + } + } } - return success(); + builder->setInsertionPoint(cluster); + auto launch = builder->create( + cluster.getLoc(), builder->getStringAttr(host_device), + launch_result_types); + launch.body().push_back(launch_block); + + builder->setInsertionPointToEnd(&launch.GetBody()); + builder->create(cluster.getLoc(), launch_results); + + for (auto result : llvm::zip(launch_results, launch.getResults())) + replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result), + cluster.body()); + + return launch; +} + +// Removes aliased outputs in cluster from head computation after head +// computation has been extracted. +void RemoveHeadComputationAliasedOutputs(OpBuilder* builder, + tf_device::LaunchOp head_computation, + tf_device::ClusterOp cluster) { + llvm::SmallVector used_old_cluster_results; + llvm::SmallVector new_cluster_results; + llvm::SmallVector new_cluster_result_types; + Operation* cluster_terminator = cluster.GetBody().getTerminator(); + for (auto result : + llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) { + Value cluster_terminator_operand = std::get<0>(result); + if (cluster_terminator_operand.getDefiningOp() == head_computation) { + std::get<1>(result).replaceAllUsesWith(cluster_terminator_operand); + } else { + new_cluster_results.push_back(cluster_terminator_operand); + new_cluster_result_types.push_back(cluster_terminator_operand.getType()); + used_old_cluster_results.push_back(std::get<1>(result)); + } + } + + if (new_cluster_results.size() == cluster.getNumResults()) return; + + builder->setInsertionPoint(cluster); + auto new_cluster = builder->create( + cluster.getLoc(), new_cluster_result_types, + /*operands=*/llvm::ArrayRef{}, cluster.getAttrs()); + new_cluster.body().takeBody(cluster.body()); + new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results); + + for (auto result : + llvm::zip(used_old_cluster_results, new_cluster.getResults())) + std::get<0>(result).replaceAllUsesWith(std::get<1>(result)); + + cluster.erase(); } struct TPUExtractHeadTailOutsideCompilation @@ -283,22 +247,25 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() { return signalPassFailure(); OpBuilder builder(&getContext()); - auto result = module.walk([&](tf_device::ClusterOp cluster) { - auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster); - auto host_launch_op = IsolateHeadExtractedOpsToLaunchOp( - &builder, cluster, head_outside_compiled_ops); - if (host_launch_op) { - if (failed(HandleHostLaunchDeviceAssignment(&builder, devices, cluster, - *host_launch_op))) { - return WalkResult::interrupt(); - } - } + llvm::SmallVector clusters; + module.walk( + [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); }); - // TODO(b/155115766): Implement tail outside compiled op extraction. - return WalkResult::advance(); - }); + for (tf_device::ClusterOp cluster : clusters) { + llvm::SmallVector head_outside_compiled_ops = + FindOutsideCompiledOpsAtHead(cluster); + if (head_outside_compiled_ops.empty()) continue; + std::string host_device; + if (failed(GetHostDeviceForHeadTailComputation(devices, cluster, + &host_device))) + return signalPassFailure(); - if (result.wasInterrupted()) signalPassFailure(); + tf_device::LaunchOp head_computation = CreateHeadComputation( + &builder, cluster, head_outside_compiled_ops, host_device); + RemoveHeadComputationAliasedOutputs(&builder, head_computation, cluster); + + // TODO(b/157160906): Implement tail outside compiled op extraction. + } } } // anonymous namespace