From 3e377e72fbdcb486e9f82e3d1af91a96c3459c41 Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Wed, 13 Mar 2019 14:37:22 -0700 Subject: [PATCH] Change tpu_system_metadata to work in 2.0 mode PiperOrigin-RevId: 238311803 --- tensorflow/python/tpu/tpu_system_metadata.py | 83 ++++++++++++-------- 1 file changed, 49 insertions(+), 34 deletions(-) diff --git a/tensorflow/python/tpu/tpu_system_metadata.py b/tensorflow/python/tpu/tpu_system_metadata.py index 30e3b8f0051..57518df0044 100644 --- a/tensorflow/python/tpu/tpu_system_metadata.py +++ b/tensorflow/python/tpu/tpu_system_metadata.py @@ -23,6 +23,7 @@ import re from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as session_lib +from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -34,6 +35,7 @@ _RETRY_TIMES = 120 _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins _TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$') +_DEVICE_TYPE_REGEX = re.compile('.*device:([^:]+).*') _DEFAULT_JOB_NAME = 'tpu_worker' _DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' @@ -57,41 +59,54 @@ def _query_tpu_system_metadata(master_address, cluster_def=None, devices = [] device_dict = collections.defaultdict(list) - # TODO(b/120564445): Replace with standard library for retries. - retry_count = 1 - while True: - logging.info('Querying Tensorflow master (%s) for TPU system metadata.', - master_address) - try: - with ops.Graph().as_default(): - with session_lib.Session( - master_address, - config=get_session_config_with_timeout( - _PINGING_MASTER_TIMEOUT_IN_MS, - cluster_def)) as sess: - devices = sess.list_devices() - for device in devices: - match = _TPU_DEVICE_REG.match(device.name) - if match: - host_id = match.group(1) - core_id = match.group(2) - device_dict[host_id].append(core_id) - tpu_core_count += 1 - break - except errors.DeadlineExceededError: - msg = ('Failed to connect to the Tensorflow master. The TPU worker may ' - 'not be ready (still scheduling) or the Tensorflow master address ' - 'is incorrect: got (%s).' % - (master_address)) + if context.executing_eagerly(): + device_names = context.list_devices() + devices = [] - # TODO(xiejw): For local or grpc master we might not need retry logic - # here. - if retry_count <= _RETRY_TIMES: - logging.warning('%s', msg) - logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES) - retry_count += 1 - else: - raise ValueError(msg) + # We want the output type to match in both eager and session mode + for name in device_names: + device_match = _DEVICE_TYPE_REGEX.match(name) + device_type = 'CPU' + if device_match: + device_type = device_match.group(1) + devices.append(session_lib._DeviceAttributes(name, device_type, 0, 0)) # pylint: disable=protected-access + else: + # TODO(b/120564445): Replace with standard library for retries. + retry_count = 1 + while True: + logging.info('Querying Tensorflow master (%s) for TPU system metadata.', + master_address) + try: + with ops.Graph().as_default(): + with session_lib.Session( + master_address, + config=get_session_config_with_timeout( + _PINGING_MASTER_TIMEOUT_IN_MS, + cluster_def)) as sess: + devices = sess.list_devices() + break + except errors.DeadlineExceededError: + msg = ('Failed to connect to the Tensorflow master. The TPU worker may ' + 'not be ready (still scheduling) or the Tensorflow master ' + 'address is incorrect: got (%s).' % + (master_address)) + + # TODO(xiejw): For local or grpc master we might not need retry logic + # here. + if retry_count <= _RETRY_TIMES: + logging.warning('%s', msg) + logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES) + retry_count += 1 + else: + raise ValueError(msg) + + for device in devices: + match = _TPU_DEVICE_REG.match(device.name) + if match: + host_id = match.group(1) + core_id = match.group(2) + device_dict[host_id].append(core_id) + tpu_core_count += 1 num_of_cores_per_host = 0 if tpu_core_count: