Handle WhileRegionOp in extract outside compilation.

Handled in similar way as IfRegionOp.  Main difference is that WhileRegionOp has a single body and the condition is communicated in the first block not before the control flow.

PiperOrigin-RevId: 329516973
Change-Id: Ic7457d50be5174c962effba1bb16866c19601a68
This commit is contained in:
Ken Franko 2020-09-01 09:01:57 -07:00 committed by TensorFlower Gardener
parent 0037c1305a
commit 6e195c5d4e
2 changed files with 400 additions and 54 deletions
tensorflow/compiler/mlir/tensorflow

View File

@ -819,4 +819,253 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op body.
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_body
func @outside_compiled_ops_inside_tf_while_body(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"()
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
// CHECK-SAME: device_ordinal = 0
// CHECK-SAME: key = "while_condition_channel_cluster1_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]])
// CHECK_NEXT: "tf.Yield"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"
// CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xf32>)
%4 = "tf.B"() : () -> (tensor<i32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.WhileRegion"(%4, %3) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
%7 = "tf.H"(%arg1) : (tensor<i32>) -> tensor<i1>
"tf.Yield"(%7) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
%8 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
%9 = "tf.D"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
"tf.Yield"(%8, %9) : (tensor<i32>, tensor<?xf32>) -> ()
}) { is_stateless = false} : (tensor<i32>, tensor<?xf32>) -> (tensor<i32>, tensor<?xf32>)
%5 = "tf.E"() : () -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op cond.
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_cond
func @outside_compiled_ops_inside_tf_while_cond(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"()
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
// CHECK-NEXT: "tf._XlaSendFromHost"(%[[I_OUTPUT]], %[[PLACEHOLDER_KEY]])
// CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
// CHECK-SAME: device_ordinal = 0
// CHECK-SAME: key = "while_condition_channel_cluster1_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]])
// CHECK_NEXT: "tf.Yield"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK "tf.XlaHostCompute"
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK-NEXT: "tf.D"
// CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xf32>)
%4 = "tf.B"() : () -> (tensor<i32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.WhileRegion"(%4, %3) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
%7 = "tf.I"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<?xf32>) -> tensor<i32>
%8 = "tf.H"(%7) : (tensor<i32>) -> tensor<i1>
"tf.Yield"(%8) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
%7 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
%8 = "tf.D"(%arg1, %arg2) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
"tf.Yield"(%7, %8) : (tensor<i32>, tensor<?xf32>) -> ()
}) { is_stateless = false} : (tensor<i32>, tensor<?xf32>) -> (tensor<i32>, tensor<?xf32>)
%5 = "tf.E"() : () -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op cond and body.
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_cond_body
func @outside_compiled_ops_inside_tf_while_cond_body(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"()
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
// CHECK-SAME: device_ordinal = 0
// CHECK-SAME: key = "while_condition_channel_cluster2_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]])
// CHECK_NEXT: "tf.Yield"
// CHECK: "tf_device.launch"
// CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"()
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
// CHECK-NEXT: "tf._XlaSendFromHost"(%[[I_OUTPUT]], %[[PLACEHOLDER_KEY]])
// CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
// CHECK-SAME: device_ordinal = 0
// CHECK-SAME: key = "while_condition_channel_cluster1_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]])
// CHECK_NEXT: "tf.Yield"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK "tf.XlaHostCompute"
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xf32>)
%4 = "tf.B"() : () -> (tensor<i32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.WhileRegion"(%4, %3) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
%7 = "tf.I"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<?xf32>) -> tensor<i32>
%8 = "tf.H"(%7) : (tensor<i32>) -> tensor<i1>
"tf.Yield"(%8) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
%7 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
%8 = "tf.D"(%arg1, %arg2) {_xla_outside_compilation = "cluster2"} : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
"tf.Yield"(%7, %8) : (tensor<i32>, tensor<?xf32>) -> ()
}) { is_stateless = false} : (tensor<i32>, tensor<?xf32>) -> (tensor<i32>, tensor<?xf32>)
%5 = "tf.E"() : () -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster inside a tf.IfRegion op
// nested in a tf.WhileRegion.
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_if
func @outside_compiled_ops_inside_tf_while_if(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"()
// CHECK-NEXT: tf.WhileRegion"
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
// CHECK-SAME: device_ordinal = 0
// CHECK-SAME: key = "while_condition_channel_cluster1_0"
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
// CHECK: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
// CHECK: "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
// CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]])
// CHECK_NEXT: "tf.Yield"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK-NEXT: "tf.XlaSendToHost"(%[[G_OUTPUT]])
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
// CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"
// CHECK-NEXT "tf.Yield"(%[[HOST_COMPUTE_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xf32>)
%4 = "tf.B"() : () -> (tensor<i32>)
%6 = "tf.G"() : () -> (tensor<i1>)
"tf.WhileRegion"(%4, %3) ({
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
%7 = "tf.H"(%arg1) : (tensor<i32>) -> tensor<i1>
"tf.Yield"(%7) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
%8 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
%10 = "tf.IfRegion"(%6) ({
%9 = "tf.D"() {_xla_outside_compilation = "cluster1"} : () -> tensor<?xf32>
"tf.Yield"(%9) : (tensor<?xf32>) -> ()
}, {
"tf.Yield"(%arg2) : (tensor<?xf32>) -> ()
}) { is_stateless = false} : (tensor<i1>) -> tensor<?xf32>
"tf.Yield"(%8, %10) : (tensor<i32>, tensor<?xf32>) -> ()
}) { is_stateless = false} : (tensor<i32>, tensor<?xf32>) -> (tensor<i32>, tensor<?xf32>)
%5 = "tf.E"() : () -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
}

