Update TPU rewrite pass to populate replicated host devices on tf_device.replicate.

Replicated host devices under data parallelism may be necessary if outside compilation is present.

PiperOrigin-RevId: 311819706
Change-Id: Iad2775559374d481e3b39ba1a8681f660ee6787e
This commit is contained in:
Andy Ly 2020-05-15 16:13:31 -07:00 committed by TensorFlower Gardener
parent a133be3d31
commit 22608ca0c2
2 changed files with 27 additions and 14 deletions

View File

@ -747,7 +747,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// -----
// Tests simple case of `tf_device.cluster_func` on TPU with replication.
// Tests simple case of `tf_device.cluster_func` on TPU with replication. Under
// data parallelism replicated host devices are also added to the
// tf_device.replicate
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
// CHECK-LABEL: func @replicated_tpu_cluster_func
@ -758,7 +760,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK-SAME: ([%[[A_OUTPUT]], %[[ARG_0]]] as %[[RI_0:[a-z0-9]*]]: tensor<?xi32>)
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]}
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}
// CHECK-SAME: n = 2
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]])
@ -1222,7 +1224,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// -----
// Tests simple case of `tf_device.cluster_func` on TPU with replication and parallel_execute.
// Tests simple case of `tf_device.cluster_func` on TPU with replication and
// parallel_execute.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
// CHECK-LABEL: func @replicated_parallel_tpu_cluster_func
@ -1240,7 +1243,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
tf_device.return
}, {
%4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32>
}) : () -> (tensor<?xi32>)
tf_device.return %3 : tensor<?xi32>
@ -1317,15 +1319,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
// "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01"
// -----
// Tests devices are set properly for replicated model parallelism.
// Tests devices are set properly for replicated model parallelism. No
// replicated host device should be present.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @replicated_parallel_execute
func @replicated_parallel_execute(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<8xi32>, tensor<8xi32>) {
// CHECK: tf_device.replicate
// CHECK-SAME: devices =
// CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
// CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]}
%0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} {
// CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
// CHECK-NEXT: "tf._TPUCompileMlir"()
@ -1357,8 +1358,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
// -----
// Tests that inputs are inputs with maximal and replicate sharding are set properly
// for replicated model parallelism.
// Tests that inputs are inputs with maximal and replicate sharding are set
// properly for replicated model parallelism.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_input_with_sharding_configurations
@ -1392,8 +1393,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
// -----
// Tests devices are set properly for replicated model parallelism with
// outputs to TPU computation placed on logical device 0.
// Tests devices are set properly for replicated model parallelism with outputs
// to TPU computation placed on logical device 0.
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
// CHECK-LABEL: func @parallel_execute_with_different_outputs
@ -1469,8 +1470,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
// -----
// Tests inputs are correctly split and fed into TPU computation for
// tiled input sharding.
// Tests inputs are correctly split and fed into TPU computation for tiled input
// sharding.
// The following OpSharding is used for TPU computation inputs in below test:
// Proto debug string:

View File

@ -437,6 +437,18 @@ void AssignDevicesToReplicate(
builder->getStrArrayAttr(devices_by_core)));
}
// For data parallelism, also add replicated host devices, as these are
// necessary for outside compilation.
if (num_cores_per_replica == 1) {
llvm::SmallVector<StringRef, 8> hosts;
hosts.reserve(num_replicas);
for (int replica = 0; replica < num_replicas; ++replica)
hosts.push_back(tpu_devices[replica][0].host);
device_attrs.push_back(builder->getNamedAttr(
tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts)));
}
replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
}