diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index b8a48bbb379..332b46f427f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -747,7 +747,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests simple case of `tf_device.cluster_func` on TPU with replication. +// Tests simple case of `tf_device.cluster_func` on TPU with replication. Under +// data parallelism replicated host devices are also added to the +// tf_device.replicate module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { // CHECK-LABEL: func @replicated_tpu_cluster_func @@ -758,7 +760,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK-SAME: ([%[[A_OUTPUT]], %[[ARG_0]]] as %[[RI_0:[a-z0-9]*]]: tensor) - // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"], TPU_REPLICATED_HOST = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:CPU:0"]} // CHECK-SAME: n = 2 %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[RI_0]]) @@ -1222,7 +1224,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests simple case of `tf_device.cluster_func` on TPU with replication and parallel_execute. +// Tests simple case of `tf_device.cluster_func` on TPU with replication and +// parallel_execute. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1"]} { // CHECK-LABEL: func @replicated_parallel_tpu_cluster_func @@ -1240,7 +1243,6 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor tf_device.return }, { %4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"]} : (tensor) -> tensor - tf_device.return %4 : tensor }) : () -> (tensor) tf_device.return %3 : tensor @@ -1317,15 +1319,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01" // ----- -// Tests devices are set properly for replicated model parallelism. +// Tests devices are set properly for replicated model parallelism. No +// replicated host device should be present. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { // CHECK-LABEL: func @replicated_parallel_execute func @replicated_parallel_execute(%arg0: tensor<8xi32>, %arg1: tensor<8xi32>) -> (tensor<8xi32>, tensor<8xi32>) { // CHECK: tf_device.replicate - // CHECK-SAME: devices = - // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] - // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + // CHECK-SAME: devices = {TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"], TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"]} %0:2 = tf_device.replicate([%arg0, %arg1] as %ri: tensor<8xi32>) {n = 2 : i32} { // CHECK-NEXT: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"() @@ -1357,8 +1358,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- -// Tests that inputs are inputs with maximal and replicate sharding are set properly -// for replicated model parallelism. +// Tests that inputs are inputs with maximal and replicate sharding are set +// properly for replicated model parallelism. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { // CHECK-LABEL: func @parallel_execute_with_input_with_sharding_configurations @@ -1392,8 +1393,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- -// Tests devices are set properly for replicated model parallelism with -// outputs to TPU computation placed on logical device 0. +// Tests devices are set properly for replicated model parallelism with outputs +// to TPU computation placed on logical device 0. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { // CHECK-LABEL: func @parallel_execute_with_different_outputs @@ -1469,8 +1470,8 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // ----- -// Tests inputs are correctly split and fed into TPU computation for -// tiled input sharding. +// Tests inputs are correctly split and fed into TPU computation for tiled input +// sharding. // The following OpSharding is used for TPU computation inputs in below test: // Proto debug string: diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 986736a9502..a7ad6a964b9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -437,6 +437,18 @@ void AssignDevicesToReplicate( builder->getStrArrayAttr(devices_by_core))); } + // For data parallelism, also add replicated host devices, as these are + // necessary for outside compilation. + if (num_cores_per_replica == 1) { + llvm::SmallVector hosts; + hosts.reserve(num_replicas); + for (int replica = 0; replica < num_replicas; ++replica) + hosts.push_back(tpu_devices[replica][0].host); + + device_attrs.push_back(builder->getNamedAttr( + tensorflow::kTPUReplicatedHost, builder->getStrArrayAttr(hosts))); + } + replicate.setAttr(kDevicesAttr, builder->getDictionaryAttr(device_attrs)); }