Support LogicalDevice in MirroredStrategy config
PiperOrigin-RevId: 280290757 Change-Id: I52dfff634e6e0ccdc81cd5cce682d7df3499b618
This commit is contained in:
parent
a6163f12ab
commit
36eed9c112
@ -38,13 +38,17 @@ def canonicalize(d, default=None):
|
||||
Note: This uses "job:localhost" as the default if executing eagerly.
|
||||
|
||||
Args:
|
||||
d: a device string.
|
||||
d: a device string or tf.config.LogicalDevice
|
||||
default: a string for default device if d doesn't have all components.
|
||||
|
||||
Returns:
|
||||
a canonicalized device string.
|
||||
"""
|
||||
d = tf_device.DeviceSpec.from_string(d)
|
||||
if isinstance(d, context.LogicalDevice):
|
||||
d = tf_device.DeviceSpec.from_string(d.name)
|
||||
else:
|
||||
d = tf_device.DeviceSpec.from_string(d)
|
||||
|
||||
assert d.device_type is None or d.device_type == d.device_type.upper(), (
|
||||
"Device type '%s' must be all-caps." % (d.device_type,))
|
||||
# Fill in missing device fields using defaults.
|
||||
|
@ -38,6 +38,7 @@ from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolve
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -206,7 +207,8 @@ def _is_device_list_single_worker(devices):
|
||||
"""Checks whether the devices list is for single or multi-worker.
|
||||
|
||||
Args:
|
||||
devices: a list of device strings, either local or for remote devices.
|
||||
devices: a list of device strings or tf.config.LogicalDevice objects, for
|
||||
either local or for remote devices.
|
||||
|
||||
Returns:
|
||||
a boolean indicating whether these device strings are for local or for
|
||||
@ -215,7 +217,10 @@ def _is_device_list_single_worker(devices):
|
||||
Raises:
|
||||
ValueError: if device strings are not consistent.
|
||||
"""
|
||||
specs = (tf_device.DeviceSpec.from_string(d) for d in devices)
|
||||
specs = []
|
||||
for d in devices:
|
||||
name = d.name if isinstance(d, context.LogicalDevice) else d
|
||||
specs.append(tf_device.DeviceSpec.from_string(name))
|
||||
num_workers = len({(d.job, d.task, d.replica) for d in specs})
|
||||
all_local = all(d.job in (None, "localhost") for d in specs)
|
||||
any_local = any(d.job in (None, "localhost") for d in specs)
|
||||
@ -321,9 +326,10 @@ def _infer_num_gpus_per_worker(devices):
|
||||
|
||||
|
||||
def all_local_devices(num_gpus=None):
|
||||
if num_gpus is None:
|
||||
num_gpus = context.num_gpus()
|
||||
return device_util.local_devices_from_num_gpus(num_gpus)
|
||||
devices = config.list_logical_devices("GPU")
|
||||
if num_gpus is not None:
|
||||
devices = devices[:num_gpus]
|
||||
return devices or config.list_logical_devices("CPU")
|
||||
|
||||
|
||||
def all_devices():
|
||||
|
@ -30,7 +30,7 @@ def get_gpus():
|
||||
gpus = context.context().list_logical_devices("GPU")
|
||||
actual_gpus = []
|
||||
for gpu in gpus:
|
||||
if "localhost" not in gpu.name:
|
||||
if "job" in gpu.name:
|
||||
actual_gpus.append(gpu.name)
|
||||
return actual_gpus
|
||||
|
||||
|
@ -471,8 +471,12 @@ class Context(object):
|
||||
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
|
||||
self._context_devices.append(pydev.canonical_name(dev_name))
|
||||
spec = pydev.DeviceSpec.from_string(dev_name)
|
||||
# If the job is localhost, we assume that the cluster has not yet been
|
||||
# configured and thus clear the job, replica & task.
|
||||
if spec.job == "localhost":
|
||||
spec = spec.replace(job=None, replica=None, task=None)
|
||||
self._logical_devices.append(
|
||||
LogicalDevice(name=dev_name, device_type=spec.device_type))
|
||||
LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
|
||||
dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
|
||||
if dev_type == "GPU":
|
||||
self._num_gpus += 1
|
||||
|
@ -19,10 +19,10 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import re
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
@ -35,8 +35,6 @@ _PINGING_MASTER_TIMEOUT_IN_MS = 5 * 60 * 1000 # 10 min
|
||||
_RETRY_TIMES = 12 * 24 # 1 day
|
||||
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
|
||||
|
||||
_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
|
||||
|
||||
_DEFAULT_JOB_NAME = 'tpu_worker'
|
||||
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
|
||||
_LOCAL_MASTERS = ('', 'local')
|
||||
@ -61,11 +59,11 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
|
||||
|
||||
if context.executing_eagerly():
|
||||
logical_devices = config.list_logical_devices()
|
||||
devices = []
|
||||
|
||||
# We want the output type to match in both eager and session mode
|
||||
for d in logical_devices:
|
||||
devices.append(session_lib._DeviceAttributes(d.name, d.device_type, 0, 0)) # pylint: disable=protected-access
|
||||
devices = [session_lib._DeviceAttributes(device_util.canonicalize(d.name), # pylint: disable=protected-access
|
||||
d.device_type, 0, 0)
|
||||
for d in logical_devices]
|
||||
else:
|
||||
# TODO(b/120564445): Replace with standard library for retries.
|
||||
retry_count = 1
|
||||
@ -97,11 +95,9 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
|
||||
raise ValueError(msg)
|
||||
|
||||
for device in devices:
|
||||
match = _TPU_DEVICE_REG.match(device.name)
|
||||
if match:
|
||||
host_id = match.group(1)
|
||||
core_id = match.group(2)
|
||||
device_dict[host_id].append(core_id)
|
||||
spec = tf_device.DeviceSpec.from_string(device.name)
|
||||
if spec.device_type == 'TPU':
|
||||
device_dict[spec.task].append(spec.device_index)
|
||||
tpu_core_count += 1
|
||||
|
||||
num_of_cores_per_host = 0
|
||||
|
Loading…
Reference in New Issue
Block a user