diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 7aef5da11f2..a6bc0b7ec56 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -878,6 +878,24 @@ tpu_py_test( ], ) +tpu_py_test( + name = "remote_cloud_tpu_pod_test", + srcs = ["remote_cloud_tpu_test.py"], + args = ["--num_tpu_devices=32"], + main = "remote_cloud_tpu_test.py", + python_version = "PY3", + tags = [ + "notap", + "tpu_pod", + ], + deps = [ + ":context", + ":remote", + "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", + "//tensorflow/python/tpu:tpu_strategy_util", + ], +) + cuda_py_test( name = "device_placement_test", size = "small", diff --git a/tensorflow/python/eager/remote_cloud_tpu_test.py b/tensorflow/python/eager/remote_cloud_tpu_test.py index d63a8924bc8..8ba11a3e6ac 100644 --- a/tensorflow/python/eager/remote_cloud_tpu_test.py +++ b/tensorflow/python/eager/remote_cloud_tpu_test.py @@ -31,24 +31,25 @@ flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.') flags.DEFINE_string('project', None, 'Name of GCP project with TPU.') flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.') +flags.DEFINE_integer('num_tpu_devices', 8, 'The expected number of TPUs.') +DEVICES_PER_TASK = 8 + EXPECTED_DEVICES_PRE_CONNECT = [ '/device:CPU:0', '/device:XLA_CPU:0', ] -EXPECTED_DEVICES_AFTER_CONNECT = [ - '/device:CPU:0', - '/device:XLA_CPU:0', - '/job:worker/replica:0/task:0/device:CPU:0', - '/job:worker/replica:0/task:0/device:XLA_CPU:0', - '/job:worker/replica:0/task:0/device:TPU_SYSTEM:0', - '/job:worker/replica:0/task:0/device:TPU:0', - '/job:worker/replica:0/task:0/device:TPU:1', - '/job:worker/replica:0/task:0/device:TPU:2', - '/job:worker/replica:0/task:0/device:TPU:3', - '/job:worker/replica:0/task:0/device:TPU:4', - '/job:worker/replica:0/task:0/device:TPU:5', - '/job:worker/replica:0/task:0/device:TPU:6', - '/job:worker/replica:0/task:0/device:TPU:7', +EXPECTED_NEW_DEVICES_AFTER_CONNECT_TEMPLATES = [ + '/job:worker/replica:0/task:{task}/device:CPU:0', + '/job:worker/replica:0/task:{task}/device:XLA_CPU:0', + '/job:worker/replica:0/task:{task}/device:TPU_SYSTEM:0', + '/job:worker/replica:0/task:{task}/device:TPU:0', + '/job:worker/replica:0/task:{task}/device:TPU:1', + '/job:worker/replica:0/task:{task}/device:TPU:2', + '/job:worker/replica:0/task:{task}/device:TPU:3', + '/job:worker/replica:0/task:{task}/device:TPU:4', + '/job:worker/replica:0/task:{task}/device:TPU:5', + '/job:worker/replica:0/task:{task}/device:TPU:6', + '/job:worker/replica:0/task:{task}/device:TPU:7', ] @@ -56,6 +57,9 @@ class RemoteCloudTPUTest(absltest.TestCase): """Test that we can connect to a real Cloud TPU.""" def test_connect(self): + # Log full diff on failure. + self.maxDiff = None # pylint:disable=invalid-name + self.assertCountEqual( EXPECTED_DEVICES_PRE_CONNECT, [device.name for device in config.list_logical_devices()]) @@ -65,8 +69,15 @@ class RemoteCloudTPUTest(absltest.TestCase): ) remote.connect_to_cluster(resolver) + expected_devices = EXPECTED_DEVICES_PRE_CONNECT + for task in range(FLAGS.num_tpu_devices // DEVICES_PER_TASK): + expected_devices.extend([ + template.format(task=task) + for template in EXPECTED_NEW_DEVICES_AFTER_CONNECT_TEMPLATES + ]) + self.assertCountEqual( - EXPECTED_DEVICES_AFTER_CONNECT, + expected_devices, [device.name for device in config.list_logical_devices()]) tpu_strategy_util.initialize_tpu_system(resolver)