Update TPU rewrite pass to populate replicated host devices on tf_device.replicate.
Replicated host devices under data parallelism may be necessary if outside compilation is present. PiperOrigin-RevId: 311819706 Change-Id: Iad2775559374d481e3b39ba1a8681f660ee6787e
This commit is contained in:
parent
a133be3d31
commit
22608ca0c2
@ -747,7 +747,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
// -----
|
||||
|
||||
// Tests simple case of `tf_device.cluster_func` on TPU with replication.
|
||||
// Tests simple case of `tf_device.cluster_func` on TPU with replication. Under
|
||||
// data parallelism replicated host devices are also added to the
|
||||
// tf_device.replicate
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
|
||||
// CHECK-LABEL: func @replicated_tpu_cluster_func
|
||||
@ -758,7 +760,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK-SAME: ([%[[A_OUTPUT]], %[[ARG_0]]] as %[[RI_0:[a-z0-9]*]]: tensor<?xi32>)
|
||||
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]}
|
||||
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]}
|
||||
// CHECK-SAME: n = 2
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
// CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]])
|
||||
@ -1222,7 +1224,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
|
||||
// -----
|
||||
|
||||
// Tests simple case of `tf_device.cluster_func` on TPU with replication and parallel_execute.
|
||||
// Tests simple case of `tf_device.cluster_func` on TPU with replication and
|
||||
// parallel_execute.
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} {
|
||||
// CHECK-LABEL: func @replicated_parallel_tpu_cluster_func
|
||||
@ -1240,7 +1243,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
tf_device.return
|
||||
}, {
|
||||
%4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor<?xi32>) -> tensor<?xi32>
|
||||
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) : () -> (tensor<?xi32>)
|
||||
tf_device.return %3 : tensor<?xi32>
|
||||
@ -1317,15 +1319,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
// "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01"
|
||||
// -----
|
||||
|
||||
// Tests devices are set properly for replicated model parallelism.
|
||||
// Tests devices are set properly for replicated model parallelism. No
|
||||
// replicated host device should be present.
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
|
||||
// CHECK-LABEL: func @replicated_parallel_execute
|
||||
func @replicated_parallel_execute(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<8xi32>, tensor<8xi32>) {
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: devices =
|
||||
// CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"]
|
||||
// CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]
|
||||
// CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]}
|
||||
%0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} {
|
||||
// CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"
|
||||
// CHECK-NEXT: "tf._TPUCompileMlir"()
|
||||
@ -1357,8 +1358,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that inputs are inputs with maximal and replicate sharding are set properly
|
||||
// for replicated model parallelism.
|
||||
// Tests that inputs are inputs with maximal and replicate sharding are set
|
||||
// properly for replicated model parallelism.
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
|
||||
// CHECK-LABEL: func @parallel_execute_with_input_with_sharding_configurations
|
||||
@ -1392,8 +1393,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
|
||||
// -----
|
||||
|
||||
// Tests devices are set properly for replicated model parallelism with
|
||||
// outputs to TPU computation placed on logical device 0.
|
||||
// Tests devices are set properly for replicated model parallelism with outputs
|
||||
// to TPU computation placed on logical device 0.
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} {
|
||||
// CHECK-LABEL: func @parallel_execute_with_different_outputs
|
||||
@ -1469,8 +1470,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc
|
||||
|
||||
// -----
|
||||
|
||||
// Tests inputs are correctly split and fed into TPU computation for
|
||||
// tiled input sharding.
|
||||
// Tests inputs are correctly split and fed into TPU computation for tiled input
|
||||
// sharding.
|
||||
|
||||
// The following OpSharding is used for TPU computation inputs in below test:
|
||||
// Proto debug string:
|
||||
|
@ -437,6 +437,18 @@ void AssignDevicesToReplicate(
|
||||
builder->getStrArrayAttr(devices_by_core)));
|
||||
}
|
||||
|
||||
// For data parallelism, also add replicated host devices, as these are
|
||||
// necessary for outside compilation.
|
||||
if (num_cores_per_replica == 1) {
|
||||
llvm::SmallVector<StringRef, 8> hosts;
|
||||
hosts.reserve(num_replicas);
|
||||
for (int replica = 0; replica < num_replicas; ++replica)
|
||||
hosts.push_back(tpu_devices[replica][0].host);
|
||||
|
||||
device_attrs.push_back(builder->getNamedAttr(
|
||||
tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts)));
|
||||
}
|
||||
|
||||
replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs));
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user