Set host device for outside compilation LaunchOp.

PiperOrigin-RevId: 318092701
Change-Id: Ib8f9de05110030a54be12c36ba004939c4c9d832
This commit is contained in:
Ken Franko 2020-06-24 10:34:59 -07:00 committed by TensorFlower Gardener
parent fcad3004bc
commit 2434d24013
2 changed files with 424 additions and 415 deletions

View File

@ -2,18 +2,7 @@
// Tests that missing `_xla_outside_compilation` attribute value results in an error. // Tests that missing `_xla_outside_compilation` attribute value results in an error.
func @missing_outside_compilation_attribute() -> () { module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
// expected-error@+1 {{attribute '_xla_outside_compilation' is empty}}
"tf.B"() {_xla_outside_compilation = ""} : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// -----
// Tests that TPU cluster with no outside compilation does not generate parallel_execute. // Tests that TPU cluster with no outside compilation does not generate parallel_execute.
// CHECK-LABEL: func @no_outside_compilation // CHECK-LABEL: func @no_outside_compilation
@ -22,7 +11,7 @@ func @no_outside_compilation() -> tensor<?xi32> {
%1 = "tf.A"() : () -> tensor<?xi32> %1 = "tf.A"() : () -> tensor<?xi32>
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32> %2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
return %0 : tensor<?xi32> return %0 : tensor<?xi32>
} }
@ -36,15 +25,17 @@ func @nodep_single_outside_compilation() -> () {
// CHECK-NEXT: "tf_device.launch" // CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.B" // CHECK-NEXT: "tf.B"
// CHECK-NOT: _xla_outside_compilation // CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.cluster" // CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A" // CHECK-NEXT: "tf.A"
// CHECK: cluster_attr = "cluster_attr" // CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
"tf_device.cluster"() ( { "tf_device.cluster"() ( {
"tf.A"() : () -> () "tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() : () -> () "tf.C"() : () -> ()
tf_device.return tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> () }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
return return
} }
@ -61,7 +52,7 @@ func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
// CHECK: "tf_device.cluster" // CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A" // CHECK-NEXT: "tf.A"
// CHECK-NEXT: "tf.E" // CHECK-NEXT: "tf.E"
// CHECK: cluster_attr = "cluster_attr" // CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
"tf_device.cluster"() ( { "tf_device.cluster"() ( {
"tf.A"() : () -> () "tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
@ -69,7 +60,7 @@ func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
"tf.D"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.D"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.E"() : () -> () "tf.E"() : () -> ()
tf_device.return tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> () }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
return return
} }
@ -87,7 +78,7 @@ func @nodep_multiple_outside_compilation() -> () {
"tf.D"() {_xla_outside_compilation = "cluster2"} : () -> () "tf.D"() {_xla_outside_compilation = "cluster2"} : () -> ()
"tf.E"() : () -> () "tf.E"() : () -> ()
tf_device.return tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> () }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
return return
} }
@ -99,6 +90,9 @@ func @single_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tens
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch" // CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster" // CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster"
// CHECK: tf_device.return // CHECK: tf_device.return
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]] // CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
@ -109,7 +103,7 @@ func @single_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tens
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
%3 = "tf.C"() : () -> tensor<?xi32> %3 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %3 : tensor<?xi32> tf_device.return %3 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -134,7 +128,7 @@ func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> te
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
%5 = "tf.C"() : () -> tensor<?xi32> %5 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32> tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> (tensor<?xf32>, tensor<?xi32>) }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> (tensor<?xf32>, tensor<?xi32>)
tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32> tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32>
} }
@ -163,7 +157,7 @@ func @single_outside_compiled_input_single_outside_compilation(%arg0: tensor<?xi
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> () "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
%4 = "tf.C"() : () -> tensor<?xi32> %4 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32> tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -194,7 +188,7 @@ func @single_outside_compiled_output_single_outside_compilation(%arg0: tensor<?x
%4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<?xi32>) %4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<?xi32>)
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32> %5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32> tf_device.return %5 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -224,7 +218,7 @@ func @return_host_output_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>) %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5 = "tf.C"(%3) : (tensor<?xi32>) -> (tensor<?xi32>) %5 = "tf.C"(%3) : (tensor<?xi32>) -> (tensor<?xi32>)
tf_device.return %4 : tensor<?xi32> tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -255,7 +249,7 @@ func @single_outside_compiled_input_output_single_outside_compilation(%arg0: ten
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>) %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32> %5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32> tf_device.return %5 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -290,7 +284,7 @@ func @multiple_outside_compiled_input_output_single_outside_compilation(%arg0: t
%7 = "tf.D"(%5) : (tensor<?xi32>) -> tensor<?xi32> %7 = "tf.D"(%5) : (tensor<?xi32>) -> tensor<?xi32>
%8 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32> %8 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %8 : tensor<?xi32> tf_device.return %8 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -332,7 +326,7 @@ func @outside_compiled_input_output_multiple_outside_compilation(%arg0: tensor<?
%6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> (tensor<?xi32>) %6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> (tensor<?xi32>)
%7 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32> %7 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %7 : tensor<?xi32> tf_device.return %7 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -361,7 +355,7 @@ func @mixed_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi
"tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> () "tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
%4 = "tf.C"() : () -> tensor<?xi32> %4 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32> tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -399,7 +393,7 @@ func @single_outside_compiled_input_multiple_outside_compilation(%arg0: tensor<?
%4 = "tf.C"() : () -> tensor<?xi32> %4 = "tf.C"() : () -> tensor<?xi32>
"tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> () "tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> ()
tf_device.return %4 : tensor<?xi32> tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -432,7 +426,7 @@ func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor<
"tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> () "tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
%5 = "tf.E"() : () -> tensor<?xi32> %5 = "tf.E"() : () -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32> tf_device.return %5 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -454,8 +448,9 @@ func @remapped_results(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>) %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5:2 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>) %5:2 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
tf_device.return %5#0, %5#1 : tensor<?xi32>, tensor<?xi32> tf_device.return %5#0, %5#1 : tensor<?xi32>, tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> (tensor<?xi32>, tensor<?xi32>) }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> (tensor<?xi32>, tensor<?xi32>)
tf_device.return %2#1 : tensor<?xi32> tf_device.return %2#1 : tensor<?xi32>
} }
return %1 : tensor<?xi32> return %1 : tensor<?xi32>
} }
}

