Change tpu_system_metadata to work in 2.0 mode
PiperOrigin-RevId: 238311803
This commit is contained in:
parent
38d9711a73
commit
3e377e72fb
@ -23,6 +23,7 @@ import re
|
|||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.client import session as session_lib
|
from tensorflow.python.client import session as session_lib
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import device as tf_device
|
from tensorflow.python.framework import device as tf_device
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -34,6 +35,7 @@ _RETRY_TIMES = 120
|
|||||||
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
|
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
|
||||||
|
|
||||||
_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
|
_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
|
||||||
|
_DEVICE_TYPE_REGEX = re.compile('.*device:([^:]+).*')
|
||||||
|
|
||||||
_DEFAULT_JOB_NAME = 'tpu_worker'
|
_DEFAULT_JOB_NAME = 'tpu_worker'
|
||||||
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
|
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
|
||||||
@ -57,6 +59,18 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
|
|||||||
devices = []
|
devices = []
|
||||||
device_dict = collections.defaultdict(list)
|
device_dict = collections.defaultdict(list)
|
||||||
|
|
||||||
|
if context.executing_eagerly():
|
||||||
|
device_names = context.list_devices()
|
||||||
|
devices = []
|
||||||
|
|
||||||
|
# We want the output type to match in both eager and session mode
|
||||||
|
for name in device_names:
|
||||||
|
device_match = _DEVICE_TYPE_REGEX.match(name)
|
||||||
|
device_type = 'CPU'
|
||||||
|
if device_match:
|
||||||
|
device_type = device_match.group(1)
|
||||||
|
devices.append(session_lib._DeviceAttributes(name, device_type, 0, 0)) # pylint: disable=protected-access
|
||||||
|
else:
|
||||||
# TODO(b/120564445): Replace with standard library for retries.
|
# TODO(b/120564445): Replace with standard library for retries.
|
||||||
retry_count = 1
|
retry_count = 1
|
||||||
while True:
|
while True:
|
||||||
@ -70,18 +84,11 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
|
|||||||
_PINGING_MASTER_TIMEOUT_IN_MS,
|
_PINGING_MASTER_TIMEOUT_IN_MS,
|
||||||
cluster_def)) as sess:
|
cluster_def)) as sess:
|
||||||
devices = sess.list_devices()
|
devices = sess.list_devices()
|
||||||
for device in devices:
|
|
||||||
match = _TPU_DEVICE_REG.match(device.name)
|
|
||||||
if match:
|
|
||||||
host_id = match.group(1)
|
|
||||||
core_id = match.group(2)
|
|
||||||
device_dict[host_id].append(core_id)
|
|
||||||
tpu_core_count += 1
|
|
||||||
break
|
break
|
||||||
except errors.DeadlineExceededError:
|
except errors.DeadlineExceededError:
|
||||||
msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
|
msg = ('Failed to connect to the Tensorflow master. The TPU worker may '
|
||||||
'not be ready (still scheduling) or the Tensorflow master address '
|
'not be ready (still scheduling) or the Tensorflow master '
|
||||||
'is incorrect: got (%s).' %
|
'address is incorrect: got (%s).' %
|
||||||
(master_address))
|
(master_address))
|
||||||
|
|
||||||
# TODO(xiejw): For local or grpc master we might not need retry logic
|
# TODO(xiejw): For local or grpc master we might not need retry logic
|
||||||
@ -93,6 +100,14 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
|
|||||||
else:
|
else:
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
for device in devices:
|
||||||
|
match = _TPU_DEVICE_REG.match(device.name)
|
||||||
|
if match:
|
||||||
|
host_id = match.group(1)
|
||||||
|
core_id = match.group(2)
|
||||||
|
device_dict[host_id].append(core_id)
|
||||||
|
tpu_core_count += 1
|
||||||
|
|
||||||
num_of_cores_per_host = 0
|
num_of_cores_per_host = 0
|
||||||
if tpu_core_count:
|
if tpu_core_count:
|
||||||
num_cores_per_host_set = set(
|
num_cores_per_host_set = set(
|
||||||
|
Loading…
Reference in New Issue
Block a user