Handle WhileRegionOp in extract outside compilation.
Handled in similar way as IfRegionOp. Main difference is that WhileRegionOp has a single body and the condition is communicated in the first block not before the control flow. PiperOrigin-RevId: 329516973 Change-Id: Ic7457d50be5174c962effba1bb16866c19601a68
This commit is contained in:
parent
0037c1305a
commit
6e195c5d4e
tensorflow/compiler/mlir/tensorflow
@ -819,4 +819,253 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op body.
|
||||
|
||||
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_body
|
||||
func @outside_compiled_ops_inside_tf_while_body(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"()
|
||||
// CHECK-NEXT: tf.WhileRegion"
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
|
||||
// CHECK-SAME: device_ordinal = 0
|
||||
// CHECK-SAME: key = "while_condition_channel_cluster1_0"
|
||||
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
|
||||
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]])
|
||||
// CHECK_NEXT: "tf.Yield"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
|
||||
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
|
||||
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
|
||||
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
|
||||
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
|
||||
// CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"
|
||||
// CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xf32>)
|
||||
%4 = "tf.B"() : () -> (tensor<i32>)
|
||||
%6 = "tf.G"() : () -> (tensor<i1>)
|
||||
|
||||
"tf.WhileRegion"(%4, %3) ({
|
||||
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
|
||||
%7 = "tf.H"(%arg1) : (tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%7) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
|
||||
%8 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
|
||||
%9 = "tf.D"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
"tf.Yield"(%8, %9) : (tensor<i32>, tensor<?xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i32>, tensor<?xf32>) -> (tensor<i32>, tensor<?xf32>)
|
||||
|
||||
%5 = "tf.E"() : () -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op cond.
|
||||
|
||||
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_cond
|
||||
func @outside_compiled_ops_inside_tf_while_cond(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"()
|
||||
// CHECK-NEXT: tf.WhileRegion"
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
|
||||
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
|
||||
// CHECK-NEXT: "tf._XlaSendFromHost"(%[[I_OUTPUT]], %[[PLACEHOLDER_KEY]])
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
|
||||
// CHECK-SAME: device_ordinal = 0
|
||||
// CHECK-SAME: key = "while_condition_channel_cluster1_0"
|
||||
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]])
|
||||
// CHECK_NEXT: "tf.Yield"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
|
||||
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
|
||||
// CHECK "tf.XlaHostCompute"
|
||||
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
|
||||
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
|
||||
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
|
||||
// CHECK-NEXT: "tf.D"
|
||||
// CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xf32>)
|
||||
%4 = "tf.B"() : () -> (tensor<i32>)
|
||||
%6 = "tf.G"() : () -> (tensor<i1>)
|
||||
|
||||
"tf.WhileRegion"(%4, %3) ({
|
||||
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
|
||||
%7 = "tf.I"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<?xf32>) -> tensor<i32>
|
||||
%8 = "tf.H"(%7) : (tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%8) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
|
||||
%7 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
|
||||
%8 = "tf.D"(%arg1, %arg2) : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
"tf.Yield"(%7, %8) : (tensor<i32>, tensor<?xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i32>, tensor<?xf32>) -> (tensor<i32>, tensor<?xf32>)
|
||||
|
||||
%5 = "tf.E"() : () -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster inside a tf.WhileRegion op cond and body.
|
||||
|
||||
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_cond_body
|
||||
func @outside_compiled_ops_inside_tf_while_cond_body(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"()
|
||||
// CHECK-NEXT: tf.WhileRegion"
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
|
||||
// CHECK-SAME: device_ordinal = 0
|
||||
// CHECK-SAME: key = "while_condition_channel_cluster2_0"
|
||||
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
|
||||
// CHECK: %[[BODY_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]])
|
||||
// CHECK_NEXT: "tf.Yield"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"()
|
||||
// CHECK-NEXT: tf.WhileRegion"
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT1:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
|
||||
// CHECK-NEXT: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[COND_RECV_OUTPUT1]]#0, %[[COND_RECV_OUTPUT1]]#1)
|
||||
// CHECK-NEXT: "tf._XlaSendFromHost"(%[[I_OUTPUT]], %[[PLACEHOLDER_KEY]])
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
|
||||
// CHECK-SAME: device_ordinal = 0
|
||||
// CHECK-SAME: key = "while_condition_channel_cluster1_0"
|
||||
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT2]])
|
||||
// CHECK_NEXT: "tf.Yield"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
|
||||
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
|
||||
// CHECK "tf.XlaHostCompute"
|
||||
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
|
||||
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
|
||||
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
|
||||
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
|
||||
// CHECK-NEXT "tf.Yield"(%[[C_OUTPUT]], %[[HOST_COMPUTE_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xf32>)
|
||||
%4 = "tf.B"() : () -> (tensor<i32>)
|
||||
%6 = "tf.G"() : () -> (tensor<i1>)
|
||||
|
||||
"tf.WhileRegion"(%4, %3) ({
|
||||
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
|
||||
%7 = "tf.I"(%arg1, %arg2) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<?xf32>) -> tensor<i32>
|
||||
%8 = "tf.H"(%7) : (tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%8) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
|
||||
%7 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
|
||||
%8 = "tf.D"(%arg1, %arg2) {_xla_outside_compilation = "cluster2"} : (tensor<i32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
"tf.Yield"(%7, %8) : (tensor<i32>, tensor<?xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i32>, tensor<?xf32>) -> (tensor<i32>, tensor<?xf32>)
|
||||
|
||||
%5 = "tf.E"() : () -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster inside a tf.IfRegion op
|
||||
// nested in a tf.WhileRegion.
|
||||
|
||||
// CHECK-LABEL: func @outside_compiled_ops_inside_tf_while_if
|
||||
func @outside_compiled_ops_inside_tf_while_if(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"()
|
||||
// CHECK-NEXT: tf.WhileRegion"
|
||||
// CHECK-NEXT: %[[COND_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
|
||||
// CHECK-SAME: device_ordinal = 0
|
||||
// CHECK-SAME: key = "while_condition_channel_cluster1_0"
|
||||
// CHECK: "tf.Yield"(%[[COND_RECV_OUTPUT]])
|
||||
// CHECK: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]])
|
||||
// CHECK: "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]])
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]])
|
||||
// CHECK_NEXT: "tf.Yield"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G"
|
||||
// CHECK-NEXT: tf.WhileRegion"(%[[B_OUTPUT]], %[[A_OUTPUT]])
|
||||
// CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"
|
||||
// CHECK-NEXT: "tf.XlaSendToHost"(%[[H_OUTPUT]])
|
||||
// CHECK-NEXT: "tf.Yield"(%[[H_OUTPUT]])
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
|
||||
// CHECK-NEXT: "tf.XlaSendToHost"(%[[G_OUTPUT]])
|
||||
// CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]])
|
||||
// CHECK-NEXT: %[[HOST_COMPUTE_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"
|
||||
// CHECK-NEXT "tf.Yield"(%[[HOST_COMPUTE_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xf32>)
|
||||
%4 = "tf.B"() : () -> (tensor<i32>)
|
||||
%6 = "tf.G"() : () -> (tensor<i1>)
|
||||
|
||||
"tf.WhileRegion"(%4, %3) ({
|
||||
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
|
||||
%7 = "tf.H"(%arg1) : (tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%7) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i32>, %arg2: tensor<?xf32>):
|
||||
%8 = "tf.C"(%arg1) : (tensor<i32>) -> tensor<i32>
|
||||
%10 = "tf.IfRegion"(%6) ({
|
||||
%9 = "tf.D"() {_xla_outside_compilation = "cluster1"} : () -> tensor<?xf32>
|
||||
"tf.Yield"(%9) : (tensor<?xf32>) -> ()
|
||||
}, {
|
||||
"tf.Yield"(%arg2) : (tensor<?xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i1>) -> tensor<?xf32>
|
||||
"tf.Yield"(%8, %10) : (tensor<i32>, tensor<?xf32>) -> ()
|
||||
}) { is_stateless = false} : (tensor<i32>, tensor<?xf32>) -> (tensor<i32>, tensor<?xf32>)
|
||||
|
||||
%5 = "tf.E"() : () -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
}
|
||||
|
@ -88,22 +88,30 @@ struct TPUExtractOutsideCompilation
|
||||
};
|
||||
|
||||
// Holds information about control flow operations that wrap outside compiled
|
||||
// op. Currently only tf.If op is supported.
|
||||
// op. Currently only tf.IfRegion and tf.WhileRegion ops are supported.
|
||||
class ControlFlowStackInfo {
|
||||
public:
|
||||
enum ControlFlowBranchType { kIfThen, kIfElse };
|
||||
enum ControlFlowBranchType { kIfThen, kIfElse, kWhileCond, kWhileBody };
|
||||
|
||||
explicit ControlFlowStackInfo(Operation* wrapping_op, Operation* nested_op)
|
||||
: callsite_op_(wrapping_op) {
|
||||
// Only tf.IfRegion op is supported for now.
|
||||
auto control_flow_op = llvm::cast<TF::IfRegionOp>(callsite_op_);
|
||||
assert(control_flow_op);
|
||||
|
||||
auto parent_region = nested_op->getParentRegion();
|
||||
if (&control_flow_op.then_branch() == parent_region) {
|
||||
type_ = ControlFlowBranchType::kIfThen;
|
||||
if (auto control_flow_op = llvm::dyn_cast<TF::IfRegionOp>(callsite_op_)) {
|
||||
auto parent_region = nested_op->getParentRegion();
|
||||
if (&control_flow_op.then_branch() == parent_region) {
|
||||
type_ = ControlFlowBranchType::kIfThen;
|
||||
} else {
|
||||
type_ = ControlFlowBranchType::kIfElse;
|
||||
}
|
||||
} else if (auto control_flow_op =
|
||||
llvm::dyn_cast<TF::WhileRegionOp>(callsite_op_)) {
|
||||
auto parent_region = nested_op->getParentRegion();
|
||||
if (&control_flow_op.cond() == parent_region) {
|
||||
type_ = ControlFlowBranchType::kWhileCond;
|
||||
} else {
|
||||
type_ = ControlFlowBranchType::kWhileBody;
|
||||
}
|
||||
} else {
|
||||
type_ = ControlFlowBranchType::kIfElse;
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
@ -133,7 +141,7 @@ llvm::SmallVector<ControlFlowStackInfo, 4> GetControlFlowStackForOp(
|
||||
Operation* op_in_stack = op;
|
||||
while (op_in_stack != tpu_cluster.getOperation()) {
|
||||
auto parent_op = op_in_stack->getParentOp();
|
||||
if (llvm::isa<TF::IfRegionOp>(parent_op)) {
|
||||
if (llvm::isa<TF::IfRegionOp, TF::WhileRegionOp>(parent_op)) {
|
||||
controlflow_stack.insert(controlflow_stack.begin(),
|
||||
ControlFlowStackInfo(parent_op, op_in_stack));
|
||||
}
|
||||
@ -166,7 +174,7 @@ TF::IfRegionOp CloneEmptyIfWithPredicate(Value predicate, bool is_stateless,
|
||||
|
||||
// Replicates tf.IfRegion op to host side computation.
|
||||
Operation* ReplicateIf(const ControlFlowStackInfo& controlflow_info,
|
||||
llvm::StringRef outside_cluster_name, ModuleOp module,
|
||||
llvm::StringRef outside_cluster_name,
|
||||
Value compilation_key, OpBuilder* builder,
|
||||
int* send_recv_counter) {
|
||||
// Create XlaSendToHostOp to send predicate value from device to host.
|
||||
@ -200,6 +208,63 @@ Operation* ReplicateIf(const ControlFlowStackInfo& controlflow_info,
|
||||
if_callsite_op.getLoc(), builder);
|
||||
}
|
||||
|
||||
// Creates a WhileRegionOp cond and body regions with yield op and
|
||||
// an empty body.
|
||||
TF::WhileRegionOp CloneEmptyWhile(bool is_stateless, APInt parallel_iterations,
|
||||
Location loc, OpBuilder* builder) {
|
||||
auto host_side_while = builder->create<TF::WhileRegionOp>(
|
||||
loc, /*output=*/ArrayRef<Type>{}, /*input=*/ArrayRef<Value>{},
|
||||
is_stateless, parallel_iterations);
|
||||
|
||||
// Create empty else branch region.
|
||||
auto& body = host_side_while.body();
|
||||
body.push_back(new Block);
|
||||
builder->setInsertionPointToEnd(&body.front());
|
||||
builder->create<TF::YieldOp>(loc, /*operands=*/ArrayRef<Value>{});
|
||||
return host_side_while;
|
||||
}
|
||||
|
||||
// Replicates tf.WhileRegion op to host side computation.
|
||||
Operation* ReplicateWhile(const ControlFlowStackInfo& controlflow_info,
|
||||
llvm::StringRef outside_cluster_name,
|
||||
Value compilation_key, OpBuilder* builder,
|
||||
int* send_recv_counter) {
|
||||
// Create XlaSendToHostOp to send cond region output from device to host.
|
||||
OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
|
||||
auto while_callsite_op =
|
||||
llvm::cast<TF::WhileRegionOp>(controlflow_info.GetCallSiteOp());
|
||||
builder->setInsertionPoint(while_callsite_op.cond().front().getTerminator());
|
||||
|
||||
const auto condition_send_recv_key =
|
||||
llvm::formatv("while_condition_channel_{0}_{1}", outside_cluster_name,
|
||||
*send_recv_counter)
|
||||
.str();
|
||||
*send_recv_counter += 1;
|
||||
auto condition =
|
||||
while_callsite_op.cond().front().getTerminator()->getOperand(0);
|
||||
builder->create<TF::XlaSendToHostOp>(while_callsite_op.getLoc(), condition,
|
||||
condition_send_recv_key);
|
||||
builder->restoreInsertionPoint(insert_point);
|
||||
|
||||
auto host_side_while = CloneEmptyWhile(
|
||||
while_callsite_op.is_stateless(), while_callsite_op.parallel_iterations(),
|
||||
while_callsite_op.getLoc(), builder);
|
||||
|
||||
// Create cond region and yield the condition from the device.
|
||||
auto& cond = host_side_while.cond();
|
||||
cond.push_back(new Block);
|
||||
builder->setInsertionPointToEnd(&cond.front());
|
||||
auto recv_condition_at_host = builder->create<TF::_XlaRecvAtHostOp>(
|
||||
while_callsite_op.getLoc(), llvm::ArrayRef<Type>{condition.getType()},
|
||||
/*dynamic_key=*/compilation_key,
|
||||
builder->getStringAttr(condition_send_recv_key),
|
||||
/*device_ordinal=*/builder->getI64IntegerAttr(0));
|
||||
builder->create<TF::YieldOp>(while_callsite_op.getLoc(),
|
||||
recv_condition_at_host.getResults());
|
||||
|
||||
return host_side_while;
|
||||
}
|
||||
|
||||
// TODO(b/157054714): Use a better abstraction instead of
|
||||
// _TPUCompileMlirOp and _XlaRecvAtHostOp and _XlaSendFromHostOp.
|
||||
// Creates a compilation key as placeholder. A placeholder compilation cache key
|
||||
@ -216,7 +281,7 @@ Value CreateCompilationKeyPlaceholder(Location loc, OpBuilder* builder) {
|
||||
|
||||
// Replicates the control flow operations that wraps outside compiled ops to
|
||||
// `destination_block`.
|
||||
Block* ReplicateControlFlowStack(
|
||||
Operation* ReplicateControlFlowStack(
|
||||
llvm::StringRef outside_cluster_name,
|
||||
const llvm::SmallVectorImpl<ControlFlowStackInfo>& stack_info,
|
||||
tf_device::ClusterOp tpu_cluster, ModuleOp module, Value compilation_key,
|
||||
@ -227,32 +292,60 @@ Block* ReplicateControlFlowStack(
|
||||
for (const auto& controlflow_stack_info : stack_info) {
|
||||
// Create control flow op given provided insertion point and
|
||||
// ControlFlowStackInfo.
|
||||
previous_replicated_controlflow_op =
|
||||
ReplicateIf(controlflow_stack_info, outside_cluster_name, module,
|
||||
compilation_key, &builder, send_recv_counter);
|
||||
auto if_op = llvm::cast<TF::IfRegionOp>(previous_replicated_controlflow_op);
|
||||
auto type = controlflow_stack_info.GetBranchType();
|
||||
if (auto control_flow_op = llvm::dyn_cast<TF::IfRegionOp>(
|
||||
controlflow_stack_info.GetCallSiteOp())) {
|
||||
previous_replicated_controlflow_op =
|
||||
ReplicateIf(controlflow_stack_info, outside_cluster_name,
|
||||
compilation_key, &builder, send_recv_counter);
|
||||
auto if_op =
|
||||
llvm::cast<TF::IfRegionOp>(previous_replicated_controlflow_op);
|
||||
auto type = controlflow_stack_info.GetBranchType();
|
||||
|
||||
// Update the insertion point to proper region inside the newly created
|
||||
// control flow op.
|
||||
if (type == ControlFlowStackInfo::kIfThen) {
|
||||
builder.setInsertionPoint(&if_op.then_branch().front().front());
|
||||
} else {
|
||||
builder.setInsertionPoint(&if_op.else_branch().front().front());
|
||||
// Update the insertion point to proper region inside the newly created
|
||||
// control flow op.
|
||||
if (type == ControlFlowStackInfo::kIfThen) {
|
||||
builder.setInsertionPoint(&if_op.then_branch().front().front());
|
||||
} else {
|
||||
builder.setInsertionPoint(&if_op.else_branch().front().front());
|
||||
}
|
||||
} else if (auto control_flow_op = llvm::dyn_cast<TF::WhileRegionOp>(
|
||||
controlflow_stack_info.GetCallSiteOp())) {
|
||||
previous_replicated_controlflow_op =
|
||||
ReplicateWhile(controlflow_stack_info, outside_cluster_name,
|
||||
compilation_key, &builder, send_recv_counter);
|
||||
auto while_op =
|
||||
llvm::cast<TF::WhileRegionOp>(previous_replicated_controlflow_op);
|
||||
auto type = controlflow_stack_info.GetBranchType();
|
||||
if (type == ControlFlowStackInfo::kWhileCond) {
|
||||
builder.setInsertionPoint(&while_op.cond().front().front());
|
||||
} else {
|
||||
builder.setInsertionPoint(&while_op.body().front().front());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return the inner most branch at which outside compiled op is located.
|
||||
// This block will later be used as insertion point to create send/recv ops.
|
||||
auto inner_most_controlflow_stack = stack_info.back();
|
||||
auto inner_most_if =
|
||||
llvm::cast<TF::IfRegionOp>(previous_replicated_controlflow_op);
|
||||
if (inner_most_controlflow_stack.GetBranchType() ==
|
||||
ControlFlowStackInfo::kIfThen) {
|
||||
return &inner_most_if.then_branch().front();
|
||||
} else {
|
||||
return &inner_most_if.else_branch().front();
|
||||
// Return operation which should be used to as the insertion point to create
|
||||
// send/recv ops.
|
||||
if (auto inner_most_if =
|
||||
llvm::dyn_cast<TF::IfRegionOp>(previous_replicated_controlflow_op)) {
|
||||
auto inner_most_controlflow_stack = stack_info.back();
|
||||
if (inner_most_controlflow_stack.GetBranchType() ==
|
||||
ControlFlowStackInfo::kIfThen) {
|
||||
return inner_most_if.then_branch().front().getTerminator();
|
||||
} else {
|
||||
return inner_most_if.else_branch().front().getTerminator();
|
||||
}
|
||||
} else if (auto inner_most_while = llvm::dyn_cast<TF::WhileRegionOp>(
|
||||
previous_replicated_controlflow_op)) {
|
||||
auto inner_most_controlflow_stack = stack_info.back();
|
||||
if (inner_most_controlflow_stack.GetBranchType() ==
|
||||
ControlFlowStackInfo::kWhileCond) {
|
||||
return &inner_most_while.cond().front().front();
|
||||
} else {
|
||||
return inner_most_while.body().front().getTerminator();
|
||||
}
|
||||
}
|
||||
return destination_block->getTerminator();
|
||||
}
|
||||
|
||||
// Collects and clusters ops in `block` with the same `_xla_outside_compilation`
|
||||
@ -279,18 +372,17 @@ LogicalResult CollectAndGroupOutsideClusterOps(Block* block,
|
||||
return failure(walk_result.wasInterrupted());
|
||||
}
|
||||
|
||||
// Moves `cluster_ops` to associated `block`.
|
||||
void MoveOutsideClusterOpsToBlock(Block& block,
|
||||
llvm::ArrayRef<Operation*> cluster_ops,
|
||||
MLIRContext* context) {
|
||||
Operation* terminator = block.getTerminator();
|
||||
// Moves `cluster_ops` before `op`.
|
||||
void MoveOutsideClusterOpsBeforeOp(Operation* op,
|
||||
llvm::ArrayRef<Operation*> cluster_ops,
|
||||
MLIRContext* context) {
|
||||
for (Operation* cluster_op : cluster_ops) {
|
||||
// Remove `_xla_outside_compilation` and `device` attribute from ops in the
|
||||
// cluster as that information will be present in the `launch_op`.
|
||||
cluster_op->removeAttr(
|
||||
Identifier::get(kXlaOutsideCompilationAttr, context));
|
||||
cluster_op->removeAttr(Identifier::get(kDeviceAttr, context));
|
||||
cluster_op->moveBefore(terminator);
|
||||
cluster_op->moveBefore(op);
|
||||
}
|
||||
}
|
||||
|
||||
@ -330,11 +422,18 @@ llvm::SmallSetVector<Value, 4> GetExternalOperands(
|
||||
// in `host_cluster_ops`.
|
||||
for (Value v : op->getOperands()) {
|
||||
Operation* defining_op = v.getDefiningOp();
|
||||
if (!defining_op) continue;
|
||||
bool is_external = llvm::none_of(
|
||||
host_cluster_ops,
|
||||
[&](Operation* cluster_op) { return defining_op == cluster_op; });
|
||||
|
||||
bool is_external = false;
|
||||
if (defining_op) {
|
||||
is_external =
|
||||
llvm::none_of(host_cluster_ops, [&](Operation* cluster_op) {
|
||||
return defining_op == cluster_op;
|
||||
});
|
||||
} else {
|
||||
if (auto block_arg = v.dyn_cast<BlockArgument>()) {
|
||||
if (block_arg.getParentRegion() == cluster_op_parent_region)
|
||||
is_external = true;
|
||||
}
|
||||
}
|
||||
if (is_external) external_values.insert(v);
|
||||
}
|
||||
} else {
|
||||
@ -432,24 +531,23 @@ void MoveOutsideCompiledOps(
|
||||
CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), &builder);
|
||||
}
|
||||
|
||||
Block* block_to_move_host_cluster = nullptr;
|
||||
Operation* insertion_op = nullptr;
|
||||
if (controlflow_stack.empty()) {
|
||||
block_to_move_host_cluster = &host_launch_op.GetBody();
|
||||
insertion_op = host_launch_op.GetBody().getTerminator();
|
||||
} else {
|
||||
int send_recv_counter = 0;
|
||||
block_to_move_host_cluster = ReplicateControlFlowStack(
|
||||
insertion_op = ReplicateControlFlowStack(
|
||||
outside_cluster_name, controlflow_stack, tpu_cluster, module,
|
||||
compilation_key, &host_launch_op.GetBody(), &send_recv_counter);
|
||||
}
|
||||
|
||||
MLIRContext* context = host_launch_op.getContext();
|
||||
if (external_inputs.empty() && external_outputs.empty()) {
|
||||
MoveOutsideClusterOpsToBlock(*block_to_move_host_cluster, cluster_ops,
|
||||
context);
|
||||
MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_ops, context);
|
||||
return;
|
||||
}
|
||||
|
||||
OpBuilder builder(block_to_move_host_cluster->getTerminator());
|
||||
OpBuilder builder(insertion_op);
|
||||
llvm::SmallVector<Type, 4> host_output_types;
|
||||
for (const auto& external_input : external_inputs)
|
||||
host_output_types.push_back(external_input.getType());
|
||||
@ -470,10 +568,9 @@ void MoveOutsideCompiledOps(
|
||||
auto host_compute = CreateHostCompute(
|
||||
&builder, tpu_cluster, cluster_ops, external_inputs, external_outputs,
|
||||
args_communication_key, retvals_communication_key);
|
||||
MoveOutsideClusterOpsToBlock(*block_to_move_host_cluster, cluster_ops,
|
||||
context);
|
||||
MoveOutsideClusterOpsBeforeOp(insertion_op, cluster_ops, context);
|
||||
|
||||
builder.setInsertionPoint(block_to_move_host_cluster->getTerminator());
|
||||
builder.setInsertionPoint(insertion_op);
|
||||
builder.create<TF::_XlaSendFromHostOp>(
|
||||
tpu_cluster.getLoc(), external_outputs,
|
||||
/*dynamic_key=*/compilation_key,
|
||||
|
Loading…
Reference in New Issue
Block a user