View File

@ -26,6 +26,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
namespace mlir { namespace mlir {
namespace TFTPU { namespace TFTPU {
@ -91,13 +93,14 @@ void MoveOutsideClusterOpsToLaunchOp(tf_device::LaunchOp launch_op,
// Creates a `tf_device::LaunchOp` to wrap cluster ops. // Creates a `tf_device::LaunchOp` to wrap cluster ops.
tf_device::LaunchOp CreateLaunchOpForOutsideCluster( tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
OpBuilder* builder, Operation* last_cluster_op) { OpBuilder* builder, Operation* last_cluster_op,
// TODO(b/154363171): Set the CPU device. llvm::StringRef host_device) {
// An empty string placeholder is used for the device as that will be later // An empty string placeholder is used for the device as that will be later
// populated with the device of the associated TPUReplicateMetadata op. // populated with the device of the associated TPUReplicateMetadata op.
llvm::SmallVector<Type, 8> result_types; llvm::SmallVector<Type, 8> result_types;
auto launch_op = builder->create<tf_device::LaunchOp>( auto launch_op = builder->create<tf_device::LaunchOp>(
last_cluster_op->getLoc(), builder->getStringAttr(""), result_types); last_cluster_op->getLoc(), builder->getStringAttr(host_device),
result_types);
launch_op.body().push_back(new Block); launch_op.body().push_back(new Block);
@ -253,8 +256,9 @@ void MoveOutsideCompiledOps(
// Creates a `parallel_execute` op in place of launch with 'clusters` and // Creates a `parallel_execute` op in place of launch with 'clusters` and
// 'launch` as regions. // 'launch` as regions.
void CreateParallelExecuteFromOutsideClusters( void CreateParallelExecuteFromOutsideClusters(tf_device::ClusterOp tpu_cluster,
tf_device::ClusterOp tpu_cluster, const OutsideClusterMap& clusters) { const OutsideClusterMap& clusters,
llvm::StringRef host_device) {
OpBuilder builder(tpu_cluster); OpBuilder builder(tpu_cluster);
// Create parallel_execute regions. The original TPU cluster computation // Create parallel_execute regions. The original TPU cluster computation
// is the extra region. // is the extra region.
@ -269,8 +273,8 @@ void CreateParallelExecuteFromOutsideClusters(
Block& outside_block = Block& outside_block =
parallel_execute_op.GetRegionBlockWithIndex(cluster.index()); parallel_execute_op.GetRegionBlockWithIndex(cluster.index());
builder.setInsertionPointToEnd(&outside_block); builder.setInsertionPointToEnd(&outside_block);
tf_device::LaunchOp host_launch_op = tf_device::LaunchOp host_launch_op = CreateLaunchOpForOutsideCluster(
CreateLaunchOpForOutsideCluster(&builder, cluster_ops.back()); &builder, cluster_ops.back(), host_device);
// Determine if there are any inputs that are provided out of cluster. // Determine if there are any inputs that are provided out of cluster.
auto external_inputs = GetExternalOperands(cluster_ops); auto external_inputs = GetExternalOperands(cluster_ops);
@ -307,8 +311,14 @@ void CreateParallelExecuteFromOutsideClusters(
} }
void TPUExtractOutsideCompilation::runOnOperation() { void TPUExtractOutsideCompilation::runOnOperation() {
// Get runtime devices information from the closest parent module.
auto module = getOperation();
mlir::TF::RuntimeDevices devices;
if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
return signalPassFailure();
auto extract_result = auto extract_result =
getOperation().walk([&](tf_device::ClusterOp tpu_cluster) { module.walk([&](tf_device::ClusterOp tpu_cluster) {
OutsideClusterMap clusters; OutsideClusterMap clusters;
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(), if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
&clusters))) &clusters)))
@ -316,7 +326,11 @@ void TPUExtractOutsideCompilation::runOnOperation() {
if (clusters.empty()) return WalkResult::advance(); if (clusters.empty()) return WalkResult::advance();
CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters); std::string host_device;
tensorflow::GetHostDeviceOutsideComputation(devices, tpu_cluster,
&host_device);
CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters,
host_device);
return WalkResult::advance(); return WalkResult::advance();
}); });