Simplify and address missing features in TPU Extract Head Outside Compilation pass.
This updates the TPUExtractHeadTailOutsideCompilation in preparation for outside compilation tail extraction. Certain parts from outside compilation head extraction can be reused. Support for ops with no operands and pruning of aliased results in the cluster is also added. PiperOrigin-RevId: 312752658 Change-Id: I7b07773b59d2dd009ac694dea083caf4eca74c00
This commit is contained in:
parent
85c637969a
commit
60fb5dcc7d
@ -6,12 +6,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
|||||||
// CHECK-LABEL: func @single_head_outside_compilation
|
// CHECK-LABEL: func @single_head_outside_compilation
|
||||||
func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
func @single_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||||
// CHECK: tf_device.launch
|
// CHECK: tf_device.launch
|
||||||
//
|
|
||||||
// CHECK: "tf.A"
|
// CHECK: "tf.A"
|
||||||
// CHECK-NEXT: tf_device.return
|
// CHECK-NEXT: tf_device.return
|
||||||
//
|
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||||
// CHECK: device
|
|
||||||
// CHECK-SAME: "/job:worker/replica:0/task:0/device:CPU:0"
|
|
||||||
//
|
//
|
||||||
// CHECK: "tf_device.cluster"
|
// CHECK: "tf_device.cluster"
|
||||||
// CHECK: "tf.C"
|
// CHECK: "tf.C"
|
||||||
@ -28,6 +25,88 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||||
|
// CHECK-LABEL: func @ops_no_operands
|
||||||
|
func @ops_no_operands() -> () {
|
||||||
|
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||||
|
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||||
|
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||||
|
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||||
|
//
|
||||||
|
// CHECK: "tf_device.cluster"
|
||||||
|
// CHECK-NEXT: "tf.B"(%[[LAUNCH_OUT]])
|
||||||
|
// CHECK-NEXT: "tf.C"
|
||||||
|
// CHECK-NEXT: tf_device.return
|
||||||
|
"tf_device.cluster"() ( {
|
||||||
|
%0 = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<i32>)
|
||||||
|
%1 = "tf.B"(%0) {}: (tensor<i32>) -> (tensor<i32>)
|
||||||
|
"tf.C"(%1) : (tensor<i32>) -> ()
|
||||||
|
tf_device.return
|
||||||
|
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||||
|
// CHECK-LABEL: func @aliased_output
|
||||||
|
func @aliased_output() -> (tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||||
|
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||||
|
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||||
|
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||||
|
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||||
|
//
|
||||||
|
// CHECK: %[[CLUSTER_OUT:.*]]:2 = "tf_device.cluster"
|
||||||
|
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[LAUNCH_OUT]])
|
||||||
|
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"
|
||||||
|
// CHECK-NEXT: tf_device.return %[[C_OUT]], %[[B_OUT]]
|
||||||
|
// CHECK-NEXT: {
|
||||||
|
// CHECK-DAG: num_cores_per_replica = 1
|
||||||
|
// CHECK-DAG: step_marker_location = ""
|
||||||
|
// CHECK-DAG: padding_map = []
|
||||||
|
// CHECK-DAG: topology = ""
|
||||||
|
// CHECK-DAG: device_assignment = []
|
||||||
|
//
|
||||||
|
// CHECK: return %[[LAUNCH_OUT]], %[[CLUSTER_OUT]]#0, %[[CLUSTER_OUT]]#1
|
||||||
|
%0:3 = "tf_device.cluster"() ( {
|
||||||
|
%1 = "tf.A"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<i32>)
|
||||||
|
%2 = "tf.B"(%1) {}: (tensor<i32>) -> (tensor<i32>)
|
||||||
|
%3 = "tf.C"(%2) : (tensor<i32>) -> (tensor<i32>)
|
||||||
|
tf_device.return %1, %3, %2 : tensor<i32>, tensor<i32>, tensor<i32>
|
||||||
|
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>, tensor<i32>, tensor<i32>)
|
||||||
|
return %0#0, %0#1, %0#2 : tensor<i32>, tensor<i32>, tensor<i32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||||
|
// CHECK-LABEL: func @all_head_computation_ops
|
||||||
|
func @all_head_computation_ops(%arg0 : tensor<i32>) -> (tensor<i32>) {
|
||||||
|
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||||
|
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||||
|
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
|
||||||
|
// CHECK: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]], %arg0)
|
||||||
|
// CHECK-NEXT: tf_device.return %[[C_OUT]]
|
||||||
|
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||||
|
//
|
||||||
|
// CHECK: "tf_device.cluster"
|
||||||
|
// CHECK-NEXT: tf_device.return
|
||||||
|
//
|
||||||
|
// CHECK: return %[[LAUNCH_OUT]]
|
||||||
|
%0 = "tf_device.cluster"() ( {
|
||||||
|
%1 = "tf.A"(%arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||||
|
%2 = "tf.B"(%1) {_xla_outside_compilation = "cluster1"} : (tensor<i32>) -> (tensor<i32>)
|
||||||
|
%3 = "tf.C"(%2, %arg0) {_xla_outside_compilation = "cluster1"} : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
|
||||||
|
tf_device.return %3 : tensor<i32>
|
||||||
|
}) {num_cores_per_replica = 1, step_marker_location = "", padding_map = [], topology = "", device_assignment = []} : () -> (tensor<i32>)
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||||
// CHECK-LABEL: func @multiple_head_outside_compilation
|
// CHECK-LABEL: func @multiple_head_outside_compilation
|
||||||
func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
func @multiple_head_outside_compilation(%arg0 : tensor<i32>) -> () {
|
||||||
@ -36,8 +115,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
|||||||
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
|
// CHECK: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
|
||||||
// CHECK: "tf.C"
|
// CHECK: "tf.C"
|
||||||
// CHECK-NEXT: tf_device.return %[[B_OUT]]
|
// CHECK-NEXT: tf_device.return %[[B_OUT]]
|
||||||
// CHECK: device
|
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||||
// CHECK-SAME: "/job:worker/replica:0/task:0/device:CPU:0"
|
|
||||||
//
|
//
|
||||||
// CHECK: "tf_device.cluster"
|
// CHECK: "tf_device.cluster"
|
||||||
// CHECK: "tf.D"(%[[LAUNCH_OUT]])
|
// CHECK: "tf.D"(%[[LAUNCH_OUT]])
|
||||||
@ -83,8 +161,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
|||||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||||
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
|
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
|
||||||
// CHECK-NEXT: tf_device.return %[[D_OUT]]
|
// CHECK-NEXT: tf_device.return %[[D_OUT]]
|
||||||
// CHECK: device
|
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||||
// CHECK-SAME: "/job:worker/replica:0/task:0/device:CPU:0"
|
|
||||||
//
|
//
|
||||||
// CHECK: "tf_device.cluster"
|
// CHECK: "tf_device.cluster"
|
||||||
// CHECK: "tf.B"
|
// CHECK: "tf.B"
|
||||||
@ -112,8 +189,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
|||||||
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
// CHECK: %[[A_OUT:.*]] = "tf.A"
|
||||||
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
|
// CHECK: %[[D_OUT:.*]] = "tf.D"(%[[A_OUT]])
|
||||||
// CHECK-NEXT: tf_device.return %[[D_OUT]]
|
// CHECK-NEXT: tf_device.return %[[D_OUT]]
|
||||||
// CHECK: device
|
// CHECK: device = "TPU_REPLICATED_HOST"
|
||||||
// CHECK-SAME: "TPU_REPLICATED_HOST"
|
|
||||||
//
|
//
|
||||||
// CHECK: "tf_device.cluster"
|
// CHECK: "tf_device.cluster"
|
||||||
// CHECK: "tf.B"
|
// CHECK: "tf.B"
|
||||||
|
@ -14,9 +14,10 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <tuple>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
#include "llvm/ADT/Optional.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
@ -26,6 +27,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Block.h" // from @llvm-project
|
#include "mlir/IR/Block.h" // from @llvm-project
|
||||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Value.h" // from @llvm-project
|
||||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||||
@ -51,181 +53,84 @@ bool HasOutsideCompilationAttribute(Operation* op) {
|
|||||||
return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
|
return op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr) != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns whether all operands of `op` are from values inside the
|
Operation* GetOpOfValue(Value value) {
|
||||||
// `input_value_set`.
|
if (auto block_arg = value.dyn_cast<BlockArgument>())
|
||||||
bool OpContainsOperandsFromSet(Operation* op,
|
return block_arg.getOwner()->getParentOp();
|
||||||
const llvm::SetVector<Value>& input_value_set) {
|
|
||||||
for (auto operand : op->getOperands())
|
|
||||||
if (input_value_set.count(operand) == 0) return false;
|
|
||||||
|
|
||||||
return true;
|
return value.getDefiningOp();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RecordOutsideCompiledOpsAndUsages(
|
// Returns a set of ops that are outside compiled and can be extracted to before
|
||||||
Operation* op, llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops,
|
// the TPU computation. These ops are either connected to the inputs of the TPU
|
||||||
llvm::SetVector<Value>* outside_compiled_op_usages) {
|
// computation or other ops that can be extracted, and have no dependencies with
|
||||||
if (HasOutsideCompilationAttribute(op) &&
|
// other ops in the TPU computation that cannot be extracted.
|
||||||
OpContainsOperandsFromSet(op, *outside_compiled_op_usages)) {
|
llvm::SmallVector<Operation*, 4> FindOutsideCompiledOpsAtHead(
|
||||||
outside_compiled_ops->insert(op);
|
tf_device::ClusterOp cluster) {
|
||||||
outside_compiled_op_usages->insert(op->getResults().begin(),
|
llvm::SmallSetVector<Operation*, 4> head_outside_compiled_ops;
|
||||||
op->getResults().end());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Traverses the MLIR graph and returns a set of ops that
|
auto cluster_ops = cluster.GetBody().without_terminator();
|
||||||
// are connected to inputs of TPU computation and outside compiled.
|
for (Operation& cluster_op : cluster_ops) {
|
||||||
void ExtractOutsideCompiledOpsConnectedToHead(
|
if (!HasOutsideCompilationAttribute(&cluster_op)) continue;
|
||||||
Value input_value, llvm::SetVector<Value>* values_used_in_host_cluster,
|
// An outside compiled op can be extracted if its operands are not from
|
||||||
llvm::SmallSetVector<Operation*, 4>* outside_compiled_ops) {
|
// other ops in the cluster that cannot be extracted.
|
||||||
llvm::SmallSetVector<Operation*, 4> parent_outside_compiled_ops_at_head;
|
auto result = cluster_op.walk([&](Operation* op) {
|
||||||
for (auto& usage : input_value.getUses()) {
|
for (Value operand : op->getOperands()) {
|
||||||
auto head_operation = usage.getOwner();
|
Operation* operand_op = GetOpOfValue(operand);
|
||||||
RecordOutsideCompiledOpsAndUsages(head_operation,
|
if (operand_op->isProperAncestor(cluster) ||
|
||||||
&parent_outside_compiled_ops_at_head,
|
cluster_op.isAncestor(operand_op) ||
|
||||||
values_used_in_host_cluster);
|
head_outside_compiled_ops.count(operand_op))
|
||||||
}
|
continue;
|
||||||
|
|
||||||
// Traverse the graph and find all outside compiled ops connected from
|
return WalkResult::interrupt();
|
||||||
// the `input_value`.
|
|
||||||
while (!parent_outside_compiled_ops_at_head.empty()) {
|
|
||||||
llvm::SmallSetVector<Operation*, 4> connected_outside_compiled_ops;
|
|
||||||
for (auto head_outside_compiled_op : parent_outside_compiled_ops_at_head) {
|
|
||||||
auto op_results = head_outside_compiled_op->getOpResults();
|
|
||||||
for (auto op_result : op_results) {
|
|
||||||
for (auto& use : op_result.getUses()) {
|
|
||||||
auto connected_op = use.getOwner();
|
|
||||||
RecordOutsideCompiledOpsAndUsages(connected_op,
|
|
||||||
&connected_outside_compiled_ops,
|
|
||||||
values_used_in_host_cluster);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
return WalkResult::advance();
|
||||||
|
});
|
||||||
|
|
||||||
outside_compiled_ops->insert(parent_outside_compiled_ops_at_head.begin(),
|
if (!result.wasInterrupted()) head_outside_compiled_ops.insert(&cluster_op);
|
||||||
parent_outside_compiled_ops_at_head.end());
|
|
||||||
std::swap(parent_outside_compiled_ops_at_head,
|
|
||||||
connected_outside_compiled_ops);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(hongjunchoi): Also handle ops without inputs that are outside
|
|
||||||
// compiled.
|
|
||||||
//
|
|
||||||
// Returns set of ops that are outside compiled and are directly connected
|
|
||||||
// to inputs to the TPU computation.
|
|
||||||
llvm::SmallSetVector<Operation*, 4> IdentifyOutsideCompiledOpsAtHead(
|
|
||||||
tf_device::ClusterOp tpu_cluster) {
|
|
||||||
llvm::SmallSetVector<Operation*, 4> outside_compiled_at_head_ops;
|
|
||||||
llvm::SetVector<Value> values_used_in_cluster;
|
|
||||||
auto& cluster_region = tpu_cluster.body();
|
|
||||||
getUsedValuesDefinedAbove(cluster_region, cluster_region,
|
|
||||||
values_used_in_cluster);
|
|
||||||
|
|
||||||
auto input_value_list = llvm::to_vector<8>(values_used_in_cluster);
|
|
||||||
for (auto input_value : input_value_list)
|
|
||||||
ExtractOutsideCompiledOpsConnectedToHead(
|
|
||||||
input_value, &values_used_in_cluster, &outside_compiled_at_head_ops);
|
|
||||||
return outside_compiled_at_head_ops;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns output values of extracted outside compiled cluster at head that
|
|
||||||
// are used by the TPU computation.
|
|
||||||
llvm::SmallVector<Value, 8> GetHeadExtractedClusterOutputs(
|
|
||||||
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
|
|
||||||
llvm::SmallVector<Value, 8> outputs;
|
|
||||||
outputs.reserve(head_outside_compiled_ops.size());
|
|
||||||
|
|
||||||
for (auto op : head_outside_compiled_ops) {
|
|
||||||
for (Operation* user : op->getUsers()) {
|
|
||||||
if (!head_outside_compiled_ops.count(user)) {
|
|
||||||
outputs.append(op->result_begin(), op->result_end());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return outputs;
|
return head_outside_compiled_ops.takeVector();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates new tf_device.launch op with outside compiled ops extracted
|
// Parses TPU compilation and execution devices from a TPU cluster and returns
|
||||||
// from the head of TPU computation.
|
// the host device for the head and tail computations. If the TPU computation is
|
||||||
llvm::Optional<tf_device::LaunchOp> IsolateHeadExtractedOpsToLaunchOp(
|
// replicated, kTPUReplicatedHost is returned instead.
|
||||||
OpBuilder* builder, tf_device::ClusterOp cluster,
|
LogicalResult GetHostDeviceForHeadTailComputation(
|
||||||
const llvm::SmallSetVector<Operation*, 4>& head_outside_compiled_ops) {
|
mlir::TF::RuntimeDevices devices, tf_device::ClusterOp cluster,
|
||||||
if (head_outside_compiled_ops.empty())
|
std::string* host_device) {
|
||||||
return llvm::Optional<tf_device::LaunchOp>();
|
auto replicate = cluster.getParentOfType<tf_device::ReplicateOp>();
|
||||||
|
if (replicate) {
|
||||||
// Create tf_device.launch op to separate all extracted outside compiled ops
|
*host_device = tensorflow::kTPUReplicatedHost;
|
||||||
// before the tf_device.cluster.
|
return success();
|
||||||
auto output_values =
|
|
||||||
GetHeadExtractedClusterOutputs(head_outside_compiled_ops);
|
|
||||||
|
|
||||||
llvm::SmallVector<Type, 8> output_return_types;
|
|
||||||
output_return_types.reserve(output_values.size());
|
|
||||||
for (auto output : output_values)
|
|
||||||
output_return_types.emplace_back(output.getType());
|
|
||||||
|
|
||||||
builder->setInsertionPoint(cluster);
|
|
||||||
auto host_launch_op = builder->create<tf_device::LaunchOp>(
|
|
||||||
cluster.getLoc(), builder->getStringAttr(""), output_return_types);
|
|
||||||
|
|
||||||
// Replace all usages of outside compiled ops that are used in TPU
|
|
||||||
// computation with the results of the above created launch op.
|
|
||||||
for (auto output_and_index : llvm::enumerate(output_values)) {
|
|
||||||
auto output_index = output_and_index.index();
|
|
||||||
auto output = output_and_index.value();
|
|
||||||
for (auto& use : output.getUses()) {
|
|
||||||
if (!head_outside_compiled_ops.count(use.getOwner()))
|
|
||||||
use.set(host_launch_op.getResult(output_index));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create terminator op for the newly created launch op.
|
auto num_cores_per_replica_attr =
|
||||||
host_launch_op.body().push_back(new Block());
|
cluster.getAttrOfType<IntegerAttr>(tensorflow::kNumCoresPerReplicaAttr);
|
||||||
builder->setInsertionPointToEnd(&host_launch_op.GetBody());
|
|
||||||
auto terminator = builder->create<tf_device::ReturnOp>(
|
|
||||||
host_launch_op.getLoc(), output_values);
|
|
||||||
|
|
||||||
// Move all outside compile ops from cluster op to launch op.
|
|
||||||
for (auto outside_compiled_op : head_outside_compiled_ops)
|
|
||||||
outside_compiled_op->moveBefore(terminator);
|
|
||||||
|
|
||||||
return host_launch_op;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parses TPU compilation and execution device form tpu cluster and assigns
|
|
||||||
// host device to `host_launch` device attribute.
|
|
||||||
LogicalResult SetCompilationDeviceToHostLaunch(
|
|
||||||
OpBuilder* builder, mlir::TF::RuntimeDevices devices,
|
|
||||||
tf_device::ClusterOp tpu_cluster, tf_device::LaunchOp host_launch) {
|
|
||||||
auto num_cores_per_replica_attr = tpu_cluster.getAttrOfType<IntegerAttr>(
|
|
||||||
tensorflow::kNumCoresPerReplicaAttr);
|
|
||||||
if (!num_cores_per_replica_attr)
|
if (!num_cores_per_replica_attr)
|
||||||
return tpu_cluster.emitOpError(
|
return cluster.emitOpError(
|
||||||
"cluster op missing `num_cores_per_replica` attribute");
|
"cluster op missing `num_cores_per_replica` attribute");
|
||||||
|
|
||||||
if (num_cores_per_replica_attr.getInt() != 1)
|
if (num_cores_per_replica_attr.getInt() != 1)
|
||||||
return tpu_cluster.emitOpError(
|
return cluster.emitOpError(
|
||||||
"outside compilation is not supported with model parallelism.");
|
"outside compilation is not supported with model parallelism.");
|
||||||
|
|
||||||
auto topology_attr =
|
auto topology_attr =
|
||||||
tpu_cluster.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
|
cluster.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
|
||||||
if (!topology_attr)
|
if (!topology_attr)
|
||||||
return tpu_cluster.emitOpError("cluster op missing `topology` attribute");
|
return cluster.emitOpError("cluster op missing `topology` attribute");
|
||||||
|
|
||||||
auto device_assignment_attr = tpu_cluster.getAttrOfType<mlir::ArrayAttr>(
|
auto device_assignment_attr =
|
||||||
tensorflow::kDeviceAssignmentAttr);
|
cluster.getAttrOfType<mlir::ArrayAttr>(tensorflow::kDeviceAssignmentAttr);
|
||||||
if (!device_assignment_attr)
|
if (!device_assignment_attr)
|
||||||
return tpu_cluster.emitOpError(
|
return cluster.emitOpError(llvm::formatv("requires attribute '{0}'",
|
||||||
llvm::formatv("requires attribute '{0}'",
|
tensorflow::kDeviceAssignmentAttr)
|
||||||
tensorflow::kDeviceAssignmentAttr)
|
.str());
|
||||||
.str());
|
|
||||||
|
|
||||||
auto status_or_device_coodinates =
|
auto status_or_device_coodinates =
|
||||||
tensorflow::GetDeviceCoordinates(device_assignment_attr);
|
tensorflow::GetDeviceCoordinates(device_assignment_attr);
|
||||||
|
|
||||||
if (!status_or_device_coodinates.ok())
|
if (!status_or_device_coodinates.ok())
|
||||||
return tpu_cluster.emitError()
|
return cluster.emitError()
|
||||||
<< "error in fetching tpu device coordinates: "
|
<< "error in fetching tpu device coordinates: "
|
||||||
<< status_or_device_coodinates.status().error_message();
|
<< status_or_device_coodinates.status().error_message();
|
||||||
|
|
||||||
@ -236,37 +141,96 @@ LogicalResult SetCompilationDeviceToHostLaunch(
|
|||||||
/*num_cores_per_replica=*/1, topology_attr.getValue(),
|
/*num_cores_per_replica=*/1, topology_attr.getValue(),
|
||||||
status_or_device_coodinates.ConsumeValueOrDie());
|
status_or_device_coodinates.ConsumeValueOrDie());
|
||||||
if (!status_or_tpu_device_assignment.ok())
|
if (!status_or_tpu_device_assignment.ok())
|
||||||
return tpu_cluster.emitError()
|
return cluster.emitError()
|
||||||
<< "error in fetching TPU compilation/execution devices: "
|
<< "error in fetching TPU compilation/execution devices: "
|
||||||
<< status_or_tpu_device_assignment.status().error_message();
|
<< status_or_tpu_device_assignment.status().error_message();
|
||||||
auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
|
auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
|
||||||
host_launch.deviceAttr(
|
|
||||||
builder->getStringAttr(tpu_device_assignment.tpu_devices[0][0].host));
|
|
||||||
|
|
||||||
|
*host_device = tpu_device_assignment.tpu_devices[0][0].host;
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assigns host device attribute to host launch op or enclosing
|
// Moves head outside compiled ops into its own `tf_device.LaunchOp`
|
||||||
// tf_device.replicate op if TPU computation is replicated.
|
// computation.
|
||||||
LogicalResult HandleHostLaunchDeviceAssignment(
|
tf_device::LaunchOp CreateHeadComputation(
|
||||||
OpBuilder* builder, mlir::TF::RuntimeDevices devices,
|
OpBuilder* builder, tf_device::ClusterOp cluster,
|
||||||
tf_device::ClusterOp tpu_cluster, tf_device::LaunchOp host_launch) {
|
llvm::ArrayRef<Operation*> head_outside_compiled_ops,
|
||||||
auto parent_replicate_op =
|
llvm::StringRef host_device) {
|
||||||
llvm::dyn_cast_or_null<tf_device::ReplicateOp>(host_launch.getParentOp());
|
Block* launch_block = new Block;
|
||||||
// If computation is replicated, then add TPU_REPLICATED_HOST device alias
|
for (Operation* head_outside_compiled_op : head_outside_compiled_ops)
|
||||||
// to the host launch op. This device alias would later be a reference to
|
head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
|
||||||
// host device string in the device map of tf_device.replicate op
|
|
||||||
// during tpu_rewrite pass.
|
// Find results of ops in head computation that needs to returned.
|
||||||
if (parent_replicate_op) {
|
llvm::SmallVector<Value, 4> launch_results;
|
||||||
host_launch.deviceAttr(
|
llvm::SmallVector<Type, 4> launch_result_types;
|
||||||
builder->getStringAttr(tensorflow::kTPUReplicatedHost));
|
for (Operation& head_outside_compiled_op : *launch_block) {
|
||||||
} else {
|
for (Value result : head_outside_compiled_op.getResults()) {
|
||||||
if (failed(SetCompilationDeviceToHostLaunch(builder, devices, tpu_cluster,
|
bool has_uses_in_cluster = false;
|
||||||
host_launch)))
|
for (Operation* user : result.getUsers()) {
|
||||||
return failure();
|
if (user->getParentRegion() &&
|
||||||
|
cluster.body().isAncestor(user->getParentRegion())) {
|
||||||
|
has_uses_in_cluster = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (has_uses_in_cluster) {
|
||||||
|
launch_results.push_back(result);
|
||||||
|
launch_result_types.push_back(result.getType());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return success();
|
builder->setInsertionPoint(cluster);
|
||||||
|
auto launch = builder->create<tf_device::LaunchOp>(
|
||||||
|
cluster.getLoc(), builder->getStringAttr(host_device),
|
||||||
|
launch_result_types);
|
||||||
|
launch.body().push_back(launch_block);
|
||||||
|
|
||||||
|
builder->setInsertionPointToEnd(&launch.GetBody());
|
||||||
|
builder->create<tf_device::ReturnOp>(cluster.getLoc(), launch_results);
|
||||||
|
|
||||||
|
for (auto result : llvm::zip(launch_results, launch.getResults()))
|
||||||
|
replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result),
|
||||||
|
cluster.body());
|
||||||
|
|
||||||
|
return launch;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Removes aliased outputs in cluster from head computation after head
|
||||||
|
// computation has been extracted.
|
||||||
|
void RemoveHeadComputationAliasedOutputs(OpBuilder* builder,
|
||||||
|
tf_device::LaunchOp head_computation,
|
||||||
|
tf_device::ClusterOp cluster) {
|
||||||
|
llvm::SmallVector<Value, 4> used_old_cluster_results;
|
||||||
|
llvm::SmallVector<Value, 4> new_cluster_results;
|
||||||
|
llvm::SmallVector<Type, 4> new_cluster_result_types;
|
||||||
|
Operation* cluster_terminator = cluster.GetBody().getTerminator();
|
||||||
|
for (auto result :
|
||||||
|
llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) {
|
||||||
|
Value cluster_terminator_operand = std::get<0>(result);
|
||||||
|
if (cluster_terminator_operand.getDefiningOp() == head_computation) {
|
||||||
|
std::get<1>(result).replaceAllUsesWith(cluster_terminator_operand);
|
||||||
|
} else {
|
||||||
|
new_cluster_results.push_back(cluster_terminator_operand);
|
||||||
|
new_cluster_result_types.push_back(cluster_terminator_operand.getType());
|
||||||
|
used_old_cluster_results.push_back(std::get<1>(result));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (new_cluster_results.size() == cluster.getNumResults()) return;
|
||||||
|
|
||||||
|
builder->setInsertionPoint(cluster);
|
||||||
|
auto new_cluster = builder->create<tf_device::ClusterOp>(
|
||||||
|
cluster.getLoc(), new_cluster_result_types,
|
||||||
|
/*operands=*/llvm::ArrayRef<Value>{}, cluster.getAttrs());
|
||||||
|
new_cluster.body().takeBody(cluster.body());
|
||||||
|
new_cluster.GetBody().getTerminator()->setOperands(new_cluster_results);
|
||||||
|
|
||||||
|
for (auto result :
|
||||||
|
llvm::zip(used_old_cluster_results, new_cluster.getResults()))
|
||||||
|
std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
|
||||||
|
|
||||||
|
cluster.erase();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TPUExtractHeadTailOutsideCompilation
|
struct TPUExtractHeadTailOutsideCompilation
|
||||||
@ -283,22 +247,25 @@ void TPUExtractHeadTailOutsideCompilation::runOnOperation() {
|
|||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
||||||
OpBuilder builder(&getContext());
|
OpBuilder builder(&getContext());
|
||||||
auto result = module.walk([&](tf_device::ClusterOp cluster) {
|
llvm::SmallVector<tf_device::ClusterOp, 4> clusters;
|
||||||
auto head_outside_compiled_ops = IdentifyOutsideCompiledOpsAtHead(cluster);
|
module.walk(
|
||||||
auto host_launch_op = IsolateHeadExtractedOpsToLaunchOp(
|
[&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); });
|
||||||
&builder, cluster, head_outside_compiled_ops);
|
|
||||||
if (host_launch_op) {
|
|
||||||
if (failed(HandleHostLaunchDeviceAssignment(&builder, devices, cluster,
|
|
||||||
*host_launch_op))) {
|
|
||||||
return WalkResult::interrupt();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(b/155115766): Implement tail outside compiled op extraction.
|
for (tf_device::ClusterOp cluster : clusters) {
|
||||||
return WalkResult::advance();
|
llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
|
||||||
});
|
FindOutsideCompiledOpsAtHead(cluster);
|
||||||
|
if (head_outside_compiled_ops.empty()) continue;
|
||||||
|
std::string host_device;
|
||||||
|
if (failed(GetHostDeviceForHeadTailComputation(devices, cluster,
|
||||||
|
&host_device)))
|
||||||
|
return signalPassFailure();
|
||||||
|
|
||||||
if (result.wasInterrupted()) signalPassFailure();
|
tf_device::LaunchOp head_computation = CreateHeadComputation(
|
||||||
|
&builder, cluster, head_outside_compiled_ops, host_device);
|
||||||
|
RemoveHeadComputationAliasedOutputs(&builder, head_computation, cluster);
|
||||||
|
|
||||||
|
// TODO(b/157160906): Implement tail outside compiled op extraction.
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
Loading…
x
Reference in New Issue
Block a user