Change tpu_system_metadata to work in 2.0 mode

PiperOrigin-RevId: 238311803
This commit is contained in:
Sourabh Bajaj 2019-03-13 14:37:22 -07:00 committed by TensorFlower Gardener
parent 38d9711a73
commit 3e377e72fb

View File

@ -23,6 +23,7 @@ import re
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_lib from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -34,6 +35,7 @@ _RETRY_TIMES = 120
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$') _TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
_DEVICE_TYPE_REGEX = re.compile('.*device:([^:]+).*')
_DEFAULT_JOB_NAME = 'tpu_worker' _DEFAULT_JOB_NAME = 'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' _DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
@ -57,41 +59,54 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
devices = [] devices = []
device_dict = collections.defaultdict(list) device_dict = collections.defaultdict(list)
# TODO(b/120564445): Replace with standard library for retries. if context.executing_eagerly():
retry_count = 1 device_names = context.list_devices()
while True: devices = []
logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
master_address)
try:
with ops.Graph().as_default():
with session_lib.Session(
master_address,
config=get_session_config_with_timeout(
_PINGING_MASTER_TIMEOUT_IN_MS,
cluster_def)) as sess:
devices = sess.list_devices()
for device in devices:
match = _TPU_DEVICE_REG.match(device.name)
if match:
host_id = match.group(1)
core_id = match.group(2)
device_dict[host_id].append(core_id)
tpu_core_count += 1
break
except errors.DeadlineExceededError:
msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
'not be ready (still scheduling) or the Tensorflow master address '
'is incorrect: got (%s).' %
(master_address))
# TODO(xiejw): For local or grpc master we might not need retry logic # We want the output type to match in both eager and session mode
# here. for name in device_names:
if retry_count <= _RETRY_TIMES: device_match = _DEVICE_TYPE_REGEX.match(name)
logging.warning('%s', msg) device_type = 'CPU'
logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES) if device_match:
retry_count += 1 device_type = device_match.group(1)
else: devices.append(session_lib._DeviceAttributes(name, device_type, 0, 0)) # pylint: disable=protected-access
raise ValueError(msg) else:
# TODO(b/120564445): Replace with standard library for retries.
retry_count = 1
while True:
logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
master_address)
try:
with ops.Graph().as_default():
with session_lib.Session(
master_address,
config=get_session_config_with_timeout(
_PINGING_MASTER_TIMEOUT_IN_MS,
cluster_def)) as sess:
devices = sess.list_devices()
break
except errors.DeadlineExceededError:
msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
'not be ready (still scheduling) or the Tensorflow master '
'address is incorrect: got (%s).' %
(master_address))
# TODO(xiejw): For local or grpc master we might not need retry logic
# here.
if retry_count <= _RETRY_TIMES:
logging.warning('%s', msg)
logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES)
retry_count += 1
else:
raise ValueError(msg)
for device in devices:
match = _TPU_DEVICE_REG.match(device.name)
if match:
host_id = match.group(1)
core_id = match.group(2)
device_dict[host_id].append(core_id)
tpu_core_count += 1
num_of_cores_per_host = 0 num_of_cores_per_host = 0
if tpu_core_count: if tpu_core_count: