Update TPUExtractHeadTailOutsideCompilation pass to support tail outside compiled computation extraction.
This extends the pass to lift ops to after the cluster if they are outside compiled and have no dependencies, other than other ops that are to be lifted. PiperOrigin-RevId: 313877620 Change-Id: Ia5d068d74383206dc6ffaed06f429b7b93767271
This commit is contained in:
parent
481aec85d6
commit
00f027a374
|
@ -157,4 +157,232 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tail_single_outside_compiled_op
|
||||
func @tail_single_outside_compiled_op() {
|
||||
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
// CHECK-DAG: step_marker_location = ""
|
||||
// CHECK-DAG: padding_map = []
|
||||
// CHECK-DAG: topology = ""
|
||||
// CHECK-DAG: device_assignment = []
|
||||
//
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
"tf_device.cluster"() ( {
|
||||
%a = "tf.A"() : () -> tensor<i32>
|
||||
"tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tail_single_outside_compiled_op_user
|
||||
func @tail_single_outside_compiled_op_user() -> tensor<i32> {
|
||||
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
// CHECK-DAG: step_marker_location = ""
|
||||
// CHECK-DAG: padding_map = []
|
||||
// CHECK-DAG: topology = ""
|
||||
// CHECK-DAG: device_assignment = []
|
||||
//
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]])
|
||||
// CHECK-NEXT: tf_device.return %[[B_OUT]]
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
%cluster = "tf_device.cluster"() ( {
|
||||
%a = "tf.A"() : () -> tensor<i32>
|
||||
%b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return %b : tensor<i32>
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor<i32>
|
||||
// CHECK: return %[[LAUNCH_OUT]]
|
||||
return %cluster : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tail_multiple_outside_compiled_ops
|
||||
func @tail_multiple_outside_compiled_ops(%arg0: tensor<i32>) {
|
||||
// CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
|
||||
// CHECK-NEXT: tf_device.return %[[B_OUT]], %[[A_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
// CHECK-DAG: step_marker_location = ""
|
||||
// CHECK-DAG: padding_map = []
|
||||
// CHECK-DAG: topology = ""
|
||||
// CHECK-DAG: device_assignment = []
|
||||
//
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%arg0, %[[CLUSTER_OUT]]#1)
|
||||
// CHECK-NEXT: "tf.D"(%[[C_OUT]], %arg0, %[[CLUSTER_OUT]]#0)
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
"tf_device.cluster"() ( {
|
||||
%a = "tf.A"() : () -> tensor<i32>
|
||||
%b = "tf.B"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
%c = "tf.C"(%arg0, %a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
"tf.D"(%c, %arg0, %b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tail_aliased_output
|
||||
func @tail_aliased_output() -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
%a = "tf.A"() : () -> tensor<i32>
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
|
||||
%b = "tf.B"() : () -> tensor<i32>
|
||||
// CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"
|
||||
// CHECK-NEXT: %[[E_OUT:.*]] = "tf.E"
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]], %[[E_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
// CHECK-DAG: step_marker_location = ""
|
||||
// CHECK-DAG: padding_map = []
|
||||
// CHECK-DAG: topology = ""
|
||||
// CHECK-DAG: device_assignment = []
|
||||
//
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[CLUSTER_OUT]]#0, %[[A_OUT]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
%cluster:5 = "tf_device.cluster"() ( {
|
||||
%c = "tf.C"() : () -> tensor<i32>
|
||||
%d = "tf.D"(%c, %a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%e = "tf.E"() : () -> tensor<i32>
|
||||
tf_device.return %a, %b, %c, %d, %e : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
|
||||
// CHECK: return %[[A_OUT]], %[[B_OUT]], %[[CLUSTER_OUT]]#0, %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#1
|
||||
return %cluster#0, %cluster#1, %cluster#2, %cluster#3, %cluster#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tail_replicated_outside_compilation
|
||||
func @tail_replicated_outside_compilation(%arg0: tensor<i32>, %arg1: tensor<i32>) {
|
||||
// CHECK: tf_device.replicate([%arg0, %arg1] as %[[RI:.*]]: tensor<i32>)
|
||||
//
|
||||
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]])
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
// CHECK-DAG: step_marker_location = ""
|
||||
// CHECK-DAG: padding_map = []
|
||||
// CHECK-DAG: topology = ""
|
||||
// CHECK-DAG: device_assignment = []
|
||||
//
|
||||
// CHECK-NEXT: "tf_device.launch"()
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]], %[[RI]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
|
||||
tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} {
|
||||
"tf_device.cluster"() ( {
|
||||
%a = "tf.A"(%ri) : (tensor<i32>) -> tensor<i32>
|
||||
%b = "tf.B"(%a, %ri) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @head_tail_no_extraction_middle_outside_compiled_ops
|
||||
func @head_tail_no_extraction_middle_outside_compiled_ops(%arg0: tensor<i32>) {
|
||||
// CHECK-NOT: "tf_device.launch"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%a = "tf.A"(%arg0) : (tensor<i32>) -> tensor<i32>
|
||||
%b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
|
||||
"tf.C"(%b) : (tensor<i32>) -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @head_tail_simple_extraction
|
||||
func @head_tail_simple_extraction(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%arg0)
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[HEAD_LAUNCH_OUT]])
|
||||
// CHECK-NEXT: tf_device.return %[[B_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
// CHECK-DAG: step_marker_location = ""
|
||||
// CHECK-DAG: padding_map = []
|
||||
// CHECK-DAG: topology = ""
|
||||
// CHECK-DAG: device_assignment = []
|
||||
//
|
||||
// CHECK: %[[TAIL_LAUNCH_OUT:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[CLUSTER_OUT]])
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]]
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
%cluster = "tf_device.cluster"() ( {
|
||||
%a = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
|
||||
%b = "tf.B"(%a) : (tensor<i32>) -> tensor<i32>
|
||||
%c = "tf.C"(%b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
|
||||
tf_device.return %c : tensor<i32>
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor<i32>
|
||||
// CHECK: return %[[TAIL_LAUNCH_OUT]]
|
||||
return %cluster : tensor<i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @head_tail_replicated_outside_compilation
|
||||
func @head_tail_replicated_outside_compilation(%arg0: tensor<i32>, %arg1: tensor<i32>) {
|
||||
// CHECK: tf_device.replicate([%arg0, %arg1] as %[[RI:.*]]: tensor<i32>)
|
||||
//
|
||||
// CHECK-NEXT: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]])
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
|
||||
//
|
||||
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[RI]], %[[B_OUT]])
|
||||
// CHECK-NEXT: "tf.E"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]])
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
// CHECK-DAG: step_marker_location = ""
|
||||
// CHECK-DAG: padding_map = []
|
||||
// CHECK-DAG: topology = ""
|
||||
// CHECK-DAG: device_assignment = []
|
||||
//
|
||||
// CHECK-NEXT: "tf_device.launch"()
|
||||
// CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]], %[[CLUSTER_OUT]], %[[RI]])
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
|
||||
tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} {
|
||||
"tf_device.cluster"() ( {
|
||||
%a = "tf.A"(%ri) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
|
||||
%b = "tf.B"() : () -> tensor<i32>
|
||||
%c = "tf.C"(%ri, %b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%d = "tf.D"(%a, %c, %ri) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%e = "tf.E"(%c, %a) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,9 +78,11 @@ bool OpInBlock(Operation* op, Block* block) {
|
|||
}
|
||||
|
||||
// Wraps block in a Launch. External uses of ops in the block will be return
|
||||
// values of the Launch and remapped to the Launch results.
|
||||
// values of the Launch and remapped to the Launch results. If `before` is set
|
||||
// to true, the Launch is created before `op`. Otherwise the Launch is created
|
||||
// after `op`.
|
||||
tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
|
||||
Block* launch_block,
|
||||
bool before, Block* launch_block,
|
||||
llvm::StringRef host_device) {
|
||||
// Find results and result types of ops in block that needs to returned.
|
||||
llvm::SmallVector<Value, 4> launch_results;
|
||||
|
@ -100,7 +102,7 @@ tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
|
|||
}
|
||||
}
|
||||
|
||||
builder->setInsertionPoint(op);
|
||||
before ? builder->setInsertionPoint(op) : builder->setInsertionPointAfter(op);
|
||||
auto launch = builder->create<tf_device::LaunchOp>(
|
||||
op->getLoc(), builder->getStringAttr(host_device), launch_result_types);
|
||||
launch.body().push_back(launch_block);
|
||||
|
@ -178,22 +180,21 @@ llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
|
|||
Region* cluster_region = &cluster.body();
|
||||
llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
|
||||
|
||||
auto walk_operands = [&](Operation* op) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
Operation* operand_op = GetOpOfValue(operand);
|
||||
if (head_outside_compiled_ops.count(operand_op)) continue;
|
||||
|
||||
if (operand_op->getParentRegion() == cluster_region)
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
return WalkResult::advance();
|
||||
};
|
||||
auto cluster_ops = cluster.GetBody().without_terminator();
|
||||
for (Operation& cluster_op : cluster_ops) {
|
||||
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
|
||||
// An outside compiled op can be extracted if its operands are not from
|
||||
// other ops in the cluster that cannot be extracted.
|
||||
auto walk_result = cluster_op.walk(walk_operands);
|
||||
auto walk_result = cluster_op.walk([&](Operation* op) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
Operation* operand_op = GetOpOfValue(operand);
|
||||
if (head_outside_compiled_ops.count(operand_op)) continue;
|
||||
|
||||
if (operand_op->getParentRegion() == cluster_region)
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
if (!walk_result.wasInterrupted())
|
||||
head_outside_compiled_ops.insert(&cluster_op);
|
||||
|
@ -211,8 +212,8 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
|
|||
for (Operation* head_outside_compiled_op : head_outside_compiled_ops)
|
||||
head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
|
||||
|
||||
tf_device::LaunchOp launch =
|
||||
CreateLaunchForBlock(builder, cluster, launch_block, host_device);
|
||||
tf_device::LaunchOp launch = CreateLaunchForBlock(
|
||||
builder, cluster, /*before=*/true, launch_block, host_device);
|
||||
|
||||
for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
|
||||
launch.getResults()))
|
||||
|
@ -220,6 +221,160 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
|
|||
cluster.body());
|
||||
}
|
||||
|
||||
// Extracts and move outside compiled ops that have no dependencies in the
|
||||
// cluster to before the cluster.
|
||||
mlir::LogicalResult LiftHeadOutsideCompiledOps(
|
||||
OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
|
||||
tf_device::ClusterOp cluster, std::string* host_device,
|
||||
bool* cluster_updated) {
|
||||
llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
|
||||
FindOutsideCompiledOpsAtHead(cluster);
|
||||
if (head_outside_compiled_ops.empty()) return success();
|
||||
if (failed(
|
||||
GetHostDeviceForHeadTailComputation(devices, cluster, host_device)))
|
||||
return failure();
|
||||
|
||||
CreateHeadComputation(builder, cluster, head_outside_compiled_ops,
|
||||
*host_device);
|
||||
|
||||
*cluster_updated = true;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Fills `tail_outside_compiled_ops` with ops that are outside compiled and
|
||||
// can be extracted to after the TPU computation, and `cluster_results` with new
|
||||
// results of the cluster. These ops are either connected to the output of the
|
||||
// TPU computation or other ops that can be extracted, and have no results used
|
||||
// by other ops in the TPU computation that cannot be extracted.
|
||||
void FindOutsideCompiledOpsAtTailAndClusterResults(
|
||||
tf_device::ClusterOp cluster,
|
||||
llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
|
||||
llvm::SmallVectorImpl<Value>* cluster_results) {
|
||||
Region* cluster_region = &cluster.body();
|
||||
llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
|
||||
Operation* terminator = cluster.GetBody().getTerminator();
|
||||
llvm::SmallSetVector<Value, 4> cluster_results_set;
|
||||
cluster_results_set.insert(terminator->getOperands().begin(),
|
||||
terminator->getOperands().end());
|
||||
|
||||
auto cluster_ops = llvm::reverse(cluster.GetBody().without_terminator());
|
||||
for (Operation& cluster_op : cluster_ops) {
|
||||
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
|
||||
|
||||
llvm::SmallVector<int, 4> results_to_forward;
|
||||
bool can_be_extracted =
|
||||
llvm::all_of(cluster_op.getUsers(), [&](Operation* op) {
|
||||
return op == terminator || tail_outside_compiled_ops_set.count(op);
|
||||
});
|
||||
if (!can_be_extracted) continue;
|
||||
|
||||
// Collect operands of cluster op that are generated within the cluster.
|
||||
// These values should be returned by the cluster.
|
||||
cluster_op.walk([&](Operation* op) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
Operation* operand_op = GetOpOfValue(operand);
|
||||
if (operand_op->getParentRegion() == cluster_region)
|
||||
cluster_results_set.insert(operand);
|
||||
}
|
||||
});
|
||||
|
||||
// Remove results of op to be extracted as there are no uses in the cluster.
|
||||
for (Value result : cluster_op.getResults())
|
||||
cluster_results_set.remove(result);
|
||||
tail_outside_compiled_ops_set.insert(&cluster_op);
|
||||
}
|
||||
|
||||
*tail_outside_compiled_ops = tail_outside_compiled_ops_set.takeVector();
|
||||
*cluster_results = cluster_results_set.takeVector();
|
||||
}
|
||||
|
||||
// Moves tail outside compiled ops into its own `tf_device.LaunchOp`
|
||||
// computation after the cluster.
|
||||
void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
|
||||
llvm::ArrayRef<Operation*> tail_outside_compiled_ops,
|
||||
llvm::StringRef host_device) {
|
||||
Block* launch_block = new Block;
|
||||
for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops)
|
||||
tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin());
|
||||
|
||||
tf_device::LaunchOp launch = CreateLaunchForBlock(
|
||||
builder, cluster, /*before=*/false, launch_block, host_device);
|
||||
|
||||
auto operand_not_in_launch = [&](OpOperand& operand) {
|
||||
return !launch.getOperation()->isProperAncestor(operand.getOwner());
|
||||
};
|
||||
for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
|
||||
launch.getResults()))
|
||||
std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
|
||||
operand_not_in_launch);
|
||||
}
|
||||
|
||||
// Updates cluster with updated cluster results after extracting tail outside
|
||||
// compiled ops.
|
||||
tf_device::ClusterOp UpdateClusterResults(
|
||||
OpBuilder* builder, tf_device::ClusterOp cluster,
|
||||
llvm::ArrayRef<Value> new_cluster_results) {
|
||||
Operation* old_terminator = cluster.GetBody().getTerminator();
|
||||
builder->setInsertionPoint(old_terminator);
|
||||
builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
|
||||
new_cluster_results);
|
||||
old_terminator->erase();
|
||||
|
||||
builder->setInsertionPoint(cluster);
|
||||
llvm::SmallVector<Type, 4> new_cluster_result_types;
|
||||
new_cluster_result_types.reserve(new_cluster_results.size());
|
||||
for (const auto& new_cluster_result : new_cluster_results)
|
||||
new_cluster_result_types.push_back(new_cluster_result.getType());
|
||||
|
||||
auto new_cluster = builder->create<tf_device::ClusterOp>(
|
||||
cluster.getLoc(), new_cluster_result_types,
|
||||
/*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
|
||||
new_cluster.body().takeBody(cluster.body());
|
||||
|
||||
auto operand_not_in_cluster = [&](OpOperand& operand) {
|
||||
return !new_cluster.getOperation()->isProperAncestor(operand.getOwner());
|
||||
};
|
||||
for (auto result :
|
||||
llvm::zip(new_cluster.GetBody().getTerminator()->getOperands(),
|
||||
new_cluster.getResults()))
|
||||
std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
|
||||
operand_not_in_cluster);
|
||||
|
||||
cluster.erase();
|
||||
return new_cluster;
|
||||
}
|
||||
|
||||
// Extracts and move outside compiled ops that do not create dependencies in the
|
||||
// cluster to after the cluster.
|
||||
mlir::LogicalResult LiftTailOutsideCompiledOps(
|
||||
OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
|
||||
std::string host_device, tf_device::ClusterOp* cluster,
|
||||
bool* cluster_updated) {
|
||||
llvm::SmallVector<Operation*, 4> tail_outside_compiled_ops;
|
||||
llvm::SmallVector<Value, 4> cluster_results;
|
||||
FindOutsideCompiledOpsAtTailAndClusterResults(
|
||||
*cluster, &tail_outside_compiled_ops, &cluster_results);
|
||||
if (tail_outside_compiled_ops.empty()) return success();
|
||||
|
||||
if (host_device.empty())
|
||||
if (failed(GetHostDeviceForHeadTailComputation(devices, *cluster,
|
||||
&host_device)))
|
||||
return failure();
|
||||
|
||||
// Forward all results of cluster first. These results will be remapped once
|
||||
// a new cluster is formed.
|
||||
cluster->replaceAllUsesWith(
|
||||
cluster->GetBody().getTerminator()->getOperands());
|
||||
|
||||
CreateTailComputation(builder, *cluster, tail_outside_compiled_ops,
|
||||
host_device);
|
||||
|
||||
*cluster = UpdateClusterResults(builder, *cluster, cluster_results);
|
||||
|
||||
*cluster_updated = true;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Removes aliased outputs in cluster from ops outside of cluster.
|
||||
void RemoveClusterAliasedOutputs(OpBuilder* builder,
|
||||
tf_device::ClusterOp cluster) {
|
||||
|
@ -256,26 +411,6 @@ void RemoveClusterAliasedOutputs(OpBuilder* builder,
|
|||
cluster.erase();
|
||||
}
|
||||
|
||||
// Extracts and move outside compiled ops that have no dependencies in the
|
||||
// cluster to before the cluster.
|
||||
mlir::LogicalResult LiftHeadOutsideCompiledOps(
|
||||
OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
|
||||
tf_device::ClusterOp cluster) {
|
||||
llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
|
||||
FindOutsideCompiledOpsAtHead(cluster);
|
||||
if (head_outside_compiled_ops.empty()) return success();
|
||||
std::string host_device;
|
||||
if (failed(
|
||||
GetHostDeviceForHeadTailComputation(devices, cluster, &host_device)))
|
||||
return failure();
|
||||
|
||||
CreateHeadComputation(builder, cluster, head_outside_compiled_ops,
|
||||
host_device);
|
||||
|
||||
RemoveClusterAliasedOutputs(builder, cluster);
|
||||
return success();
|
||||
}
|
||||
|
||||
struct TPUExtractHeadTailOutsideCompilation
|
||||
: public PassWrapper<TPUExtractHeadTailOutsideCompilation,
|
||||
OperationPass<ModuleOp>> {
|
||||
|
@ -295,10 +430,14 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
|
|||
[&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
|
||||
|
||||
for (tf_device::ClusterOp cluster : clusters) {
|
||||
if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster)))
|
||||
std::string host_device;
|
||||
bool cluster_updated = false;
|
||||
if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster,
|
||||
&host_device, &cluster_updated)) ||
|
||||
failed(LiftTailOutsideCompiledOps(&builder, devices, host_device,
|
||||
&cluster, &cluster_updated)))
|
||||
return signalPassFailure();
|
||||
|
||||
// TODO(b/157160906): Implement tail outside compiled op extraction.
|
||||
if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue