Add test for connecting to v2-32 Cloud TPU.
PiperOrigin-RevId: 297219348 Change-Id: I6072853736fc98badbca598850f850b9980cc7b0
This commit is contained in:
parent
5473ad460d
commit
8df05fc187
@ -878,6 +878,24 @@ tpu_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tpu_py_test(
|
||||||
|
name = "remote_cloud_tpu_pod_test",
|
||||||
|
srcs = ["remote_cloud_tpu_test.py"],
|
||||||
|
args = ["--num_tpu_devices=32"],
|
||||||
|
main = "remote_cloud_tpu_test.py",
|
||||||
|
python_version = "PY3",
|
||||||
|
tags = [
|
||||||
|
"notap",
|
||||||
|
"tpu_pod",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":context",
|
||||||
|
":remote",
|
||||||
|
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||||
|
"//tensorflow/python/tpu:tpu_strategy_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "device_placement_test",
|
name = "device_placement_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -31,24 +31,25 @@ flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.')
|
|||||||
flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
|
flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
|
||||||
flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.')
|
flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.')
|
||||||
|
|
||||||
|
flags.DEFINE_integer('num_tpu_devices', 8, 'The expected number of TPUs.')
|
||||||
|
DEVICES_PER_TASK = 8
|
||||||
|
|
||||||
EXPECTED_DEVICES_PRE_CONNECT = [
|
EXPECTED_DEVICES_PRE_CONNECT = [
|
||||||
'/device:CPU:0',
|
'/device:CPU:0',
|
||||||
'/device:XLA_CPU:0',
|
'/device:XLA_CPU:0',
|
||||||
]
|
]
|
||||||
EXPECTED_DEVICES_AFTER_CONNECT = [
|
EXPECTED_NEW_DEVICES_AFTER_CONNECT_TEMPLATES = [
|
||||||
'/device:CPU:0',
|
'/job:worker/replica:0/task:{task}/device:CPU:0',
|
||||||
'/device:XLA_CPU:0',
|
'/job:worker/replica:0/task:{task}/device:XLA_CPU:0',
|
||||||
'/job:worker/replica:0/task:0/device:CPU:0',
|
'/job:worker/replica:0/task:{task}/device:TPU_SYSTEM:0',
|
||||||
'/job:worker/replica:0/task:0/device:XLA_CPU:0',
|
'/job:worker/replica:0/task:{task}/device:TPU:0',
|
||||||
'/job:worker/replica:0/task:0/device:TPU_SYSTEM:0',
|
'/job:worker/replica:0/task:{task}/device:TPU:1',
|
||||||
'/job:worker/replica:0/task:0/device:TPU:0',
|
'/job:worker/replica:0/task:{task}/device:TPU:2',
|
||||||
'/job:worker/replica:0/task:0/device:TPU:1',
|
'/job:worker/replica:0/task:{task}/device:TPU:3',
|
||||||
'/job:worker/replica:0/task:0/device:TPU:2',
|
'/job:worker/replica:0/task:{task}/device:TPU:4',
|
||||||
'/job:worker/replica:0/task:0/device:TPU:3',
|
'/job:worker/replica:0/task:{task}/device:TPU:5',
|
||||||
'/job:worker/replica:0/task:0/device:TPU:4',
|
'/job:worker/replica:0/task:{task}/device:TPU:6',
|
||||||
'/job:worker/replica:0/task:0/device:TPU:5',
|
'/job:worker/replica:0/task:{task}/device:TPU:7',
|
||||||
'/job:worker/replica:0/task:0/device:TPU:6',
|
|
||||||
'/job:worker/replica:0/task:0/device:TPU:7',
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -56,6 +57,9 @@ class RemoteCloudTPUTest(absltest.TestCase):
|
|||||||
"""Test that we can connect to a real Cloud TPU."""
|
"""Test that we can connect to a real Cloud TPU."""
|
||||||
|
|
||||||
def test_connect(self):
|
def test_connect(self):
|
||||||
|
# Log full diff on failure.
|
||||||
|
self.maxDiff = None # pylint:disable=invalid-name
|
||||||
|
|
||||||
self.assertCountEqual(
|
self.assertCountEqual(
|
||||||
EXPECTED_DEVICES_PRE_CONNECT,
|
EXPECTED_DEVICES_PRE_CONNECT,
|
||||||
[device.name for device in config.list_logical_devices()])
|
[device.name for device in config.list_logical_devices()])
|
||||||
@ -65,8 +69,15 @@ class RemoteCloudTPUTest(absltest.TestCase):
|
|||||||
)
|
)
|
||||||
remote.connect_to_cluster(resolver)
|
remote.connect_to_cluster(resolver)
|
||||||
|
|
||||||
|
expected_devices = EXPECTED_DEVICES_PRE_CONNECT
|
||||||
|
for task in range(FLAGS.num_tpu_devices // DEVICES_PER_TASK):
|
||||||
|
expected_devices.extend([
|
||||||
|
template.format(task=task)
|
||||||
|
for template in EXPECTED_NEW_DEVICES_AFTER_CONNECT_TEMPLATES
|
||||||
|
])
|
||||||
|
|
||||||
self.assertCountEqual(
|
self.assertCountEqual(
|
||||||
EXPECTED_DEVICES_AFTER_CONNECT,
|
expected_devices,
|
||||||
[device.name for device in config.list_logical_devices()])
|
[device.name for device in config.list_logical_devices()])
|
||||||
|
|
||||||
tpu_strategy_util.initialize_tpu_system(resolver)
|
tpu_strategy_util.initialize_tpu_system(resolver)
|
||||||
|
Loading…
Reference in New Issue
Block a user