Simplify and address missing features in TPU Extract Head Outside Compilation pass.
This updates the TPUExtractHeadTailOutsideCompilation in preparation for outside compilation tail extraction. Certain parts from outside compilation head extraction can be reused. Support for ops with no operands and pruning of aliased results in the cluster is also added. PiperOrigin-RevId: 312752658 Change-Id: I7b07773b59d2dd009ac694dea083caf4eca74c00
This commit is contained in:
parent
85c637969a
commit
60fb5dcc7d
@ -6,12 +6,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK-LABEL: func @single_head_outside_compilation
|
||||
func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK: tf_device.launch
|
||||
//
|
||||
// CHECK: "tf.A"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
//
|
||||
// CHECK: device
|
||||
// CHECK-SAME: "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.C"
|
||||
@ -28,6 +25,88 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||
// CHECK-LABEL: func @ops_no_operands
|
||||
func @ops_no_operands() -> () {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.B"(%[[LAUNCH_OUT]])
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
"tf_device.cluster"() ( {
|
||||
%0 = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<i32>)
|
||||
%1 = "tf.B"(%0) {}: (tensor<i32>) -> (tensor<i32>)
|
||||
"tf.C"(%1) : (tensor<i32>) -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||
// CHECK-LABEL: func @aliased_output
|
||||
func @aliased_output() -> (tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster"
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[LAUNCH_OUT]])
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]], %[[B_OUT]]
|
||||
// CHECK-NEXT: {
|
||||
// CHECK-DAG: num_cores_per_replica = 1
|
||||
// CHECK-DAG: step_marker_location = ""
|
||||
// CHECK-DAG: padding_map = []
|
||||
// CHECK-DAG: topology = ""
|
||||
// CHECK-DAG: device_assignment = []
|
||||
//
|
||||
// CHECK: return %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#0, %[[CLUSTER_OUT]]#1
|
||||
%0:3 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<i32>)
|
||||
%2 = "tf.B"(%1) {}: (tensor<i32>) -> (tensor<i32>)
|
||||
%3 = "tf.C"(%2) : (tensor<i32>) -> (tensor<i32>)
|
||||
tf_device.return %1, %3, %2 : tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>, tensor<i32>, tensor<i32>)
|
||||
return %0#0, %0#1, %0#2 : tensor<i32>, tensor<i32>, tensor<i32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||
// CHECK-LABEL: func @all_head_computation_ops
|
||||
func @all_head_computation_ops(%arg0 : tensor<i32>) -> (tensor<i32>) {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
|
||||
// CHECK: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]], %arg0)
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]]
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
//
|
||||
// CHECK: return %[[LAUNCH_OUT]]
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
%2 = "tf.B"(%1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||
%3 = "tf.C"(%2, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
|
||||
tf_device.return %3 : tensor<i32>
|
||||
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>)
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||
// CHECK-LABEL: func @multiple_head_outside_compilation
|
||||
func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||
@ -36,8 +115,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
|
||||
// CHECK: "tf.C"
|
||||
// CHECK-NEXT: tf_device.return %[[B_OUT]]
|
||||
// CHECK: device
|
||||
// CHECK-SAME: "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.D"(%[[LAUNCH_OUT]])
|
||||
@ -83,8 +161,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
|
||||
// CHECK-NEXT: tf_device.return %[[D_OUT]]
|
||||
// CHECK: device
|
||||
// CHECK-SAME: "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.B"
|
||||
@ -105,15 +182,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||
// CHECK-LABEL: func @test_replicated_head_outside_compilation
|
||||
func @test_replicated_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
|
||||
// CHECK-NEXT: tf_device.return %[[D_OUT]]
|
||||
// CHECK: device
|
||||
// CHECK-SAME: "TPU_REPLICATED_HOST"
|
||||
// CHECK: device = "TPU_REPLICATED_HOST"
|
||||
//
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: "tf.B"
|
||||
|
@ -14,9 +14,10 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "llvm/ADT/Optional.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
@ -26,6 +27,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
@ -51,181 +53,84 @@ bool HasOutsideCompilationAttribute(Operation* op) {
|
||||
return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
|
||||
}
|
||||
|
||||
// Returns whether all operands of `op` are from values inside the
|
||||
// `input_value_set`.
|
||||
bool OpContainsOperandsFromSet(Operation* op,
|
||||
const llvm::SetVector<Value>& input_value_set) {
|
||||
for (auto operand : op->getOperands())
|
||||
if (input_value_set.count(operand) == 0) return false;
|
||||
Operation* GetOpOfValue(Value value) {
|
||||
if (auto block_arg = value.dyn_cast<BlockArgument>())
|
||||
return block_arg.getOwner()->getParentOp();
|
||||
|
||||
return true;
|
||||
return value.getDefiningOp();
|
||||
}
|
||||
|
||||
void RecordOutsideCompiledOpsAndUsages(
|
||||
Operation* op, llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops,
|
||||
llvm::SetVector<Value>* outside_compiled_op_usages) {
|
||||
if (HasOutsideCompilationAttribute(op) &&
|
||||
OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) {
|
||||
outside_compiled_ops->insert(op);
|
||||
outside_compiled_op_usages->insert(op->getResults().begin(),
|
||||
op->getResults().end());
|
||||
}
|
||||
}
|
||||
// Returns a set of ops that are outside compiled and can be extracted to before
|
||||
// the TPU computation. These ops are either connected to the inputs of the TPU
|
||||
// computation or other ops that can be extracted, and have no dependencies with
|
||||
// other ops in the TPU computation that cannot be extracted.
|
||||
llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
|
||||
tf_device::ClusterOp cluster) {
|
||||
llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
|
||||
|
||||
// Traverses the MLIR graph and returns a set of ops that
|
||||
// are connected to inputs of TPU computation and outside compiled.
|
||||
void ExtractOutsideCompiledOpsConnectedToHead(
|
||||
Value input_value, llvm::SetVector<Value>* values_used_in_host_cluster,
|
||||
llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops) {
|
||||
llvm::SmallSetVector<Operation*, 4> parent_outside_compiled_ops_at_head;
|
||||
for (auto& usage : input_value.getUses()) {
|
||||
auto head_operation = usage.getOwner();
|
||||
RecordOutsideCompiledOpsAndUsages(head_operation,
|
||||
&parent_outside_compiled_ops_at_head,
|
||||
values_used_in_host_cluster);
|
||||
}
|
||||
auto cluster_ops = cluster.GetBody().without_terminator();
|
||||
for (Operation& cluster_op : cluster_ops) {
|
||||
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
|
||||
// An outside compiled op can be extracted if its operands are not from
|
||||
// other ops in the cluster that cannot be extracted.
|
||||
auto result = cluster_op.walk([&](Operation* op) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
Operation* operand_op = GetOpOfValue(operand);
|
||||
if (operand_op->isProperAncestor(cluster) ||
|
||||
cluster_op.isAncestor(operand_op) ||
|
||||
head_outside_compiled_ops.count(operand_op))
|
||||
continue;
|
||||
|
||||
// Traverse the graph and find all outside compiled ops connected from
|
||||
// the `input_value`.
|
||||
while (!parent_outside_compiled_ops_at_head.empty()) {
|
||||
llvm::SmallSetVector<Operation*, 4> connected_outside_compiled_ops;
|
||||
for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) {
|
||||
auto op_results = head_outside_compiled_op->getOpResults();
|
||||
for (auto op_result : op_results) {
|
||||
for (auto& use : op_result.getUses()) {
|
||||
auto connected_op = use.getOwner();
|
||||
RecordOutsideCompiledOpsAndUsages(connected_op,
|
||||
&connected_outside_compiled_ops,
|
||||
values_used_in_host_cluster);
|
||||
}
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
||||
outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(),
|
||||
parent_outside_compiled_ops_at_head.end());
|
||||
std::swap(parent_outside_compiled_ops_at_head,
|
||||
connected_outside_compiled_ops);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(hongjunchoi): Also handle ops without inputs that are outside
|
||||
// compiled.
|
||||
//
|
||||
// Returns set of ops that are outside compiled and are directly connected
|
||||
// to inputs to the TPU computation.
|
||||
llvm::SmallSetVector<Operation*, 4> IdentifyOutsideCompiledOpsAtHead(
|
||||
tf_device::ClusterOp tpu_cluster) {
|
||||
llvm::SmallSetVector<Operation*, 4> outside_compiled_at_head_ops;
|
||||
llvm::SetVector<Value> values_used_in_cluster;
|
||||
auto& cluster_region = tpu_cluster.body();
|
||||
getUsedValuesDefinedAbove(cluster_region, cluster_region,
|
||||
values_used_in_cluster);
|
||||
|
||||
auto input_value_list = llvm::to_vector<8>(values_used_in_cluster);
|
||||
for (auto input_value : input_value_list)
|
||||
ExtractOutsideCompiledOpsConnectedToHead(
|
||||
input_value, &values_used_in_cluster, &outside_compiled_at_head_ops);
|
||||
return outside_compiled_at_head_ops;
|
||||
}
|
||||
|
||||
// Returns output values of extracted outside compiled cluster at head that
|
||||
// are used by the TPU computation.
|
||||
llvm::SmallVector<Value, 8> GetHeadExtractedClusterOutputs(
|
||||
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
|
||||
llvm::SmallVector<Value, 8> outputs;
|
||||
outputs.reserve(head_outside_compiled_ops.size());
|
||||
|
||||
for (auto op : head_outside_compiled_ops) {
|
||||
for (Operation* user : op->getUsers()) {
|
||||
if (!head_outside_compiled_ops.count(user)) {
|
||||
outputs.append(op->result_begin(), op->result_end());
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!result.wasInterrupted()) head_outside_compiled_ops.insert(&cluster_op);
|
||||
}
|
||||
|
||||
return outputs;
|
||||
return head_outside_compiled_ops.takeVector();
|
||||
}
|
||||
|
||||
// Creates new tf_device.launch op with outside compiled ops extracted
|
||||
// from the head of TPU computation.
|
||||
llvm::Optional<tf_device::LaunchOp> IsolateHeadExtractedOpsToLaunchOp(
|
||||
OpBuilder* builder, tf_device::ClusterOp cluster,
|
||||
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
|
||||
if (head_outside_compiled_ops.empty())
|
||||
return llvm::Optional<tf_device::LaunchOp>();
|
||||
|
||||
// Create tf_device.launch op to separate all extracted outside compiled ops
|
||||
// before the tf_device.cluster.
|
||||
auto output_values =
|
||||
GetHeadExtractedClusterOutputs(head_outside_compiled_ops);
|
||||
|
||||
llvm::SmallVector<Type, 8> output_return_types;
|
||||
output_return_types.reserve(output_values.size());
|
||||
for (auto output : output_values)
|
||||
output_return_types.emplace_back(output.getType());
|
||||
|
||||
builder->setInsertionPoint(cluster);
|
||||
auto host_launch_op = builder->create<tf_device::LaunchOp>(
|
||||
cluster.getLoc(), builder->getStringAttr(""), output_return_types);
|
||||
|
||||
// Replace all usages of outside compiled ops that are used in TPU
|
||||
// computation with the results of the above created launch op.
|
||||
for (auto output_and_index : llvm::enumerate(output_values)) {
|
||||
auto output_index = output_and_index.index();
|
||||
auto output = output_and_index.value();
|
||||
for (auto& use : output.getUses()) {
|
||||
if (!head_outside_compiled_ops.count(use.getOwner()))
|
||||
use.set(host_launch_op.getResult(output_index));
|
||||
}
|
||||
// Parses TPU compilation and execution devices from a TPU cluster and returns
|
||||
// the host device for the head and tail computations. If the TPU computation is
|
||||
// replicated, kTPUReplicatedHost is returned instead.
|
||||
LogicalResult GetHostDeviceForHeadTailComputation(
|
||||
mlir::TF::RuntimeDevices devices, tf_device::ClusterOp cluster,
|
||||
std::string* host_device) {
|
||||
auto replicate = cluster.getParentOfType<tf_device::ReplicateOp>();
|
||||
if (replicate) {
|
||||
*host_device = tensorflow::kTPUReplicatedHost;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Create terminator op for the newly created launch op.
|
||||
host_launch_op.body().push_back(new Block());
|
||||
builder->setInsertionPointToEnd(&host_launch_op.GetBody());
|
||||
auto terminator = builder->create<tf_device::ReturnOp>(
|
||||
host_launch_op.getLoc(), output_values);
|
||||
|
||||
// Move all outside compile ops from cluster op to launch op.
|
||||
for (auto outside_compiled_op : head_outside_compiled_ops)
|
||||
outside_compiled_op->moveBefore(terminator);
|
||||
|
||||
return host_launch_op;
|
||||
}
|
||||
|
||||
// Parses TPU compilation and execution device form tpu cluster and assigns
|
||||
// host device to `host_launch` device attribute.
|
||||
LogicalResult SetCompilationDeviceToHostLaunch(
|
||||
OpBuilder* builder, mlir::TF::RuntimeDevices devices,
|
||||
tf_device::ClusterOp tpu_cluster, tf_device::LaunchOp host_launch) {
|
||||
auto num_cores_per_replica_attr = tpu_cluster.getAttrOfType<IntegerAttr>(
|
||||
tensorflow::kNumCoresPerReplicaAttr);
|
||||
auto num_cores_per_replica_attr =
|
||||
cluster.getAttrOfType<IntegerAttr>(tensorflow::kNumCoresPerReplicaAttr);
|
||||
if (!num_cores_per_replica_attr)
|
||||
return tpu_cluster.emitOpError(
|
||||
return cluster.emitOpError(
|
||||
"cluster op missing `num_cores_per_replica` attribute");
|
||||
|
||||
if (num_cores_per_replica_attr.getInt() != 1)
|
||||
return tpu_cluster.emitOpError(
|
||||
return cluster.emitOpError(
|
||||
"outside compilation is not supported with model parallelism.");
|
||||
|
||||
auto topology_attr =
|
||||
tpu_cluster.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
|
||||
cluster.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
|
||||
if (!topology_attr)
|
||||
return tpu_cluster.emitOpError("cluster op missing `topology` attribute");
|
||||
return cluster.emitOpError("cluster op missing `topology` attribute");
|
||||
|
||||
auto device_assignment_attr = tpu_cluster.getAttrOfType<mlir::ArrayAttr>(
|
||||
tensorflow::kDeviceAssignmentAttr);
|
||||
auto device_assignment_attr =
|
||||
cluster.getAttrOfType<mlir::ArrayAttr>(tensorflow::kDeviceAssignmentAttr);
|
||||
if (!device_assignment_attr)
|
||||
return tpu_cluster.emitOpError(
|
||||
llvm::formatv("requires attribute '{0}'",
|
||||
tensorflow::kDeviceAssignmentAttr)
|
||||
.str());
|
||||
return cluster.emitOpError(llvm::formatv("requires attribute '{0}'",
|
||||
tensorflow::kDeviceAssignmentAttr)
|
||||
.str());
|
||||
|
||||
auto status_or_device_coodinates =
|
||||
tensorflow::GetDeviceCoordinates(device_assignment_attr);
|
||||
|
||||
if (!status_or_device_coodinates.ok())
|
||||
return tpu_cluster.emitError()
|
||||
return cluster.emitError()
|
||||
<< "error in fetching tpu device coordinates: "
|
||||
<< status_or_device_coodinates.status().error_message();
|
||||
|
||||
@ -236,37 +141,96 @@ LogicalResult SetCompilationDeviceToHostLaunch(
|
||||
/*num_cores_per_replica=*/1, topology_attr.getValue(),
|
||||
status_or_device_coodinates.ConsumeValueOrDie());
|
||||
if (!status_or_tpu_device_assignment.ok())
|
||||
return tpu_cluster.emitError()
|
||||
return cluster.emitError()
|
||||
<< "error in fetching TPU compilation/execution devices: "
|
||||
<< status_or_tpu_device_assignment.status().error_message();
|
||||
auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
|
||||
host_launch.deviceAttr(
|
||||
builder->getStringAttr(tpu_device_assignment.tpu_devices[0][0].host));
|
||||
|
||||
*host_device = tpu_device_assignment.tpu_devices[0][0].host;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Assigns host device attribute to host launch op or enclosing
|
||||
// tf_device.replicate op if TPU computation is replicated.
|
||||
LogicalResult HandleHostLaunchDeviceAssignment(
|
||||
OpBuilder* builder, mlir::TF::RuntimeDevices devices,
|
||||
tf_device::ClusterOp tpu_cluster, tf_device::LaunchOp host_launch) {
|
||||
auto parent_replicate_op =
|
||||
llvm::dyn_cast_or_null<tf_device::ReplicateOp>(host_launch.getParentOp());
|
||||
// If computation is replicated, then add TPU_REPLICATED_HOST device alias
|
||||
// to the host launch op. This device alias would later be a reference to
|
||||
// host device string in the device map of tf_device.replicate op
|
||||
// during tpu_rewrite pass.
|
||||
if (parent_replicate_op) {
|
||||
host_launch.deviceAttr(
|
||||
builder->getStringAttr(tensorflow::kTPUReplicatedHost));
|
||||
} else {
|
||||
if (failed(SetCompilationDeviceToHostLaunch(builder, devices, tpu_cluster,
|
||||
host_launch)))
|
||||
return failure();
|
||||
// Moves head outside compiled ops into its own `tf_device.LaunchOp`
|
||||
// computation.
|
||||
tf_device::LaunchOp CreateHeadComputation(
|
||||
OpBuilder* builder, tf_device::ClusterOp cluster,
|
||||
llvm::ArrayRef<Operation*> head_outside_compiled_ops,
|
||||
llvm::StringRef host_device) {
|
||||
Block* launch_block = new Block;
|
||||
for (Operation* head_outside_compiled_op : head_outside_compiled_ops)
|
||||
head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
|
||||
|
||||
// Find results of ops in head computation that needs to returned.
|
||||
llvm::SmallVector<Value, 4> launch_results;
|
||||
llvm::SmallVector<Type, 4> launch_result_types;
|
||||
for (Operation& head_outside_compiled_op : *launch_block) {
|
||||
for (Value result : head_outside_compiled_op.getResults()) {
|
||||
bool has_uses_in_cluster = false;
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (user->getParentRegion() &&
|
||||
cluster.body().isAncestor(user->getParentRegion())) {
|
||||
has_uses_in_cluster = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_uses_in_cluster) {
|
||||
launch_results.push_back(result);
|
||||
launch_result_types.push_back(result.getType());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
builder->setInsertionPoint(cluster);
|
||||
auto launch = builder->create<tf_device::LaunchOp>(
|
||||
cluster.getLoc(), builder->getStringAttr(host_device),
|
||||
launch_result_types);
|
||||
launch.body().push_back(launch_block);
|
||||
|
||||
builder->setInsertionPointToEnd(&launch.GetBody());
|
||||
builder->create<tf_device::ReturnOp>(cluster.getLoc(), launch_results);
|
||||
|
||||
for (auto result : llvm::zip(launch_results, launch.getResults()))
|
||||
replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result),
|
||||
cluster.body());
|
||||
|
||||
return launch;
|
||||
}
|
||||
|
||||
// Removes aliased outputs in cluster from head computation after head
|
||||
// computation has been extracted.
|
||||
void RemoveHeadComputationAliasedOutputs(OpBuilder* builder,
|
||||
tf_device::LaunchOp head_computation,
|
||||
tf_device::ClusterOp cluster) {
|
||||
llvm::SmallVector<Value, 4> used_old_cluster_results;
|
||||
llvm::SmallVector<Value, 4> new_cluster_results;
|
||||
llvm::SmallVector<Type, 4> new_cluster_result_types;
|
||||
Operation* cluster_terminator = cluster.GetBody().getTerminator();
|
||||
for (auto result :
|
||||
llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) {
|
||||
Value cluster_terminator_operand = std::get<0>(result);
|
||||
if (cluster_terminator_operand.getDefiningOp() == head_computation) {
|
||||
std::get<1>(result).replaceAllUsesWith(cluster_terminator_operand);
|
||||
} else {
|
||||
new_cluster_results.push_back(cluster_terminator_operand);
|
||||
new_cluster_result_types.push_back(cluster_terminator_operand.getType());
|
||||
used_old_cluster_results.push_back(std::get<1>(result));
|
||||
}
|
||||
}
|
||||
|
||||
if (new_cluster_results.size() == cluster.getNumResults()) return;
|
||||
|
||||
builder->setInsertionPoint(cluster);
|
||||
auto new_cluster = builder->create<tf_device::ClusterOp>(
|
||||
cluster.getLoc(), new_cluster_result_types,
|
||||
/*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
|
||||
new_cluster.body().takeBody(cluster.body());
|
||||
new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results);
|
||||
|
||||
for (auto result :
|
||||
llvm::zip(used_old_cluster_results, new_cluster.getResults()))
|
||||
std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
|
||||
|
||||
cluster.erase();
|
||||
}
|
||||
|
||||
struct TPUExtractHeadTailOutsideCompilation
|
||||
@ -283,22 +247,25 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
|
||||
return signalPassFailure();
|
||||
|
||||
OpBuilder builder(&getContext());
|
||||
auto result = module.walk([&](tf_device::ClusterOp cluster) {
|
||||
auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster);
|
||||
auto host_launch_op = IsolateHeadExtractedOpsToLaunchOp(
|
||||
&builder, cluster, head_outside_compiled_ops);
|
||||
if (host_launch_op) {
|
||||
if (failed(HandleHostLaunchDeviceAssignment(&builder, devices, cluster,
|
||||
*host_launch_op))) {
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
llvm::SmallVector<tf_device::ClusterOp, 4> clusters;
|
||||
module.walk(
|
||||
[&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
|
||||
|
||||
// TODO(b/155115766): Implement tail outside compiled op extraction.
|
||||
return WalkResult::advance();
|
||||
});
|
||||
for (tf_device::ClusterOp cluster : clusters) {
|
||||
llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
|
||||
FindOutsideCompiledOpsAtHead(cluster);
|
||||
if (head_outside_compiled_ops.empty()) continue;
|
||||
std::string host_device;
|
||||
if (failed(GetHostDeviceForHeadTailComputation(devices, cluster,
|
||||
&host_device)))
|
||||
return signalPassFailure();
|
||||
|
||||
if (result.wasInterrupted()) signalPassFailure();
|
||||
tf_device::LaunchOp head_computation = CreateHeadComputation(
|
||||
&builder, cluster, head_outside_compiled_ops, host_device);
|
||||
RemoveHeadComputationAliasedOutputs(&builder, head_computation, cluster);
|
||||
|
||||
// TODO(b/157160906): Implement tail outside compiled op extraction.
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
Loading…
Reference in New Issue
Block a user