Change tpu_system_metadata to work in 2.0 mode
PiperOrigin-RevId: 238311803
This commit is contained in:
parent
38d9711a73
commit
3e377e72fb
@ -23,6 +23,7 @@ import re
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
@ -34,6 +35,7 @@ _RETRY_TIMES = 120
|
||||
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
|
||||
|
||||
_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
|
||||
_DEVICE_TYPE_REGEX = re.compile('.*device:([^:]+).*')
|
||||
|
||||
_DEFAULT_JOB_NAME = 'tpu_worker'
|
||||
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
|
||||
@ -57,41 +59,54 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
|
||||
devices = []
|
||||
device_dict = collections.defaultdict(list)
|
||||
|
||||
# TODO(b/120564445): Replace with standard library for retries.
|
||||
retry_count = 1
|
||||
while True:
|
||||
logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
|
||||
master_address)
|
||||
try:
|
||||
with ops.Graph().as_default():
|
||||
with session_lib.Session(
|
||||
master_address,
|
||||
config=get_session_config_with_timeout(
|
||||
_PINGING_MASTER_TIMEOUT_IN_MS,
|
||||
cluster_def)) as sess:
|
||||
devices = sess.list_devices()
|
||||
for device in devices:
|
||||
match = _TPU_DEVICE_REG.match(device.name)
|
||||
if match:
|
||||
host_id = match.group(1)
|
||||
core_id = match.group(2)
|
||||
device_dict[host_id].append(core_id)
|
||||
tpu_core_count += 1
|
||||
break
|
||||
except errors.DeadlineExceededError:
|
||||
msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
|
||||
'not be ready (still scheduling) or the Tensorflow master address '
|
||||
'is incorrect: got (%s).' %
|
||||
(master_address))
|
||||
if context.executing_eagerly():
|
||||
device_names = context.list_devices()
|
||||
devices = []
|
||||
|
||||
# TODO(xiejw): For local or grpc master we might not need retry logic
|
||||
# here.
|
||||
if retry_count <= _RETRY_TIMES:
|
||||
logging.warning('%s', msg)
|
||||
logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES)
|
||||
retry_count += 1
|
||||
else:
|
||||
raise ValueError(msg)
|
||||
# We want the output type to match in both eager and session mode
|
||||
for name in device_names:
|
||||
device_match = _DEVICE_TYPE_REGEX.match(name)
|
||||
device_type = 'CPU'
|
||||
if device_match:
|
||||
device_type = device_match.group(1)
|
||||
devices.append(session_lib._DeviceAttributes(name, device_type, 0, 0)) # pylint: disable=protected-access
|
||||
else:
|
||||
# TODO(b/120564445): Replace with standard library for retries.
|
||||
retry_count = 1
|
||||
while True:
|
||||
logging.info('Querying Tensorflow master (%s) for TPU system metadata.',
|
||||
master_address)
|
||||
try:
|
||||
with ops.Graph().as_default():
|
||||
with session_lib.Session(
|
||||
master_address,
|
||||
config=get_session_config_with_timeout(
|
||||
_PINGING_MASTER_TIMEOUT_IN_MS,
|
||||
cluster_def)) as sess:
|
||||
devices = sess.list_devices()
|
||||
break
|
||||
except errors.DeadlineExceededError:
|
||||
msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
|
||||
'not be ready (still scheduling) or the Tensorflow master '
|
||||
'address is incorrect: got (%s).' %
|
||||
(master_address))
|
||||
|
||||
# TODO(xiejw): For local or grpc master we might not need retry logic
|
||||
# here.
|
||||
if retry_count <= _RETRY_TIMES:
|
||||
logging.warning('%s', msg)
|
||||
logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES)
|
||||
retry_count += 1
|
||||
else:
|
||||
raise ValueError(msg)
|
||||
|
||||
for device in devices:
|
||||
match = _TPU_DEVICE_REG.match(device.name)
|
||||
if match:
|
||||
host_id = match.group(1)
|
||||
core_id = match.group(2)
|
||||
device_dict[host_id].append(core_id)
|
||||
tpu_core_count += 1
|
||||
|
||||
num_of_cores_per_host = 0
|
||||
if tpu_core_count:
|
||||
|
Loading…
Reference in New Issue
Block a user