Change tpu_system_metadata to work in 2.0 mode

PiperOrigin-RevId: 238311803
This commit is contained in:
Sourabh Bajaj 2019-03-13 14:37:22 -07:00 committed by TensorFlower Gardener
parent 38d9711a73
commit 3e377e72fb

View File

@ -23,6 +23,7 @@ import re
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@ -34,6 +35,7 @@ _RETRY_TIMES = 120
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
_DEVICE_TYPE_REGEX = re.compile('.*device:([^:]+).*')
_DEFAULT_JOB_NAME = 'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
@ -57,6 +59,18 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
devices = []
device_dict = collections.defaultdict(list)
if context.executing_eagerly():
device_names = context.list_devices()
devices = []
# We want the output type to match in both eager and session mode
for name in device_names:
device_match = _DEVICE_TYPE_REGEX.match(name)
device_type = 'CPU'
if device_match:
device_type = device_match.group(1)
devices.append(session_lib._DeviceAttributes(name, device_type, 0, 0)) # pylint: disable=protected-access
else:
# TODO(b/120564445): Replace with standard library for retries.
retry_count = 1
while True:
@ -70,18 +84,11 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
_PINGING_MASTER_TIMEOUT_IN_MS,
cluster_def)) as sess:
devices = sess.list_devices()
for device in devices:
match = _TPU_DEVICE_REG.match(device.name)
if match:
host_id = match.group(1)
core_id = match.group(2)
device_dict[host_id].append(core_id)
tpu_core_count += 1
break
except errors.DeadlineExceededError:
msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
'not be ready (still scheduling) or the Tensorflow master address '
'is incorrect: got (%s).' %
'not be ready (still scheduling) or the Tensorflow master '
'address is incorrect: got (%s).' %
(master_address))
# TODO(xiejw): For local or grpc master we might not need retry logic
@ -93,6 +100,14 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
else:
raise ValueError(msg)
for device in devices:
match = _TPU_DEVICE_REG.match(device.name)
if match:
host_id = match.group(1)
core_id = match.group(2)
device_dict[host_id].append(core_id)
tpu_core_count += 1
num_of_cores_per_host = 0
if tpu_core_count:
num_cores_per_host_set = set(