diff --git a/tensorflow/python/distribute/device_util.py b/tensorflow/python/distribute/device_util.py index d1295f27019..db6009d1a45 100644 --- a/tensorflow/python/distribute/device_util.py +++ b/tensorflow/python/distribute/device_util.py @@ -38,13 +38,17 @@ def canonicalize(d, default=None): Note: This uses "job:localhost" as the default if executing eagerly. Args: - d: a device string. + d: a device string or tf.config.LogicalDevice default: a string for default device if d doesn't have all components. Returns: a canonicalized device string. """ - d = tf_device.DeviceSpec.from_string(d) + if isinstance(d, context.LogicalDevice): + d = tf_device.DeviceSpec.from_string(d.name) + else: + d = tf_device.DeviceSpec.from_string(d) + assert d.device_type is None or d.device_type == d.device_type.upper(), ( "Device type '%s' must be all-caps." % (d.device_type,)) # Fill in missing device fields using defaults. diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 38f69dba74e..0fb8ae0aafb 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -38,6 +38,7 @@ from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolve from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import tape +from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import dtypes @@ -206,7 +207,8 @@ def _is_device_list_single_worker(devices): """Checks whether the devices list is for single or multi-worker. Args: - devices: a list of device strings, either local or for remote devices. + devices: a list of device strings or tf.config.LogicalDevice objects, for + either local or for remote devices. Returns: a boolean indicating whether these device strings are for local or for @@ -215,7 +217,10 @@ def _is_device_list_single_worker(devices): Raises: ValueError: if device strings are not consistent. """ - specs = (tf_device.DeviceSpec.from_string(d) for d in devices) + specs = [] + for d in devices: + name = d.name if isinstance(d, context.LogicalDevice) else d + specs.append(tf_device.DeviceSpec.from_string(name)) num_workers = len({(d.job, d.task, d.replica) for d in specs}) all_local = all(d.job in (None, "localhost") for d in specs) any_local = any(d.job in (None, "localhost") for d in specs) @@ -321,9 +326,10 @@ def _infer_num_gpus_per_worker(devices): def all_local_devices(num_gpus=None): - if num_gpus is None: - num_gpus = context.num_gpus() - return device_util.local_devices_from_num_gpus(num_gpus) + devices = config.list_logical_devices("GPU") + if num_gpus is not None: + devices = devices[:num_gpus] + return devices or config.list_logical_devices("CPU") def all_devices(): diff --git a/tensorflow/python/distribute/remote_mirrored_strategy_eager_test.py b/tensorflow/python/distribute/remote_mirrored_strategy_eager_test.py index 1389ec393e0..36ec919a575 100644 --- a/tensorflow/python/distribute/remote_mirrored_strategy_eager_test.py +++ b/tensorflow/python/distribute/remote_mirrored_strategy_eager_test.py @@ -30,7 +30,7 @@ def get_gpus(): gpus = context.context().list_logical_devices("GPU") actual_gpus = [] for gpu in gpus: - if "localhost" not in gpu.name: + if "job" in gpu.name: actual_gpus.append(gpu.name) return actual_gpus diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 2757cc4667d..dbcdd4a83d6 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -471,8 +471,12 @@ class Context(object): dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i) self._context_devices.append(pydev.canonical_name(dev_name)) spec = pydev.DeviceSpec.from_string(dev_name) + # If the job is localhost, we assume that the cluster has not yet been + # configured and thus clear the job, replica & task. + if spec.job == "localhost": + spec = spec.replace(job=None, replica=None, task=None) self._logical_devices.append( - LogicalDevice(name=dev_name, device_type=spec.device_type)) + LogicalDevice(name=spec.to_string(), device_type=spec.device_type)) dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i) if dev_type == "GPU": self._num_gpus += 1 diff --git a/tensorflow/python/tpu/tpu_system_metadata.py b/tensorflow/python/tpu/tpu_system_metadata.py index 8628feee418..1998e0e0aeb 100644 --- a/tensorflow/python/tpu/tpu_system_metadata.py +++ b/tensorflow/python/tpu/tpu_system_metadata.py @@ -19,10 +19,10 @@ from __future__ import division from __future__ import print_function import collections -import re from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as session_lib +from tensorflow.python.distribute import device_util from tensorflow.python.eager import context from tensorflow.python.framework import config from tensorflow.python.framework import device as tf_device @@ -35,8 +35,6 @@ _PINGING_MASTER_TIMEOUT_IN_MS = 5 * 60 * 1000 # 10 min _RETRY_TIMES = 12 * 24 # 1 day _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins -_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$') - _DEFAULT_JOB_NAME = 'tpu_worker' _DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' _LOCAL_MASTERS = ('', 'local') @@ -61,11 +59,11 @@ def _query_tpu_system_metadata(master_address, cluster_def=None, if context.executing_eagerly(): logical_devices = config.list_logical_devices() - devices = [] # We want the output type to match in both eager and session mode - for d in logical_devices: - devices.append(session_lib._DeviceAttributes(d.name, d.device_type, 0, 0)) # pylint: disable=protected-access + devices = [session_lib._DeviceAttributes(device_util.canonicalize(d.name), # pylint: disable=protected-access + d.device_type, 0, 0) + for d in logical_devices] else: # TODO(b/120564445): Replace with standard library for retries. retry_count = 1 @@ -97,11 +95,9 @@ def _query_tpu_system_metadata(master_address, cluster_def=None, raise ValueError(msg) for device in devices: - match = _TPU_DEVICE_REG.match(device.name) - if match: - host_id = match.group(1) - core_id = match.group(2) - device_dict[host_id].append(core_id) + spec = tf_device.DeviceSpec.from_string(device.name) + if spec.device_type == 'TPU': + device_dict[spec.task].append(spec.device_index) tpu_core_count += 1 num_of_cores_per_host = 0