Update TPUExtractHeadTailOutsideCompilation pass to support tail outside compiled computation extraction.

This extends the pass to lift ops to after the cluster if they are outside compiled and have no dependencies, other than other ops that are to be lifted.

PiperOrigin-RevId: 313877620
Change-Id: Ia5d068d74383206dc6ffaed06f429b7b93767271
This commit is contained in:
Andy Ly 2020-05-29 17:12:33 -07:00 committed by TensorFlower Gardener
parent 481aec85d6
commit 00f027a374
2 changed files with 406 additions and 39 deletions

View File

@ -157,4 +157,232 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
} }
return return
} }
// CHECK-LABEL: func @tail_single_outside_compiled_op
func @tail_single_outside_compiled_op() {
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: {
// CHECK-DAG: num_cores_per_replica = 1
// CHECK-DAG: step_marker_location = ""
// CHECK-DAG: padding_map = []
// CHECK-DAG: topology = ""
// CHECK-DAG: device_assignment = []
//
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]])
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
"tf_device.cluster"() ( {
%a = "tf.A"() : () -> tensor<i32>
"tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> ()
"tf.C"() : () -> ()
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
return
}
// CHECK-LABEL: func @tail_single_outside_compiled_op_user
func @tail_single_outside_compiled_op_user() -> tensor<i32> {
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: {
// CHECK-DAG: num_cores_per_replica = 1
// CHECK-DAG: step_marker_location = ""
// CHECK-DAG: padding_map = []
// CHECK-DAG: topology = ""
// CHECK-DAG: device_assignment = []
//
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]])
// CHECK-NEXT: tf_device.return %[[B_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
%cluster = "tf_device.cluster"() ( {
%a = "tf.A"() : () -> tensor<i32>
%b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
"tf.C"() : () -> ()
tf_device.return %b : tensor<i32>
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor<i32>
// CHECK: return %[[LAUNCH_OUT]]
return %cluster : tensor<i32>
}
// CHECK-LABEL: func @tail_multiple_outside_compiled_ops
func @tail_multiple_outside_compiled_ops(%arg0: tensor<i32>) {
// CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
// CHECK-NEXT: tf_device.return %[[B_OUT]], %[[A_OUT]]
// CHECK-NEXT: {
// CHECK-DAG: num_cores_per_replica = 1
// CHECK-DAG: step_marker_location = ""
// CHECK-DAG: padding_map = []
// CHECK-DAG: topology = ""
// CHECK-DAG: device_assignment = []
//
// CHECK: "tf_device.launch"
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%arg0, %[[CLUSTER_OUT]]#1)
// CHECK-NEXT: "tf.D"(%[[C_OUT]], %arg0, %[[CLUSTER_OUT]]#0)
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
"tf_device.cluster"() ( {
%a = "tf.A"() : () -> tensor<i32>
%b = "tf.B"(%arg0) : (tensor<i32>) -> tensor<i32>
%c = "tf.C"(%arg0, %a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
"tf.D"(%c, %arg0, %b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> ()
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
return
}
// CHECK-LABEL: func @tail_aliased_output
func @tail_aliased_output() -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) {
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
%a = "tf.A"() : () -> tensor<i32>
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
%b = "tf.B"() : () -> tensor<i32>
// CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster"
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"
// CHECK-NEXT: %[[E_OUT:.*]] = "tf.E"
// CHECK-NEXT: tf_device.return %[[C_OUT]], %[[E_OUT]]
// CHECK-NEXT: {
// CHECK-DAG: num_cores_per_replica = 1
// CHECK-DAG: step_marker_location = ""
// CHECK-DAG: padding_map = []
// CHECK-DAG: topology = ""
// CHECK-DAG: device_assignment = []
//
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[CLUSTER_OUT]]#0, %[[A_OUT]])
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
%cluster:5 = "tf_device.cluster"() ( {
%c = "tf.C"() : () -> tensor<i32>
%d = "tf.D"(%c, %a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%e = "tf.E"() : () -> tensor<i32>
tf_device.return %a, %b, %c, %d, %e : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
// CHECK: return %[[A_OUT]], %[[B_OUT]], %[[CLUSTER_OUT]]#0, %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#1
return %cluster#0, %cluster#1, %cluster#2, %cluster#3, %cluster#4 : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
}
// CHECK-LABEL: func @tail_replicated_outside_compilation
func @tail_replicated_outside_compilation(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// CHECK: tf_device.replicate([%arg0, %arg1] as %[[RI:.*]]: tensor<i32>)
//
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]])
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: {
// CHECK-DAG: num_cores_per_replica = 1
// CHECK-DAG: step_marker_location = ""
// CHECK-DAG: padding_map = []
// CHECK-DAG: topology = ""
// CHECK-DAG: device_assignment = []
//
// CHECK-NEXT: "tf_device.launch"()
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]], %[[RI]])
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} {
"tf_device.cluster"() ( {
%a = "tf.A"(%ri) : (tensor<i32>) -> tensor<i32>
%b = "tf.B"(%a, %ri) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
tf_device.return
}
return
}
// CHECK-LABEL: func @head_tail_no_extraction_middle_outside_compiled_ops
func @head_tail_no_extraction_middle_outside_compiled_ops(%arg0: tensor<i32>) {
// CHECK-NOT: "tf_device.launch"
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: tf_device.return
"tf_device.cluster"() ( {
%a = "tf.A"(%arg0) : (tensor<i32>) -> tensor<i32>
%b = "tf.B"(%a) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
"tf.C"(%b) : (tensor<i32>) -> ()
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
return
}
// CHECK-LABEL: func @head_tail_simple_extraction
func @head_tail_simple_extraction(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%arg0)
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[HEAD_LAUNCH_OUT]])
// CHECK-NEXT: tf_device.return %[[B_OUT]]
// CHECK-NEXT: {
// CHECK-DAG: num_cores_per_replica = 1
// CHECK-DAG: step_marker_location = ""
// CHECK-DAG: padding_map = []
// CHECK-DAG: topology = ""
// CHECK-DAG: device_assignment = []
//
// CHECK: %[[TAIL_LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[CLUSTER_OUT]])
// CHECK-NEXT: tf_device.return %[[C_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
%cluster = "tf_device.cluster"() ( {
%a = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
%b = "tf.B"(%a) : (tensor<i32>) -> tensor<i32>
%c = "tf.C"(%b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
tf_device.return %c : tensor<i32>
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> tensor<i32>
// CHECK: return %[[TAIL_LAUNCH_OUT]]
return %cluster : tensor<i32>
}
// CHECK-LABEL: func @head_tail_replicated_outside_compilation
func @head_tail_replicated_outside_compilation(%arg0: tensor<i32>, %arg1: tensor<i32>) {
// CHECK: tf_device.replicate([%arg0, %arg1] as %[[RI:.*]]: tensor<i32>)
//
// CHECK-NEXT: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]])
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
//
// CHECK: %[[CLUSTER_OUT:.*]] = "tf_device.cluster"
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[RI]], %[[B_OUT]])
// CHECK-NEXT: "tf.E"(%[[C_OUT]], %[[HEAD_LAUNCH_OUT]])
// CHECK-NEXT: tf_device.return %[[C_OUT]]
// CHECK-NEXT: {
// CHECK-DAG: num_cores_per_replica = 1
// CHECK-DAG: step_marker_location = ""
// CHECK-DAG: padding_map = []
// CHECK-DAG: topology = ""
// CHECK-DAG: device_assignment = []
//
// CHECK-NEXT: "tf_device.launch"()
// CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]], %[[CLUSTER_OUT]], %[[RI]])
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} {
"tf_device.cluster"() ( {
%a = "tf.A"(%ri) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> tensor<i32>
%b = "tf.B"() : () -> tensor<i32>
%c = "tf.C"(%ri, %b) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> tensor<i32>
%d = "tf.D"(%a, %c, %ri) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<i32>
%e = "tf.E"(%c, %a) : (tensor<i32>, tensor<i32>) -> tensor<i32>
tf_device.return
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
tf_device.return
}
return
}
} }

View File

@ -78,9 +78,11 @@ bool OpInBlock(Operation* op, Block* block) {
} }
// Wraps block in a Launch. External uses of ops in the block will be return // Wraps block in a Launch. External uses of ops in the block will be return
// values of the Launch and remapped to the Launch results. // values of the Launch and remapped to the Launch results. If `before` is set
// to true, the Launch is created before `op`. Otherwise the Launch is created
// after `op`.
tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op, tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
Block* launch_block, bool before, Block* launch_block,
llvm::StringRef host_device) { llvm::StringRef host_device) {
// Find results and result types of ops in block that needs to returned. // Find results and result types of ops in block that needs to returned.
llvm::SmallVector<Value, 4> launch_results; llvm::SmallVector<Value, 4> launch_results;
@ -100,7 +102,7 @@ tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
} }
} }
builder->setInsertionPoint(op); before ? builder->setInsertionPoint(op) : builder->setInsertionPointAfter(op);
auto launch = builder->create<tf_device::LaunchOp>( auto launch = builder->create<tf_device::LaunchOp>(
op->getLoc(), builder->getStringAttr(host_device), launch_result_types); op->getLoc(), builder->getStringAttr(host_device), launch_result_types);
launch.body().push_back(launch_block); launch.body().push_back(launch_block);
@ -178,7 +180,12 @@ llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
Region* cluster_region = &cluster.body(); Region* cluster_region = &cluster.body();
llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops; llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
auto walk_operands = [&](Operation* op) { auto cluster_ops = cluster.GetBody().without_terminator();
for (Operation& cluster_op : cluster_ops) {
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
// An outside compiled op can be extracted if its operands are not from
// other ops in the cluster that cannot be extracted.
auto walk_result = cluster_op.walk([&](Operation* op) {
for (Value operand : op->getOperands()) { for (Value operand : op->getOperands()) {
Operation* operand_op = GetOpOfValue(operand); Operation* operand_op = GetOpOfValue(operand);
if (head_outside_compiled_ops.count(operand_op)) continue; if (head_outside_compiled_ops.count(operand_op)) continue;
@ -187,13 +194,7 @@ llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
return WalkResult::interrupt(); return WalkResult::interrupt();
} }
return WalkResult::advance(); return WalkResult::advance();
}; });
auto cluster_ops = cluster.GetBody().without_terminator();
for (Operation& cluster_op : cluster_ops) {
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
// An outside compiled op can be extracted if its operands are not from
// other ops in the cluster that cannot be extracted.
auto walk_result = cluster_op.walk(walk_operands);
if (!walk_result.wasInterrupted()) if (!walk_result.wasInterrupted())
head_outside_compiled_ops.insert(&cluster_op); head_outside_compiled_ops.insert(&cluster_op);
@ -211,8 +212,8 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
for (Operation* head_outside_compiled_op : head_outside_compiled_ops) for (Operation* head_outside_compiled_op : head_outside_compiled_ops)
head_outside_compiled_op->moveBefore(launch_block, launch_block->end()); head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
tf_device::LaunchOp launch = tf_device::LaunchOp launch = CreateLaunchForBlock(
CreateLaunchForBlock(builder, cluster, launch_block, host_device); builder, cluster, /*before=*/true, launch_block, host_device);
for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(), for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
launch.getResults())) launch.getResults()))
@ -220,6 +221,160 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
cluster.body()); cluster.body());
} }
// Extracts and move outside compiled ops that have no dependencies in the
// cluster to before the cluster.
mlir::LogicalResult LiftHeadOutsideCompiledOps(
OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
tf_device::ClusterOp cluster, std::string* host_device,
bool* cluster_updated) {
llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
FindOutsideCompiledOpsAtHead(cluster);
if (head_outside_compiled_ops.empty()) return success();
if (failed(
GetHostDeviceForHeadTailComputation(devices, cluster, host_device)))
return failure();
CreateHeadComputation(builder, cluster, head_outside_compiled_ops,
*host_device);
*cluster_updated = true;
return success();
}
// Fills `tail_outside_compiled_ops` with ops that are outside compiled and
// can be extracted to after the TPU computation, and `cluster_results` with new
// results of the cluster. These ops are either connected to the output of the
// TPU computation or other ops that can be extracted, and have no results used
// by other ops in the TPU computation that cannot be extracted.
void FindOutsideCompiledOpsAtTailAndClusterResults(
tf_device::ClusterOp cluster,
llvm::SmallVectorImpl<Operation*>* tail_outside_compiled_ops,
llvm::SmallVectorImpl<Value>* cluster_results) {
Region* cluster_region = &cluster.body();
llvm::SmallSetVector<Operation*, 4> tail_outside_compiled_ops_set;
Operation* terminator = cluster.GetBody().getTerminator();
llvm::SmallSetVector<Value, 4> cluster_results_set;
cluster_results_set.insert(terminator->getOperands().begin(),
terminator->getOperands().end());
auto cluster_ops = llvm::reverse(cluster.GetBody().without_terminator());
for (Operation& cluster_op : cluster_ops) {
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
llvm::SmallVector<int, 4> results_to_forward;
bool can_be_extracted =
llvm::all_of(cluster_op.getUsers(), [&](Operation* op) {
return op == terminator || tail_outside_compiled_ops_set.count(op);
});
if (!can_be_extracted) continue;
// Collect operands of cluster op that are generated within the cluster.
// These values should be returned by the cluster.
cluster_op.walk([&](Operation* op) {
for (Value operand : op->getOperands()) {
Operation* operand_op = GetOpOfValue(operand);
if (operand_op->getParentRegion() == cluster_region)
cluster_results_set.insert(operand);
}
});
// Remove results of op to be extracted as there are no uses in the cluster.
for (Value result : cluster_op.getResults())
cluster_results_set.remove(result);
tail_outside_compiled_ops_set.insert(&cluster_op);
}
*tail_outside_compiled_ops = tail_outside_compiled_ops_set.takeVector();
*cluster_results = cluster_results_set.takeVector();
}
// Moves tail outside compiled ops into its own `tf_device.LaunchOp`
// computation after the cluster.
void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
llvm::ArrayRef<Operation*> tail_outside_compiled_ops,
llvm::StringRef host_device) {
Block* launch_block = new Block;
for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops)
tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin());
tf_device::LaunchOp launch = CreateLaunchForBlock(
builder, cluster, /*before=*/false, launch_block, host_device);
auto operand_not_in_launch = [&](OpOperand& operand) {
return !launch.getOperation()->isProperAncestor(operand.getOwner());
};
for (auto result : llvm::zip(launch.GetBody().getTerminator()->getOperands(),
launch.getResults()))
std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
operand_not_in_launch);
}
// Updates cluster with updated cluster results after extracting tail outside
// compiled ops.
tf_device::ClusterOp UpdateClusterResults(
OpBuilder* builder, tf_device::ClusterOp cluster,
llvm::ArrayRef<Value> new_cluster_results) {
Operation* old_terminator = cluster.GetBody().getTerminator();
builder->setInsertionPoint(old_terminator);
builder->create<tf_device::ReturnOp>(old_terminator->getLoc(),
new_cluster_results);
old_terminator->erase();
builder->setInsertionPoint(cluster);
llvm::SmallVector<Type, 4> new_cluster_result_types;
new_cluster_result_types.reserve(new_cluster_results.size());
for (const auto& new_cluster_result : new_cluster_results)
new_cluster_result_types.push_back(new_cluster_result.getType());
auto new_cluster = builder->create<tf_device::ClusterOp>(
cluster.getLoc(), new_cluster_result_types,
/*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
new_cluster.body().takeBody(cluster.body());
auto operand_not_in_cluster = [&](OpOperand& operand) {
return !new_cluster.getOperation()->isProperAncestor(operand.getOwner());
};
for (auto result :
llvm::zip(new_cluster.GetBody().getTerminator()->getOperands(),
new_cluster.getResults()))
std::get<0>(result).replaceUsesWithIf(std::get<1>(result),
operand_not_in_cluster);
cluster.erase();
return new_cluster;
}
// Extracts and move outside compiled ops that do not create dependencies in the
// cluster to after the cluster.
mlir::LogicalResult LiftTailOutsideCompiledOps(
OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
std::string host_device, tf_device::ClusterOp* cluster,
bool* cluster_updated) {
llvm::SmallVector<Operation*, 4> tail_outside_compiled_ops;
llvm::SmallVector<Value, 4> cluster_results;
FindOutsideCompiledOpsAtTailAndClusterResults(
*cluster, &tail_outside_compiled_ops, &cluster_results);
if (tail_outside_compiled_ops.empty()) return success();
if (host_device.empty())
if (failed(GetHostDeviceForHeadTailComputation(devices, *cluster,
&host_device)))
return failure();
// Forward all results of cluster first. These results will be remapped once
// a new cluster is formed.
cluster->replaceAllUsesWith(
cluster->GetBody().getTerminator()->getOperands());
CreateTailComputation(builder, *cluster, tail_outside_compiled_ops,
host_device);
*cluster = UpdateClusterResults(builder, *cluster, cluster_results);
*cluster_updated = true;
return success();
}
// Removes aliased outputs in cluster from ops outside of cluster. // Removes aliased outputs in cluster from ops outside of cluster.
void RemoveClusterAliasedOutputs(OpBuilder* builder, void RemoveClusterAliasedOutputs(OpBuilder* builder,
tf_device::ClusterOp cluster) { tf_device::ClusterOp cluster) {
@ -256,26 +411,6 @@ void RemoveClusterAliasedOutputs(OpBuilder* builder,
cluster.erase(); cluster.erase();
} }
// Extracts and move outside compiled ops that have no dependencies in the
// cluster to before the cluster.
mlir::LogicalResult LiftHeadOutsideCompiledOps(
OpBuilder* builder, const mlir::TF::RuntimeDevices& devices,
tf_device::ClusterOp cluster) {
llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
FindOutsideCompiledOpsAtHead(cluster);
if (head_outside_compiled_ops.empty()) return success();
std::string host_device;
if (failed(
GetHostDeviceForHeadTailComputation(devices, cluster, &host_device)))
return failure();
CreateHeadComputation(builder, cluster, head_outside_compiled_ops,
host_device);
RemoveClusterAliasedOutputs(builder, cluster);
return success();
}
struct TPUExtractHeadTailOutsideCompilation struct TPUExtractHeadTailOutsideCompilation
: public PassWrapper<TPUExtractHeadTailOutsideCompilation, : public PassWrapper<TPUExtractHeadTailOutsideCompilation,
OperationPass<ModuleOp>> { OperationPass<ModuleOp>> {
@ -295,10 +430,14 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
[&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); }); [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
for (tf_device::ClusterOp cluster : clusters) { for (tf_device::ClusterOp cluster : clusters) {
if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster))) std::string host_device;
bool cluster_updated = false;
if (failed(LiftHeadOutsideCompiledOps(&builder, devices, cluster,
&host_device, &cluster_updated)) ||
failed(LiftTailOutsideCompiledOps(&builder, devices, host_device,
&cluster, &cluster_updated)))
return signalPassFailure(); return signalPassFailure();
if (cluster_updated) RemoveClusterAliasedOutputs(&builder, cluster);
// TODO(b/157160906): Implement tail outside compiled op extraction.
} }
} }