Support LogicalDevice in MirroredStrategy config

PiperOrigin-RevId: 280290757
Change-Id: I52dfff634e6e0ccdc81cd5cce682d7df3499b618
This commit is contained in:
Gaurav Jain 2019-11-13 15:14:46 -08:00 committed by TensorFlower Gardener
parent a6163f12ab
commit 36eed9c112
5 changed files with 30 additions and 20 deletions

View File

@ -38,13 +38,17 @@ def canonicalize(d, default=None):
Note: This uses "job:localhost" as the default if executing eagerly.
Args:
d: a device string.
d: a device string or tf.config.LogicalDevice
default: a string for default device if d doesn't have all components.
Returns:
a canonicalized device string.
"""
d = tf_device.DeviceSpec.from_string(d)
if isinstance(d, context.LogicalDevice):
d = tf_device.DeviceSpec.from_string(d.name)
else:
d = tf_device.DeviceSpec.from_string(d)
assert d.device_type is None or d.device_type == d.device_type.upper(), (
"Device type '%s' must be all-caps." % (d.device_type,))
# Fill in missing device fields using defaults.

View File

@ -38,6 +38,7 @@ from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolve
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import tape
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
@ -206,7 +207,8 @@ def _is_device_list_single_worker(devices):
"""Checks whether the devices list is for single or multi-worker.
Args:
devices: a list of device strings, either local or for remote devices.
devices: a list of device strings or tf.config.LogicalDevice objects, for
either local or for remote devices.
Returns:
a boolean indicating whether these device strings are for local or for
@ -215,7 +217,10 @@ def _is_device_list_single_worker(devices):
Raises:
ValueError: if device strings are not consistent.
"""
specs = (tf_device.DeviceSpec.from_string(d) for d in devices)
specs = []
for d in devices:
name = d.name if isinstance(d, context.LogicalDevice) else d
specs.append(tf_device.DeviceSpec.from_string(name))
num_workers = len({(d.job, d.task, d.replica) for d in specs})
all_local = all(d.job in (None, "localhost") for d in specs)
any_local = any(d.job in (None, "localhost") for d in specs)
@ -321,9 +326,10 @@ def _infer_num_gpus_per_worker(devices):
def all_local_devices(num_gpus=None):
if num_gpus is None:
num_gpus = context.num_gpus()
return device_util.local_devices_from_num_gpus(num_gpus)
devices = config.list_logical_devices("GPU")
if num_gpus is not None:
devices = devices[:num_gpus]
return devices or config.list_logical_devices("CPU")
def all_devices():

View File

@ -30,7 +30,7 @@ def get_gpus():
gpus = context.context().list_logical_devices("GPU")
actual_gpus = []
for gpu in gpus:
if "localhost" not in gpu.name:
if "job" in gpu.name:
actual_gpus.append(gpu.name)
return actual_gpus

View File

@ -471,8 +471,12 @@ class Context(object):
dev_name = pywrap_tensorflow.TF_DeviceListName(device_list, i)
self._context_devices.append(pydev.canonical_name(dev_name))
spec = pydev.DeviceSpec.from_string(dev_name)
# If the job is localhost, we assume that the cluster has not yet been
# configured and thus clear the job, replica & task.
if spec.job == "localhost":
spec = spec.replace(job=None, replica=None, task=None)
self._logical_devices.append(
LogicalDevice(name=dev_name, device_type=spec.device_type))
LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
dev_type = pywrap_tensorflow.TF_DeviceListType(device_list, i)
if dev_type == "GPU":
self._num_gpus += 1

View File

@ -19,10 +19,10 @@ from __future__ import division
from __future__ import print_function
import collections
import re
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import device as tf_device
@ -35,8 +35,6 @@ _PINGING_MASTER_TIMEOUT_IN_MS = 5 * 60 * 1000 # 10 min
_RETRY_TIMES = 12 * 24 # 1 day
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
_DEFAULT_JOB_NAME = 'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
_LOCAL_MASTERS = ('', 'local')
@ -61,11 +59,11 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
if context.executing_eagerly():
logical_devices = config.list_logical_devices()
devices = []
# We want the output type to match in both eager and session mode
for d in logical_devices:
devices.append(session_lib._DeviceAttributes(d.name, d.device_type, 0, 0)) # pylint: disable=protected-access
devices = [session_lib._DeviceAttributes(device_util.canonicalize(d.name), # pylint: disable=protected-access
d.device_type, 0, 0)
for d in logical_devices]
else:
# TODO(b/120564445): Replace with standard library for retries.
retry_count = 1
@ -97,11 +95,9 @@ def _query_tpu_system_metadata(master_address, cluster_def=None,
raise ValueError(msg)
for device in devices:
match = _TPU_DEVICE_REG.match(device.name)
if match:
host_id = match.group(1)
core_id = match.group(2)
device_dict[host_id].append(core_id)
spec = tf_device.DeviceSpec.from_string(device.name)
if spec.device_type == 'TPU':
device_dict[spec.task].append(spec.device_index)
tpu_core_count += 1
num_of_cores_per_host = 0