From 6e195c5d4e7aa40539bbe6ae878ba94de46a56cf Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Tue, 1 Sep 2020 09:01:57 -0700 Subject: [PATCH] Handle WhileRegionOp in extract outside compilation. Handled in similar way as IfRegionOp. Main difference is that WhileRegionOp has a single body and the condition is communicated in the first block not before the control flow. PiperOrigin-RevId: 329516973 Change-Id: Ic7457d50be5174c962effba1bb16866c19601a68 --- .../tpu_extract_outside_compilation.mlir | 249 ++++++++++++++++++ .../tpu_extract_outside_compilation.cc | 205 ++++++++++---- 2 files changed, 400 insertions(+), 54 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir index 2271bca7382..5b8f1887e13 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir @@ -819,4 +819,253 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor return %1 : tensor } + + // Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op body. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_body + func @outside_compiled_ops_inside_tf_while_body(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster1_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]]) + // CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir" + // CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.H"(%arg1) : (tensor) -> tensor + "tf.Yield"(%7) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %8 = "tf.C"(%arg1) : (tensor) -> tensor + %9 = "tf.D"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + "tf.Yield"(%8, %9) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op cond. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_cond + func @outside_compiled_ops_inside_tf_while_cond(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1) + // CHECK-NEXT: "tf._XlaSendFromHost"(%[[I_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster1_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK "tf.XlaHostCompute" + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: "tf.D" + // CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.I"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + %8 = "tf.H"(%7) : (tensor) -> tensor + "tf.Yield"(%8) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.C"(%arg1) : (tensor) -> tensor + %8 = "tf.D"(%arg1, %arg2) : (tensor, tensor) -> tensor + "tf.Yield"(%7, %8) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op cond and body. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_cond_body + func @outside_compiled_ops_inside_tf_while_cond_body(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster2_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]]) + // CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1) + // CHECK-NEXT: "tf._XlaSendFromHost"(%[[I_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster1_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK "tf.XlaHostCompute" + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.I"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> tensor + %8 = "tf.H"(%7) : (tensor) -> tensor + "tf.Yield"(%8) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.C"(%arg1) : (tensor) -> tensor + %8 = "tf.D"(%arg1, %arg2) {_xla_outside_compilation = "cluster2"} : (tensor, tensor) -> tensor + "tf.Yield"(%7, %8) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.IfRegion op + // nested in a tf.WhileRegion. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_if + func @outside_compiled_ops_inside_tf_while_if(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: tf.WhileRegion" + // CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "while_condition_channel_cluster1_0" + // CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]]) + // CHECK: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) + // CHECK: "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D" + // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK_NEXT: "tf.Yield" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]]) + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" + // CHECK-NEXT: "tf.XlaSendToHost"(%[[G_OUTPUT]]) + // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) + // CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir" + // CHECK-NEXT "tf.Yield"(%[[HOST_COMPUTE_OUTPUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.WhileRegion"(%4, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = "tf.H"(%arg1) : (tensor) -> tensor + "tf.Yield"(%7) : (tensor) -> () + }, { + ^bb0(%arg1: tensor, %arg2: tensor): + %8 = "tf.C"(%arg1) : (tensor) -> tensor + %10 = "tf.IfRegion"(%6) ({ + %9 = "tf.D"() {_xla_outside_compilation = "cluster1"} : () -> tensor + "tf.Yield"(%9) : (tensor) -> () + }, { + "tf.Yield"(%arg2) : (tensor) -> () + }) { is_stateless = false} : (tensor) -> tensor + "tf.Yield"(%8, %10) : (tensor, tensor) -> () + }) { is_stateless = false} : (tensor, tensor) -> (tensor, tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index b141a7dc792..98d62b77975 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -88,22 +88,30 @@ struct TPUExtractOutsideCompilation }; // Holds information about control flow operations that wrap outside compiled -// op. Currently only tf.If op is supported. +// op. Currently only tf.IfRegion and tf.WhileRegion ops are supported. class ControlFlowStackInfo { public: - enum ControlFlowBranchType { kIfThen, kIfElse }; + enum ControlFlowBranchType { kIfThen, kIfElse, kWhileCond, kWhileBody }; explicit ControlFlowStackInfo(Operation* wrapping_op, Operation* nested_op) : callsite_op_(wrapping_op) { - // Only tf.IfRegion op is supported for now. - auto control_flow_op = llvm::cast(callsite_op_); - assert(control_flow_op); - - auto parent_region = nested_op->getParentRegion(); - if (&control_flow_op.then_branch() == parent_region) { - type_ = ControlFlowBranchType::kIfThen; + if (auto control_flow_op = llvm::dyn_cast(callsite_op_)) { + auto parent_region = nested_op->getParentRegion(); + if (&control_flow_op.then_branch() == parent_region) { + type_ = ControlFlowBranchType::kIfThen; + } else { + type_ = ControlFlowBranchType::kIfElse; + } + } else if (auto control_flow_op = + llvm::dyn_cast(callsite_op_)) { + auto parent_region = nested_op->getParentRegion(); + if (&control_flow_op.cond() == parent_region) { + type_ = ControlFlowBranchType::kWhileCond; + } else { + type_ = ControlFlowBranchType::kWhileBody; + } } else { - type_ = ControlFlowBranchType::kIfElse; + assert(false); } } @@ -133,7 +141,7 @@ llvm::SmallVector GetControlFlowStackForOp( Operation* op_in_stack = op; while (op_in_stack != tpu_cluster.getOperation()) { auto parent_op = op_in_stack->getParentOp(); - if (llvm::isa(parent_op)) { + if (llvm::isa(parent_op)) { controlflow_stack.insert(controlflow_stack.begin(), ControlFlowStackInfo(parent_op, op_in_stack)); } @@ -166,7 +174,7 @@ TF::IfRegionOp CloneEmptyIfWithPredicate(Value predicate, bool is_stateless, // Replicates tf.IfRegion op to host side computation. Operation* ReplicateIf(const ControlFlowStackInfo& controlflow_info, - llvm::StringRef outside_cluster_name, ModuleOp module, + llvm::StringRef outside_cluster_name, Value compilation_key, OpBuilder* builder, int* send_recv_counter) { // Create XlaSendToHostOp to send predicate value from device to host. @@ -200,6 +208,63 @@ Operation* ReplicateIf(const ControlFlowStackInfo& controlflow_info, if_callsite_op.getLoc(), builder); } +// Creates a WhileRegionOp cond and body regions with yield op and +// an empty body. +TF::WhileRegionOp CloneEmptyWhile(bool is_stateless, APInt parallel_iterations, + Location loc, OpBuilder* builder) { + auto host_side_while = builder->create( + loc, /*output=*/ArrayRef{}, /*input=*/ArrayRef{}, + is_stateless, parallel_iterations); + + // Create empty else branch region. + auto& body = host_side_while.body(); + body.push_back(new Block); + builder->setInsertionPointToEnd(&body.front()); + builder->create(loc, /*operands=*/ArrayRef{}); + return host_side_while; +} + +// Replicates tf.WhileRegion op to host side computation. +Operation* ReplicateWhile(const ControlFlowStackInfo& controlflow_info, + llvm::StringRef outside_cluster_name, + Value compilation_key, OpBuilder* builder, + int* send_recv_counter) { + // Create XlaSendToHostOp to send cond region output from device to host. + OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint(); + auto while_callsite_op = + llvm::cast(controlflow_info.GetCallSiteOp()); + builder->setInsertionPoint(while_callsite_op.cond().front().getTerminator()); + + const auto condition_send_recv_key = + llvm::formatv("while_condition_channel_{0}_{1}", outside_cluster_name, + *send_recv_counter) + .str(); + *send_recv_counter += 1; + auto condition = + while_callsite_op.cond().front().getTerminator()->getOperand(0); + builder->create(while_callsite_op.getLoc(), condition, + condition_send_recv_key); + builder->restoreInsertionPoint(insert_point); + + auto host_side_while = CloneEmptyWhile( + while_callsite_op.is_stateless(), while_callsite_op.parallel_iterations(), + while_callsite_op.getLoc(), builder); + + // Create cond region and yield the condition from the device. + auto& cond = host_side_while.cond(); + cond.push_back(new Block); + builder->setInsertionPointToEnd(&cond.front()); + auto recv_condition_at_host = builder->create( + while_callsite_op.getLoc(), llvm::ArrayRef{condition.getType()}, + /*dynamic_key=*/compilation_key, + builder->getStringAttr(condition_send_recv_key), + /*device_ordinal=*/builder->getI64IntegerAttr(0)); + builder->create(while_callsite_op.getLoc(), + recv_condition_at_host.getResults()); + + return host_side_while; +} + // TODO(b/157054714): Use a better abstraction instead of // _TPUCompileMlirOp and _XlaRecvAtHostOp and _XlaSendFromHostOp. // Creates a compilation key as placeholder. A placeholder compilation cache key @@ -216,7 +281,7 @@ Value CreateCompilationKeyPlaceholder(Location loc, OpBuilder* builder) { // Replicates the control flow operations that wraps outside compiled ops to // `destination_block`. -Block* ReplicateControlFlowStack( +Operation* ReplicateControlFlowStack( llvm::StringRef outside_cluster_name, const llvm::SmallVectorImpl& stack_info, tf_device::ClusterOp tpu_cluster, ModuleOp module, Value compilation_key, @@ -227,32 +292,60 @@ Block* ReplicateControlFlowStack( for (const auto& controlflow_stack_info : stack_info) { // Create control flow op given provided insertion point and // ControlFlowStackInfo. - previous_replicated_controlflow_op = - ReplicateIf(controlflow_stack_info, outside_cluster_name, module, - compilation_key, &builder, send_recv_counter); - auto if_op = llvm::cast(previous_replicated_controlflow_op); - auto type = controlflow_stack_info.GetBranchType(); + if (auto control_flow_op = llvm::dyn_cast( + controlflow_stack_info.GetCallSiteOp())) { + previous_replicated_controlflow_op = + ReplicateIf(controlflow_stack_info, outside_cluster_name, + compilation_key, &builder, send_recv_counter); + auto if_op = + llvm::cast(previous_replicated_controlflow_op); + auto type = controlflow_stack_info.GetBranchType(); - // Update the insertion point to proper region inside the newly created - // control flow op. - if (type == ControlFlowStackInfo::kIfThen) { - builder.setInsertionPoint(&if_op.then_branch().front().front()); - } else { - builder.setInsertionPoint(&if_op.else_branch().front().front()); + // Update the insertion point to proper region inside the newly created + // control flow op. + if (type == ControlFlowStackInfo::kIfThen) { + builder.setInsertionPoint(&if_op.then_branch().front().front()); + } else { + builder.setInsertionPoint(&if_op.else_branch().front().front()); + } + } else if (auto control_flow_op = llvm::dyn_cast( + controlflow_stack_info.GetCallSiteOp())) { + previous_replicated_controlflow_op = + ReplicateWhile(controlflow_stack_info, outside_cluster_name, + compilation_key, &builder, send_recv_counter); + auto while_op = + llvm::cast(previous_replicated_controlflow_op); + auto type = controlflow_stack_info.GetBranchType(); + if (type == ControlFlowStackInfo::kWhileCond) { + builder.setInsertionPoint(&while_op.cond().front().front()); + } else { + builder.setInsertionPoint(&while_op.body().front().front()); + } } } - // Return the inner most branch at which outside compiled op is located. - // This block will later be used as insertion point to create send/recv ops. - auto inner_most_controlflow_stack = stack_info.back(); - auto inner_most_if = - llvm::cast(previous_replicated_controlflow_op); - if (inner_most_controlflow_stack.GetBranchType() == - ControlFlowStackInfo::kIfThen) { - return &inner_most_if.then_branch().front(); - } else { - return &inner_most_if.else_branch().front(); + // Return operation which should be used to as the insertion point to create + // send/recv ops. + if (auto inner_most_if = + llvm::dyn_cast(previous_replicated_controlflow_op)) { + auto inner_most_controlflow_stack = stack_info.back(); + if (inner_most_controlflow_stack.GetBranchType() == + ControlFlowStackInfo::kIfThen) { + return inner_most_if.then_branch().front().getTerminator(); + } else { + return inner_most_if.else_branch().front().getTerminator(); + } + } else if (auto inner_most_while = llvm::dyn_cast( + previous_replicated_controlflow_op)) { + auto inner_most_controlflow_stack = stack_info.back(); + if (inner_most_controlflow_stack.GetBranchType() == + ControlFlowStackInfo::kWhileCond) { + return &inner_most_while.cond().front().front(); + } else { + return inner_most_while.body().front().getTerminator(); + } } + return destination_block->getTerminator(); } // Collects and clusters ops in `block` with the same `_xla_outside_compilation` @@ -279,18 +372,17 @@ LogicalResult CollectAndGroupOutsideClusterOps(Block* block, return failure(walk_result.wasInterrupted()); } -// Moves `cluster_ops` to associated `block`. -void MoveOutsideClusterOpsToBlock(Block& block, - llvm::ArrayRef cluster_ops, - MLIRContext* context) { - Operation* terminator = block.getTerminator(); +// Moves `cluster_ops` before `op`. +void MoveOutsideClusterOpsBeforeOp(Operation* op, + llvm::ArrayRef cluster_ops, + MLIRContext* context) { for (Operation* cluster_op : cluster_ops) { // Remove `_xla_outside_compilation` and `device` attribute from ops in the // cluster as that information will be present in the `launch_op`. cluster_op->removeAttr( Identifier::get(kXlaOutsideCompilationAttr, context)); cluster_op->removeAttr(Identifier::get(kDeviceAttr, context)); - cluster_op->moveBefore(terminator); + cluster_op->moveBefore(op); } } @@ -330,11 +422,18 @@ llvm::SmallSetVector GetExternalOperands( // in `host_cluster_ops`. for (Value v : op->getOperands()) { Operation* defining_op = v.getDefiningOp(); - if (!defining_op) continue; - bool is_external = llvm::none_of( - host_cluster_ops, - [&](Operation* cluster_op) { return defining_op == cluster_op; }); - + bool is_external = false; + if (defining_op) { + is_external = + llvm::none_of(host_cluster_ops, [&](Operation* cluster_op) { + return defining_op == cluster_op; + }); + } else { + if (auto block_arg = v.dyn_cast()) { + if (block_arg.getParentRegion() == cluster_op_parent_region) + is_external = true; + } + } if (is_external) external_values.insert(v); } } else { @@ -432,24 +531,23 @@ void MoveOutsideCompiledOps( CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), &builder); } - Block* block_to_move_host_cluster = nullptr; + Operation* insertion_op = nullptr; if (controlflow_stack.empty()) { - block_to_move_host_cluster = &host_launch_op.GetBody(); + insertion_op = host_launch_op.GetBody().getTerminator(); } else { int send_recv_counter = 0; - block_to_move_host_cluster = ReplicateControlFlowStack( + insertion_op = ReplicateControlFlowStack( outside_cluster_name, controlflow_stack, tpu_cluster, module, compilation_key, &host_launch_op.GetBody(), &send_recv_counter); } MLIRContext* context = host_launch_op.getContext(); if (external_inputs.empty() && external_outputs.empty()) { - MoveOutsideClusterOpsToBlock(*block_to_move_host_cluster, cluster_ops, - context); + MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_ops, context); return; } - OpBuilder builder(block_to_move_host_cluster->getTerminator()); + OpBuilder builder(insertion_op); llvm::SmallVector host_output_types; for (const auto& external_input : external_inputs) host_output_types.push_back(external_input.getType()); @@ -470,10 +568,9 @@ void MoveOutsideCompiledOps( auto host_compute = CreateHostCompute( &builder, tpu_cluster, cluster_ops, external_inputs, external_outputs, args_communication_key, retvals_communication_key); - MoveOutsideClusterOpsToBlock(*block_to_move_host_cluster, cluster_ops, - context); + MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_ops, context); - builder.setInsertionPoint(block_to_move_host_cluster->getTerminator()); + builder.setInsertionPoint(insertion_op); builder.create( tpu_cluster.getLoc(), external_outputs, /*dynamic_key=*/compilation_key,