Add test for connecting to v2-32 Cloud TPU.
PiperOrigin-RevId: 297219348 Change-Id: I6072853736fc98badbca598850f850b9980cc7b0
This commit is contained in:
parent
5473ad460d
commit
8df05fc187
tensorflow/python/eager
@ -878,6 +878,24 @@ tpu_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tpu_py_test(
|
||||
name = "remote_cloud_tpu_pod_test",
|
||||
srcs = ["remote_cloud_tpu_test.py"],
|
||||
args = ["--num_tpu_devices=32"],
|
||||
main = "remote_cloud_tpu_test.py",
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"notap",
|
||||
"tpu_pod",
|
||||
],
|
||||
deps = [
|
||||
":context",
|
||||
":remote",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/tpu:tpu_strategy_util",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "device_placement_test",
|
||||
size = "small",
|
||||
|
@ -31,24 +31,25 @@ flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.')
|
||||
flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
|
||||
flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.')
|
||||
|
||||
flags.DEFINE_integer('num_tpu_devices', 8, 'The expected number of TPUs.')
|
||||
DEVICES_PER_TASK = 8
|
||||
|
||||
EXPECTED_DEVICES_PRE_CONNECT = [
|
||||
'/device:CPU:0',
|
||||
'/device:XLA_CPU:0',
|
||||
]
|
||||
EXPECTED_DEVICES_AFTER_CONNECT = [
|
||||
'/device:CPU:0',
|
||||
'/device:XLA_CPU:0',
|
||||
'/job:worker/replica:0/task:0/device:CPU:0',
|
||||
'/job:worker/replica:0/task:0/device:XLA_CPU:0',
|
||||
'/job:worker/replica:0/task:0/device:TPU_SYSTEM:0',
|
||||
'/job:worker/replica:0/task:0/device:TPU:0',
|
||||
'/job:worker/replica:0/task:0/device:TPU:1',
|
||||
'/job:worker/replica:0/task:0/device:TPU:2',
|
||||
'/job:worker/replica:0/task:0/device:TPU:3',
|
||||
'/job:worker/replica:0/task:0/device:TPU:4',
|
||||
'/job:worker/replica:0/task:0/device:TPU:5',
|
||||
'/job:worker/replica:0/task:0/device:TPU:6',
|
||||
'/job:worker/replica:0/task:0/device:TPU:7',
|
||||
EXPECTED_NEW_DEVICES_AFTER_CONNECT_TEMPLATES = [
|
||||
'/job:worker/replica:0/task:{task}/device:CPU:0',
|
||||
'/job:worker/replica:0/task:{task}/device:XLA_CPU:0',
|
||||
'/job:worker/replica:0/task:{task}/device:TPU_SYSTEM:0',
|
||||
'/job:worker/replica:0/task:{task}/device:TPU:0',
|
||||
'/job:worker/replica:0/task:{task}/device:TPU:1',
|
||||
'/job:worker/replica:0/task:{task}/device:TPU:2',
|
||||
'/job:worker/replica:0/task:{task}/device:TPU:3',
|
||||
'/job:worker/replica:0/task:{task}/device:TPU:4',
|
||||
'/job:worker/replica:0/task:{task}/device:TPU:5',
|
||||
'/job:worker/replica:0/task:{task}/device:TPU:6',
|
||||
'/job:worker/replica:0/task:{task}/device:TPU:7',
|
||||
]
|
||||
|
||||
|
||||
@ -56,6 +57,9 @@ class RemoteCloudTPUTest(absltest.TestCase):
|
||||
"""Test that we can connect to a real Cloud TPU."""
|
||||
|
||||
def test_connect(self):
|
||||
# Log full diff on failure.
|
||||
self.maxDiff = None # pylint:disable=invalid-name
|
||||
|
||||
self.assertCountEqual(
|
||||
EXPECTED_DEVICES_PRE_CONNECT,
|
||||
[device.name for device in config.list_logical_devices()])
|
||||
@ -65,8 +69,15 @@ class RemoteCloudTPUTest(absltest.TestCase):
|
||||
)
|
||||
remote.connect_to_cluster(resolver)
|
||||
|
||||
expected_devices = EXPECTED_DEVICES_PRE_CONNECT
|
||||
for task in range(FLAGS.num_tpu_devices // DEVICES_PER_TASK):
|
||||
expected_devices.extend([
|
||||
template.format(task=task)
|
||||
for template in EXPECTED_NEW_DEVICES_AFTER_CONNECT_TEMPLATES
|
||||
])
|
||||
|
||||
self.assertCountEqual(
|
||||
EXPECTED_DEVICES_AFTER_CONNECT,
|
||||
expected_devices,
|
||||
[device.name for device in config.list_logical_devices()])
|
||||
|
||||
tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
|
Loading…
Reference in New Issue
Block a user