Add test for connecting to v2-32 Cloud TPU.

PiperOrigin-RevId: 297219348
Change-Id: I6072853736fc98badbca598850f850b9980cc7b0
This commit is contained in:
Revan Sopher 2020-02-25 15:31:27 -08:00 committed by TensorFlower Gardener
parent 5473ad460d
commit 8df05fc187
2 changed files with 44 additions and 15 deletions
tensorflow/python/eager

View File

@ -878,6 +878,24 @@ tpu_py_test(
],
)
tpu_py_test(
name = "remote_cloud_tpu_pod_test",
srcs = ["remote_cloud_tpu_test.py"],
args = ["--num_tpu_devices=32"],
main = "remote_cloud_tpu_test.py",
python_version = "PY3",
tags = [
"notap",
"tpu_pod",
],
deps = [
":context",
":remote",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/tpu:tpu_strategy_util",
],
)
cuda_py_test(
name = "device_placement_test",
size = "small",

View File

@ -31,24 +31,25 @@ flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.')
flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.')
flags.DEFINE_integer('num_tpu_devices', 8, 'The expected number of TPUs.')
DEVICES_PER_TASK = 8
EXPECTED_DEVICES_PRE_CONNECT = [
'/device:CPU:0',
'/device:XLA_CPU:0',
]
EXPECTED_DEVICES_AFTER_CONNECT = [
'/device:CPU:0',
'/device:XLA_CPU:0',
'/job:worker/replica:0/task:0/device:CPU:0',
'/job:worker/replica:0/task:0/device:XLA_CPU:0',
'/job:worker/replica:0/task:0/device:TPU_SYSTEM:0',
'/job:worker/replica:0/task:0/device:TPU:0',
'/job:worker/replica:0/task:0/device:TPU:1',
'/job:worker/replica:0/task:0/device:TPU:2',
'/job:worker/replica:0/task:0/device:TPU:3',
'/job:worker/replica:0/task:0/device:TPU:4',
'/job:worker/replica:0/task:0/device:TPU:5',
'/job:worker/replica:0/task:0/device:TPU:6',
'/job:worker/replica:0/task:0/device:TPU:7',
EXPECTED_NEW_DEVICES_AFTER_CONNECT_TEMPLATES = [
'/job:worker/replica:0/task:{task}/device:CPU:0',
'/job:worker/replica:0/task:{task}/device:XLA_CPU:0',
'/job:worker/replica:0/task:{task}/device:TPU_SYSTEM:0',
'/job:worker/replica:0/task:{task}/device:TPU:0',
'/job:worker/replica:0/task:{task}/device:TPU:1',
'/job:worker/replica:0/task:{task}/device:TPU:2',
'/job:worker/replica:0/task:{task}/device:TPU:3',
'/job:worker/replica:0/task:{task}/device:TPU:4',
'/job:worker/replica:0/task:{task}/device:TPU:5',
'/job:worker/replica:0/task:{task}/device:TPU:6',
'/job:worker/replica:0/task:{task}/device:TPU:7',
]
@ -56,6 +57,9 @@ class RemoteCloudTPUTest(absltest.TestCase):
"""Test that we can connect to a real Cloud TPU."""
def test_connect(self):
# Log full diff on failure.
self.maxDiff = None # pylint:disable=invalid-name
self.assertCountEqual(
EXPECTED_DEVICES_PRE_CONNECT,
[device.name for device in config.list_logical_devices()])
@ -65,8 +69,15 @@ class RemoteCloudTPUTest(absltest.TestCase):
)
remote.connect_to_cluster(resolver)
expected_devices = EXPECTED_DEVICES_PRE_CONNECT
for task in range(FLAGS.num_tpu_devices // DEVICES_PER_TASK):
expected_devices.extend([
template.format(task=task)
for template in EXPECTED_NEW_DEVICES_AFTER_CONNECT_TEMPLATES
])
self.assertCountEqual(
EXPECTED_DEVICES_AFTER_CONNECT,
expected_devices,
[device.name for device in config.list_logical_devices()])
tpu_strategy_util.initialize_tpu_system(resolver)