View File

@ -88,22 +88,30 @@ struct TPUExtractOutsideCompilation
};
// Holds information about control flow operations that wrap outside compiled
// op. Currently only tf.If op is supported.
// op. Currently only tf.IfRegion and tf.WhileRegion ops are supported.
class ControlFlowStackInfo {
public:
enum ControlFlowBranchType { kIfThen, kIfElse };
enum ControlFlowBranchType { kIfThen, kIfElse, kWhileCond, kWhileBody };
explicit ControlFlowStackInfo(Operation* wrapping_op, Operation* nested_op)
: callsite_op_(wrapping_op) {
// Only tf.IfRegion op is supported for now.
auto control_flow_op = llvm::cast<TF::IfRegionOp>(callsite_op_);
assert(control_flow_op);
auto parent_region = nested_op->getParentRegion();
if (&control_flow_op.then_branch() == parent_region) {
type_ = ControlFlowBranchType::kIfThen;
if (auto control_flow_op = llvm::dyn_cast<TF::IfRegionOp>(callsite_op_)) {
auto parent_region = nested_op->getParentRegion();
if (&control_flow_op.then_branch() == parent_region) {
type_ = ControlFlowBranchType::kIfThen;
} else {
type_ = ControlFlowBranchType::kIfElse;
}
} else if (auto control_flow_op =
llvm::dyn_cast<TF::WhileRegionOp>(callsite_op_)) {
auto parent_region = nested_op->getParentRegion();
if (&control_flow_op.cond() == parent_region) {
type_ = ControlFlowBranchType::kWhileCond;
} else {
type_ = ControlFlowBranchType::kWhileBody;
}
} else {
type_ = ControlFlowBranchType::kIfElse;
assert(false);
}
}
@ -133,7 +141,7 @@ llvm::SmallVector<ControlFlowStackInfo, 4> GetControlFlowStackForOp(
Operation* op_in_stack = op;
while (op_in_stack != tpu_cluster.getOperation()) {
auto parent_op = op_in_stack->getParentOp();
if (llvm::isa<TF::IfRegionOp>(parent_op)) {
if (llvm::isa<TF::IfRegionOp, TF::WhileRegionOp>(parent_op)) {
controlflow_stack.insert(controlflow_stack.begin(),
ControlFlowStackInfo(parent_op, op_in_stack));
}
@ -166,7 +174,7 @@ TF::IfRegionOp CloneEmptyIfWithPredicate(Value predicate, bool is_stateless,
// Replicates tf.IfRegion op to host side computation.
Operation* ReplicateIf(const ControlFlowStackInfo& controlflow_info,
llvm::StringRef outside_cluster_name, ModuleOp module,
llvm::StringRef outside_cluster_name,
Value compilation_key, OpBuilder* builder,
int* send_recv_counter) {
// Create XlaSendToHostOp to send predicate value from device to host.
@ -200,6 +208,63 @@ Operation* ReplicateIf(const ControlFlowStackInfo& controlflow_info,
if_callsite_op.getLoc(), builder);
}
// Creates a WhileRegionOp cond and body regions with yield op and
// an empty body.
TF::WhileRegionOp CloneEmptyWhile(bool is_stateless, APInt parallel_iterations,
Location loc, OpBuilder* builder) {
auto host_side_while = builder->create<TF::WhileRegionOp>(
loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
is_stateless, parallel_iterations);
// Create empty else branch region.
auto& body = host_side_while.body();
body.push_back(new Block);
builder->setInsertionPointToEnd(&body.front());
builder->create<TF::YieldOp>(loc, /*operands=*/ArrayRef<Value>{});
return host_side_while;
}
// Replicates tf.WhileRegion op to host side computation.
Operation* ReplicateWhile(const ControlFlowStackInfo& controlflow_info,
llvm::StringRef outside_cluster_name,
Value compilation_key, OpBuilder* builder,
int* send_recv_counter) {
// Create XlaSendToHostOp to send cond region output from device to host.
OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
auto while_callsite_op =
llvm::cast<TF::WhileRegionOp>(controlflow_info.GetCallSiteOp());
builder->setInsertionPoint(while_callsite_op.cond().front().getTerminator());
const auto condition_send_recv_key =
llvm::formatv("while_condition_channel_{0}_{1}", outside_cluster_name,
*send_recv_counter)
.str();
*send_recv_counter += 1;
auto condition =
while_callsite_op.cond().front().getTerminator()->getOperand(0);
builder->create<TF::XlaSendToHostOp>(while_callsite_op.getLoc(), condition,
condition_send_recv_key);
builder->restoreInsertionPoint(insert_point);
auto host_side_while = CloneEmptyWhile(
while_callsite_op.is_stateless(), while_callsite_op.parallel_iterations(),
while_callsite_op.getLoc(), builder);
// Create cond region and yield the condition from the device.
auto& cond = host_side_while.cond();
cond.push_back(new Block);
builder->setInsertionPointToEnd(&cond.front());
auto recv_condition_at_host = builder->create<TF::_XlaRecvAtHostOp>(
while_callsite_op.getLoc(), llvm::ArrayRef<Type>{condition.getType()},
/*dynamic_key=*/compilation_key,
builder->getStringAttr(condition_send_recv_key),
/*device_ordinal=*/builder->getI64IntegerAttr(0));
builder->create<TF::YieldOp>(while_callsite_op.getLoc(),
recv_condition_at_host.getResults());
return host_side_while;
}
// TODO(b/157054714): Use a better abstraction instead of
// _TPUCompileMlirOp and _XlaRecvAtHostOp and _XlaSendFromHostOp.
// Creates a compilation key as placeholder. A placeholder compilation cache key
@ -216,7 +281,7 @@ Value CreateCompilationKeyPlaceholder(Location loc, OpBuilder* builder) {
// Replicates the control flow operations that wraps outside compiled ops to
// `destination_block`.
Block* ReplicateControlFlowStack(
Operation* ReplicateControlFlowStack(
llvm::StringRef outside_cluster_name,
const llvm::SmallVectorImpl<ControlFlowStackInfo>& stack_info,
tf_device::ClusterOp tpu_cluster, ModuleOp module, Value compilation_key,
@ -227,32 +292,60 @@ Block* ReplicateControlFlowStack(
for (const auto& controlflow_stack_info : stack_info) {
// Create control flow op given provided insertion point and
// ControlFlowStackInfo.
previous_replicated_controlflow_op =
ReplicateIf(controlflow_stack_info, outside_cluster_name, module,
compilation_key, &builder, send_recv_counter);
auto if_op = llvm::cast<TF::IfRegionOp>(previous_replicated_controlflow_op);
auto type = controlflow_stack_info.GetBranchType();
if (auto control_flow_op = llvm::dyn_cast<TF::IfRegionOp>(
controlflow_stack_info.GetCallSiteOp())) {
previous_replicated_controlflow_op =
ReplicateIf(controlflow_stack_info, outside_cluster_name,
compilation_key, &builder, send_recv_counter);
auto if_op =
llvm::cast<TF::IfRegionOp>(previous_replicated_controlflow_op);
auto type = controlflow_stack_info.GetBranchType();
// Update the insertion point to proper region inside the newly created
// control flow op.
if (type == ControlFlowStackInfo::kIfThen) {
builder.setInsertionPoint(&if_op.then_branch().front().front());
} else {
builder.setInsertionPoint(&if_op.else_branch().front().front());
// Update the insertion point to proper region inside the newly created
// control flow op.
if (type == ControlFlowStackInfo::kIfThen) {
builder.setInsertionPoint(&if_op.then_branch().front().front());
} else {
builder.setInsertionPoint(&if_op.else_branch().front().front());
}
} else if (auto control_flow_op = llvm::dyn_cast<TF::WhileRegionOp>(
controlflow_stack_info.GetCallSiteOp())) {
previous_replicated_controlflow_op =
ReplicateWhile(controlflow_stack_info, outside_cluster_name,
compilation_key, &builder, send_recv_counter);
auto while_op =
llvm::cast<TF::WhileRegionOp>(previous_replicated_controlflow_op);
auto type = controlflow_stack_info.GetBranchType();
if (type == ControlFlowStackInfo::kWhileCond) {
builder.setInsertionPoint(&while_op.cond().front().front());
} else {
builder.setInsertionPoint(&while_op.body().front().front());
}
}
}
// Return the inner most branch at which outside compiled op is located.
// This block will later be used as insertion point to create send/recv ops.
auto inner_most_controlflow_stack = stack_info.back();
auto inner_most_if =
llvm::cast<TF::IfRegionOp>(previous_replicated_controlflow_op);
if (inner_most_controlflow_stack.GetBranchType() ==
ControlFlowStackInfo::kIfThen) {
return &inner_most_if.then_branch().front();
} else {
return &inner_most_if.else_branch().front();
// Return operation which should be used to as the insertion point to create
// send/recv ops.
if (auto inner_most_if =
llvm::dyn_cast<TF::IfRegionOp>(previous_replicated_controlflow_op)) {
auto inner_most_controlflow_stack = stack_info.back();
if (inner_most_controlflow_stack.GetBranchType() ==
ControlFlowStackInfo::kIfThen) {
return inner_most_if.then_branch().front().getTerminator();
} else {
return inner_most_if.else_branch().front().getTerminator();
}
} else if (auto inner_most_while = llvm::dyn_cast<TF::WhileRegionOp>(
previous_replicated_controlflow_op)) {
auto inner_most_controlflow_stack = stack_info.back();
if (inner_most_controlflow_stack.GetBranchType() ==
ControlFlowStackInfo::kWhileCond) {
return &inner_most_while.cond().front().front();
} else {
return inner_most_while.body().front().getTerminator();
}
}
return destination_block->getTerminator();
}
// Collects and clusters ops in `block` with the same `_xla_outside_compilation`
@ -279,18 +372,17 @@ LogicalResult CollectAndGroupOutsideClusterOps(Block* block,
return failure(walk_result.wasInterrupted());
}
// Moves `cluster_ops` to associated `block`.
void MoveOutsideClusterOpsToBlock(Block& block,
llvm::ArrayRef<Operation*> cluster_ops,
MLIRContext* context) {
Operation* terminator = block.getTerminator();
// Moves `cluster_ops` before `op`.
void MoveOutsideClusterOpsBeforeOp(Operation* op,
llvm::ArrayRef<Operation*> cluster_ops,
MLIRContext* context) {
for (Operation* cluster_op : cluster_ops) {
// Remove `_xla_outside_compilation` and `device` attribute from ops in the
// cluster as that information will be present in the `launch_op`.
cluster_op->removeAttr(
Identifier::get(kXlaOutsideCompilationAttr, context));
cluster_op->removeAttr(Identifier::get(kDeviceAttr, context));
cluster_op->moveBefore(terminator);
cluster_op->moveBefore(op);
}
}
@ -330,11 +422,18 @@ llvm::SmallSetVector<Value, 4> GetExternalOperands(
// in `host_cluster_ops`.
for (Value v : op->getOperands()) {
Operation* defining_op = v.getDefiningOp();
if (!defining_op) continue;
bool is_external = llvm::none_of(
host_cluster_ops,
[&](Operation* cluster_op) { return defining_op == cluster_op; });
bool is_external = false;
if (defining_op) {
is_external =
llvm::none_of(host_cluster_ops, [&](Operation* cluster_op) {
return defining_op == cluster_op;
});
} else {
if (auto block_arg = v.dyn_cast<BlockArgument>()) {
if (block_arg.getParentRegion() == cluster_op_parent_region)
is_external = true;
}
}
if (is_external) external_values.insert(v);
}
} else {
@ -432,24 +531,23 @@ void MoveOutsideCompiledOps(
CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), &builder);
}
Block* block_to_move_host_cluster = nullptr;
Operation* insertion_op = nullptr;
if (controlflow_stack.empty()) {
block_to_move_host_cluster = &host_launch_op.GetBody();
insertion_op = host_launch_op.GetBody().getTerminator();
} else {
int send_recv_counter = 0;
block_to_move_host_cluster = ReplicateControlFlowStack(
insertion_op = ReplicateControlFlowStack(
outside_cluster_name, controlflow_stack, tpu_cluster, module,
compilation_key, &host_launch_op.GetBody(), &send_recv_counter);
}
MLIRContext* context = host_launch_op.getContext();
if (external_inputs.empty() && external_outputs.empty()) {
MoveOutsideClusterOpsToBlock(*block_to_move_host_cluster, cluster_ops,
context);
MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_ops, context);
return;
}
OpBuilder builder(block_to_move_host_cluster->getTerminator());
OpBuilder builder(insertion_op);
llvm::SmallVector<Type, 4> host_output_types;
for (const auto& external_input : external_inputs)
host_output_types.push_back(external_input.getType());
@ -470,10 +568,9 @@ void MoveOutsideCompiledOps(
auto host_compute = CreateHostCompute(
&builder, tpu_cluster, cluster_ops, external_inputs, external_outputs,
args_communication_key, retvals_communication_key);
MoveOutsideClusterOpsToBlock(*block_to_move_host_cluster, cluster_ops,
context);
MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_ops, context);
builder.setInsertionPoint(block_to_move_host_cluster->getTerminator());
builder.setInsertionPoint(insertion_op);
builder.create<TF::_XlaSendFromHostOp>(
tpu_cluster.getLoc(), external_outputs,
/*dynamic_key=*/compilation_key,