Support LogicalDevice in MirroredStrategy config
PiperOrigin-RevId: 280290757 Change-Id: I52dfff634e6e0ccdc81cd5cce682d7df3499b618
This commit is contained in:
parent
a6163f12ab
commit
36eed9c112
@ -38,13 +38,17 @@ def canonicalize(d, default=None):
|
|||||||
Note: This uses "job:localhost" as the default if executing eagerly.
|
Note: This uses "job:localhost" as the default if executing eagerly.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
d: a device string.
|
d: a device string or tf.config.LogicalDevice
|
||||||
default: a string for default device if d doesn't have all components.
|
default: a string for default device if d doesn't have all components.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a canonicalized device string.
|
a canonicalized device string.
|
||||||
"""
|
"""
|
||||||
d = tf_device.DeviceSpec.from_string(d)
|
if isinstance(d, context.LogicalDevice):
|
||||||
|
d = tf_device.DeviceSpec.from_string(d.name)
|
||||||
|
else:
|
||||||
|
d = tf_device.DeviceSpec.from_string(d)
|
||||||
|
|
||||||
assert d.device_type is None or d.device_type == d.device_type.upper(), (
|
assert d.device_type is None or d.device_type == d.device_type.upper(), (
|
||||||
"Device type '%s' must be all-caps." % (d.device_type,))
|
"Device type '%s' must be all-caps." % (d.device_type,))
|
||||||
# Fill in missing device fields using defaults.
|
# Fill in missing device fields using defaults.
|
||||||
|
@ -38,6 +38,7 @@ from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolve
|
|||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import device as tf_device
|
from tensorflow.python.framework import device as tf_device
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -206,7 +207,8 @@ def _is_device_list_single_worker(devices):
|
|||||||
"""Checks whether the devices list is for single or multi-worker.
|
"""Checks whether the devices list is for single or multi-worker.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
devices: a list of device strings, either local or for remote devices.
|
devices: a list of device strings or tf.config.LogicalDevice objects, for
|
||||||
|
either local or for remote devices.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a boolean indicating whether these device strings are for local or for
|
a boolean indicating whether these device strings are for local or for
|
||||||
@ -215,7 +217,10 @@ def _is_device_list_single_worker(devices):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if device strings are not consistent.
|
ValueError: if device strings are not consistent.
|
||||||
"""
|
"""
|
||||||
specs = (tf_device.DeviceSpec.from_string(d) for d in devices)
|
specs = []
|
||||||
|
for d in devices:
|
||||||
|
name = d.name if isinstance(d, context.LogicalDevice) else d
|
||||||
|
specs.append(tf_device.DeviceSpec.from_string(name))
|
||||||
num_workers = len({(d.job, d.task, d.replica) for d in specs})
|
num_workers = len({(d.job, d.task, d.replica) for d in specs})
|
||||||
all_local = all(d.job in (None, "localhost") for d in specs)
|
all_local = all(d.job in (None, "localhost") for d in specs)
|
||||||
any_local = any(d.job in (None, "localhost") for d in specs)
|
any_local = any(d.job in (None, "localhost") for d in specs)
|
||||||
@ -321,9 +326,10 @@ def _infer_num_gpus_per_worker(devices):
|
|||||||
|
|
||||||
|
|
||||||
def all_local_devices(num_gpus=None):
|
def all_local_devices(num_gpus=None):
|
||||||
if num_gpus is None:
|
devices = config.list_logical_devices("GPU")
|
||||||
num_gpus = context.num_gpus()
|
if num_gpus is not None:
|
||||||
return device_util.local_devices_from_num_gpus(num_gpus)
|
devices = devices[:num_gpus]
|
||||||
|
return devices or config.list_logical_devices("CPU")
|
||||||
|
|
||||||
|
|
||||||
def all_devices():
|
def all_devices():
|
||||||
|
@ -30,7 +30,7 @@ def get_gpus():
|
|||||||
gpus = context.context().list_logical_devices("GPU")
|
gpus = context.context().list_logical_devices("GPU")
|
||||||
actual_gpus = []
|
actual_gpus = []
|
||||||
for gpu in gpus:
|
for gpu in gpus:
|
||||||
if "localhost" not in gpu.name:
|
if "job" in gpu.name:
|
||||||
actual_gpus.append(gpu.name)
|
actual_gpus.append(gpu.name)
|
||||||
return actual_gpus
|
return actual_gpus
|
||||||
|
|
||||||
|
@ -471,8 +471,12 @@ class Context(object):
|
|||||||
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
|
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
|
||||||
self._context_devices.append(pydev.canonical_name(dev_name))
|
self._context_devices.append(pydev.canonical_name(dev_name))
|
||||||
spec = pydev.DeviceSpec.from_string(dev_name)
|
spec = pydev.DeviceSpec.from_string(dev_name)
|
||||||
|
# If the job is localhost, we assume that the cluster has not yet been
|
||||||
|
# configured and thus clear the job, replica & task.
|
||||||
|
if spec.job == "localhost":
|
||||||
|
spec = spec.replace(job=None, replica=None, task=None)
|
||||||
self._logical_devices.append(
|
self._logical_devices.append(
|
||||||
LogicalDevice(name=dev_name, device_type=spec.device_type))
|
LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
|
||||||
dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
|
dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
|
||||||
if dev_type == "GPU":
|
if dev_type == "GPU":
|
||||||
self._num_gpus += 1
|
self._num_gpus += 1
|
||||||
|
@ -19,10 +19,10 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import re
|
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.client import session as session_lib
|
from tensorflow.python.client import session as session_lib
|
||||||
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import config
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import device as tf_device
|
from tensorflow.python.framework import device as tf_device
|
||||||
@ -35,8 +35,6 @@ _PINGING_MASTER_TIMEOUT_IN_MS = 5 * 60 * 1000 # 10 min
|
|||||||
_RETRY_TIMES = 12 * 24 # 1 day
|
_RETRY_TIMES = 12 * 24 # 1 day
|
||||||
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
|
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
|
||||||
|
|
||||||
_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
|
|
||||||
|
|
||||||
_DEFAULT_JOB_NAME = 'tpu_worker'
|
_DEFAULT_JOB_NAME = 'tpu_worker'
|
||||||
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
|
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
|
||||||
_LOCAL_MASTERS = ('', 'local')
|
_LOCAL_MASTERS = ('', 'local')
|
||||||
@ -61,11 +59,11 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
|
|||||||
|
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
logical_devices = config.list_logical_devices()
|
logical_devices = config.list_logical_devices()
|
||||||
devices = []
|
|
||||||
|
|
||||||
# We want the output type to match in both eager and session mode
|
# We want the output type to match in both eager and session mode
|
||||||
for d in logical_devices:
|
devices = [session_lib._DeviceAttributes(device_util.canonicalize(d.name), # pylint: disable=protected-access
|
||||||
devices.append(session_lib._DeviceAttributes(d.name, d.device_type, 0, 0)) # pylint: disable=protected-access
|
d.device_type, 0, 0)
|
||||||
|
for d in logical_devices]
|
||||||
else:
|
else:
|
||||||
# TODO(b/120564445): Replace with standard library for retries.
|
# TODO(b/120564445): Replace with standard library for retries.
|
||||||
retry_count = 1
|
retry_count = 1
|
||||||
@ -97,11 +95,9 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
for device in devices:
|
for device in devices:
|
||||||
match = _TPU_DEVICE_REG.match(device.name)
|
spec = tf_device.DeviceSpec.from_string(device.name)
|
||||||
if match:
|
if spec.device_type == 'TPU':
|
||||||
host_id = match.group(1)
|
device_dict[spec.task].append(spec.device_index)
|
||||||
core_id = match.group(2)
|
|
||||||
device_dict[host_id].append(core_id)
|
|
||||||
tpu_core_count += 1
|
tpu_core_count += 1
|
||||||
|
|
||||||
num_of_cores_per_host = 0
|
num_of_cores_per_host = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user