Add InputOption support to all remaining strategies.
PiperOrigin-RevId: 318160598 Change-Id: I99ec2aa41a6e4e5fb4720b878fd82ad10802dcc9
This commit is contained in:
parent
9f292402a6
commit
4cdc8d04f3
@ -74,7 +74,7 @@ class CentralStorageStrategy(distribute_lib.Strategy):
|
||||
def _from_num_gpus(cls, num_gpus):
|
||||
return cls(device_util.local_devices_from_num_gpus(num_gpus))
|
||||
|
||||
def experimental_distribute_dataset(self, dataset): # pylint: disable=useless-super-delegation
|
||||
def experimental_distribute_dataset(self, dataset, options=None): # pylint: disable=useless-super-delegation
|
||||
"""Distributes a tf.data.Dataset instance provided via dataset.
|
||||
|
||||
The returned dataset is a wrapped strategy dataset which creates a
|
||||
@ -96,14 +96,17 @@ class CentralStorageStrategy(distribute_lib.Strategy):
|
||||
```
|
||||
Args:
|
||||
dataset: `tf.data.Dataset` to be prefetched to device.
|
||||
options: `tf.distribute.InputOptions` used to control options on how this
|
||||
dataset is distributed.
|
||||
|
||||
Returns:
|
||||
A "distributed `Dataset`" that the caller can iterate over.
|
||||
"""
|
||||
return super(CentralStorageStrategy, self).experimental_distribute_dataset(
|
||||
dataset)
|
||||
dataset, options)
|
||||
|
||||
def experimental_distribute_datasets_from_function(self, dataset_fn): # pylint: disable=useless-super-delegation
|
||||
def experimental_distribute_datasets_from_function(self, dataset_fn, # pylint: disable=useless-super-delegation
|
||||
options=None):
|
||||
"""Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
|
||||
|
||||
`dataset_fn` will be called once for each worker in the strategy. In this
|
||||
@ -136,6 +139,8 @@ class CentralStorageStrategy(distribute_lib.Strategy):
|
||||
Args:
|
||||
dataset_fn: A function taking a `tf.distribute.InputContext` instance and
|
||||
returning a `tf.data.Dataset`.
|
||||
options: `tf.distribute.InputOptions` used to control options on how this
|
||||
dataset is distributed.
|
||||
|
||||
Returns:
|
||||
A "distributed `Dataset`", which the caller can iterate over like regular
|
||||
@ -143,7 +148,8 @@ class CentralStorageStrategy(distribute_lib.Strategy):
|
||||
"""
|
||||
return super(
|
||||
CentralStorageStrategy,
|
||||
self).experimental_distribute_datasets_from_function(dataset_fn)
|
||||
self).experimental_distribute_datasets_from_function(dataset_fn,
|
||||
options)
|
||||
|
||||
def experimental_local_results(self, value): # pylint: disable=useless-super-delegation
|
||||
"""Returns the list of all local per-replica values contained in `value`.
|
||||
|
@ -358,9 +358,6 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
)
|
||||
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
|
||||
local_devices)
|
||||
host_device = device_util.get_host_for_device(self._worker_device)
|
||||
self._input_workers = input_lib.InputWorkers(
|
||||
[(host_device, self.worker_devices)])
|
||||
|
||||
# Add a default device so that ops without specified devices will not end up
|
||||
# on other workers.
|
||||
@ -378,6 +375,20 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
task_id, self._num_workers, local_devices,
|
||||
self._communication)
|
||||
|
||||
def _input_workers_with_options(self, options=None):
|
||||
host_device = device_util.get_host_for_device(self._worker_device)
|
||||
if not options or options.experimental_prefetch_to_device:
|
||||
return input_lib.InputWorkers([(host_device, self.worker_devices)])
|
||||
else:
|
||||
return input_lib.InputWorkers([(
|
||||
host_device,
|
||||
[device_util.get_host_for_device(worker) for worker in
|
||||
self.worker_devices])])
|
||||
|
||||
@property
|
||||
def _input_workers(self):
|
||||
return self._input_workers_with_options()
|
||||
|
||||
def _get_variable_creator_initial_value(self,
|
||||
replica_id,
|
||||
device,
|
||||
@ -441,7 +452,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
input_context = self._make_input_context()
|
||||
return input_lib.get_distributed_dataset(
|
||||
dataset,
|
||||
self._input_workers,
|
||||
self._input_workers_with_options(options),
|
||||
self._container_strategy(),
|
||||
split_batch_by=self._num_replicas_in_sync,
|
||||
input_context=input_context)
|
||||
@ -451,7 +462,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
input_context = self._make_input_context()
|
||||
return input_lib.get_distributed_datasets_from_function(
|
||||
dataset_fn=dataset_fn,
|
||||
input_workers=self._input_workers,
|
||||
input_workers=self._input_workers_with_options(options),
|
||||
input_contexts=[input_context],
|
||||
strategy=self._container_strategy())
|
||||
|
||||
|
@ -29,7 +29,9 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import collective_all_reduce_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
@ -37,6 +39,7 @@ from tensorflow.python.distribute import strategy_test_lib
|
||||
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
@ -281,6 +284,53 @@ class DistributedCollectiveAllReduceStrategyTest(
|
||||
self.assertEqual(2 * num_workers,
|
||||
distribution.num_replicas_in_sync)
|
||||
|
||||
@combinations.generate(combinations.combine(
|
||||
mode=['graph'],
|
||||
prefetch_to_device=[None, True]))
|
||||
def test_prefetch_to_device_dataset(self, prefetch_to_device):
|
||||
distribution, _, _ = self._get_test_object(
|
||||
task_type='worker',
|
||||
task_id=0,
|
||||
num_gpus=2)
|
||||
if prefetch_to_device is None:
|
||||
input_options = None
|
||||
else:
|
||||
input_options = distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=prefetch_to_device)
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.batch(distribution.num_replicas_in_sync)
|
||||
dataset = distribution.experimental_distribute_dataset(
|
||||
dataset, options=input_options)
|
||||
if isinstance(dataset, input_lib.DistributedDatasetV1):
|
||||
item = dataset.make_initializable_iterator().get_next()
|
||||
else:
|
||||
self.skipTest('unsupported test combination')
|
||||
device_types = {
|
||||
tf_device.DeviceSpec.from_string(tensor.device).device_type for
|
||||
tensor in item.values}
|
||||
self.assertAllEqual(list(device_types), ['GPU'])
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph']))
|
||||
def test_prefetch_to_host_dataset(self):
|
||||
distribution, _, _ = self._get_test_object(
|
||||
task_type='worker',
|
||||
task_id=0,
|
||||
num_gpus=2)
|
||||
input_options = distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.batch(distribution.num_replicas_in_sync)
|
||||
dataset = distribution.experimental_distribute_dataset(
|
||||
dataset, options=input_options)
|
||||
if isinstance(dataset, input_lib.DistributedDatasetV1):
|
||||
item = dataset.make_initializable_iterator().get_next()
|
||||
else:
|
||||
self.skipTest('unsupported test combination')
|
||||
device_types = {
|
||||
tf_device.DeviceSpec.from_string(tensor.device).device_type for
|
||||
tensor in item.values}
|
||||
self.assertAllEqual(list(device_types), ['CPU'])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
|
||||
def testMinimizeLossGraph(self, required_gpus):
|
||||
|
@ -629,11 +629,10 @@ class InputOptions(
|
||||
```
|
||||
|
||||
Attributes:
|
||||
experimental_prefetch_to_device: Boolean. Currently only applies to
|
||||
TPUStrategy. Defaults to True. If True, dataset elements will be
|
||||
prefetched to accelerator device memory. When False, dataset elements are
|
||||
prefetched to host device memory. Must be False when using TPUEmbedding
|
||||
API.
|
||||
experimental_prefetch_to_device: Boolean. Defaults to True. If True, dataset
|
||||
elements will be prefetched to accelerator device memory. When False,
|
||||
dataset elements are prefetched to host device memory. Must be False when
|
||||
using TPUEmbedding API.
|
||||
"""
|
||||
|
||||
def __new__(cls, experimental_prefetch_to_device=True):
|
||||
|
@ -331,16 +331,16 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
def _initialize_single_worker(self, devices):
|
||||
"""Initializes the object for single-worker training."""
|
||||
self._devices = tuple(device_util.canonicalize(d) for d in devices)
|
||||
self._input_workers = input_lib.InputWorkers(
|
||||
((device_util.canonicalize("/device:CPU:0", devices[0]), devices),))
|
||||
self._input_workers_devices = (
|
||||
(device_util.canonicalize("/device:CPU:0", devices[0]), devices),)
|
||||
self._inferred_cross_device_ops = None if self._cross_device_ops else (
|
||||
cross_device_ops_lib.choose_the_best(devices))
|
||||
self._host_input_device = numpy_dataset.SingleDevice(
|
||||
self._input_workers.worker_devices[0])
|
||||
self._input_workers_devices[0][0])
|
||||
self._is_multi_worker_training = False
|
||||
logging.info("Using MirroredStrategy with devices %r", devices)
|
||||
device_spec = tf_device.DeviceSpec.from_string(
|
||||
self._input_workers.worker_devices[0])
|
||||
self._input_workers_devices[0][0])
|
||||
# Ensures when we enter strategy.scope() we use the correct default device
|
||||
if device_spec.job is not None and device_spec.job != "localhost":
|
||||
self._default_device = "/job:%s/replica:%d/task:%d" % (
|
||||
@ -368,7 +368,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
self._host_input_device = numpy_dataset.SingleDevice(workers[0])
|
||||
|
||||
self._devices = tuple(devices)
|
||||
self._input_workers = input_lib.InputWorkers(worker_devices)
|
||||
self._input_workers_devices = worker_devices
|
||||
self._is_multi_worker_training = True
|
||||
|
||||
if len(workers) > 1:
|
||||
@ -385,6 +385,18 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
logging.info("Using MirroredStrategy with remote devices %r", devices)
|
||||
|
||||
def _input_workers_with_options(self, options=None):
|
||||
if not options or options.experimental_prefetch_to_device:
|
||||
return input_lib.InputWorkers(self._input_workers_devices)
|
||||
else:
|
||||
return input_lib.InputWorkers(
|
||||
[(host_device, (host_device,) * len(compute_devices)) for
|
||||
host_device, compute_devices in self._input_workers_devices])
|
||||
|
||||
@property
|
||||
def _input_workers(self):
|
||||
return self._input_workers_with_options()
|
||||
|
||||
def _get_variable_creator_initial_value(self,
|
||||
replica_id,
|
||||
device,
|
||||
@ -478,7 +490,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
def _experimental_distribute_dataset(self, dataset, options):
|
||||
return input_lib.get_distributed_dataset(
|
||||
dataset,
|
||||
self._input_workers,
|
||||
self._input_workers_with_options(options),
|
||||
self._container_strategy(),
|
||||
split_batch_by=self._num_replicas_in_sync)
|
||||
|
||||
@ -489,7 +501,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
def _experimental_distribute_datasets_from_function(self, dataset_fn,
|
||||
options):
|
||||
input_contexts = []
|
||||
num_workers = self._input_workers.num_workers
|
||||
input_workers = self._input_workers_with_options(options)
|
||||
num_workers = input_workers.num_workers
|
||||
for i in range(num_workers):
|
||||
input_contexts.append(distribute_lib.InputContext(
|
||||
num_input_pipelines=num_workers,
|
||||
@ -498,7 +511,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
return input_lib.get_distributed_datasets_from_function(
|
||||
dataset_fn,
|
||||
self._input_workers,
|
||||
input_workers,
|
||||
input_contexts,
|
||||
self._container_strategy())
|
||||
|
||||
|
@ -30,8 +30,10 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
@ -44,6 +46,7 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
@ -230,6 +233,47 @@ class MirroredTwoDeviceDistributionTest(
|
||||
def testTrainableVariables(self, distribution):
|
||||
self._test_trainable_variable(distribution)
|
||||
|
||||
def test_prefetch_to_device_dataset(self, distribution):
|
||||
input_options = distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=True)
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.batch(distribution.num_replicas_in_sync)
|
||||
dataset = distribution.experimental_distribute_dataset(
|
||||
dataset, options=input_options)
|
||||
if context.executing_eagerly():
|
||||
item = next(iter(dataset))
|
||||
else:
|
||||
if isinstance(dataset, input_lib.DistributedDatasetV1):
|
||||
item = dataset.make_initializable_iterator().get_next()
|
||||
else:
|
||||
self.skipTest("unsupported test combination")
|
||||
device_types = [
|
||||
tf_device.DeviceSpec.from_string(tensor.device).device_type for
|
||||
tensor in item.values]
|
||||
expected_device_types = [
|
||||
tf_device.DeviceSpec.from_string(device).device_type for
|
||||
device in distribution.extended.worker_devices]
|
||||
self.assertAllEqual(device_types, expected_device_types)
|
||||
|
||||
def test_prefetch_to_host_dataset(self, distribution):
|
||||
input_options = distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.batch(distribution.num_replicas_in_sync)
|
||||
dataset = distribution.experimental_distribute_dataset(
|
||||
dataset, options=input_options)
|
||||
if context.executing_eagerly():
|
||||
item = next(iter(dataset))
|
||||
else:
|
||||
if isinstance(dataset, input_lib.DistributedDatasetV1):
|
||||
item = dataset.make_initializable_iterator().get_next()
|
||||
else:
|
||||
self.skipTest("unsupported test combination")
|
||||
device_types = {
|
||||
tf_device.DeviceSpec.from_string(tensor.device).device_type for
|
||||
tensor in item.values}
|
||||
self.assertAllEqual(list(device_types), ["CPU"])
|
||||
|
||||
|
||||
def one_device_combinations():
|
||||
return combinations.combine(
|
||||
|
@ -81,7 +81,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
|
||||
distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
|
||||
"OneDeviceStrategy")
|
||||
|
||||
def experimental_distribute_dataset(self, dataset): # pylint: disable=useless-super-delegation
|
||||
def experimental_distribute_dataset(self, dataset, options=None): # pylint: disable=useless-super-delegation
|
||||
"""Distributes a tf.data.Dataset instance provided via dataset.
|
||||
|
||||
In this case, there is only one device, so this is only a thin wrapper
|
||||
@ -102,14 +102,16 @@ class OneDeviceStrategy(distribute_lib.Strategy):
|
||||
```
|
||||
Args:
|
||||
dataset: `tf.data.Dataset` to be prefetched to device.
|
||||
|
||||
options: `tf.distribute.InputOptions` used to control options on how this
|
||||
dataset is distributed.
|
||||
Returns:
|
||||
A "distributed `Dataset`" that the caller can iterate over.
|
||||
"""
|
||||
return super(OneDeviceStrategy, self).experimental_distribute_dataset(
|
||||
dataset)
|
||||
dataset, options)
|
||||
|
||||
def experimental_distribute_datasets_from_function(self, dataset_fn): # pylint: disable=useless-super-delegation
|
||||
def experimental_distribute_datasets_from_function(self, dataset_fn, # pylint: disable=useless-super-delegation
|
||||
options=None):
|
||||
"""Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
|
||||
|
||||
`dataset_fn` will be called once for each worker in the strategy. In this
|
||||
@ -140,6 +142,8 @@ class OneDeviceStrategy(distribute_lib.Strategy):
|
||||
Args:
|
||||
dataset_fn: A function taking a `tf.distribute.InputContext` instance and
|
||||
returning a `tf.data.Dataset`.
|
||||
options: `tf.distribute.InputOptions` used to control options on how this
|
||||
dataset is distributed.
|
||||
|
||||
Returns:
|
||||
A "distributed `Dataset`", which the caller can iterate over like regular
|
||||
@ -147,7 +151,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
|
||||
"""
|
||||
return super(
|
||||
OneDeviceStrategy, self).experimental_distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
dataset_fn, options)
|
||||
|
||||
def experimental_local_results(self, value): # pylint: disable=useless-super-delegation
|
||||
"""Returns the list of all local per-replica values contained in `value`.
|
||||
@ -254,10 +258,18 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
||||
def __init__(self, container_strategy, device):
|
||||
super(OneDeviceExtended, self).__init__(container_strategy)
|
||||
self._device = device_util.resolve(device)
|
||||
suffix_loc = self._device.rfind("/")
|
||||
self._input_device = self._device[:suffix_loc] + "/device:CPU:0"
|
||||
worker_device_pairs = [(self._input_device, [self._device])]
|
||||
self._input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||
self._input_device = device_util.get_host_for_device(self._device)
|
||||
|
||||
def _input_workers_with_options(self, options=None):
|
||||
if not options or options.experimental_prefetch_to_device:
|
||||
return input_lib.InputWorkers([(self._input_device, (self._device,))])
|
||||
else:
|
||||
return input_lib.InputWorkers([(self._input_device,
|
||||
(self._input_device,))])
|
||||
|
||||
@property
|
||||
def _input_workers(self):
|
||||
return self._input_workers_with_options()
|
||||
|
||||
def _create_variable(self, next_creator, **kwargs):
|
||||
colocate_with = kwargs.pop("colocate_with", None)
|
||||
@ -300,14 +312,16 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
||||
def _experimental_distribute_dataset(self, dataset, options):
|
||||
# Note that split_batch_by argument is not passed because it is always 1 in
|
||||
# this strategy, and adding it adds unnecessary overhead to the dataset.
|
||||
return input_lib.get_distributed_dataset(dataset, self._input_workers,
|
||||
self._container_strategy())
|
||||
return input_lib.get_distributed_dataset(
|
||||
dataset,
|
||||
self._input_workers_with_options(options),
|
||||
self._container_strategy())
|
||||
|
||||
def _experimental_distribute_datasets_from_function(self, dataset_fn,
|
||||
options):
|
||||
return input_lib.get_distributed_datasets_from_function(
|
||||
dataset_fn,
|
||||
self._input_workers,
|
||||
self._input_workers_with_options(options),
|
||||
[distribute_lib.InputContext()],
|
||||
self._container_strategy())
|
||||
|
||||
|
@ -20,10 +20,13 @@ from __future__ import print_function
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import strategy_test_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
@ -116,6 +119,44 @@ class OneDeviceStrategyTest(
|
||||
def testTrainableVariables(self, distribution):
|
||||
self._test_trainable_variable(distribution)
|
||||
|
||||
def test_prefetch_to_device_dataset(self, distribution):
|
||||
input_options = distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=True)
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.batch(distribution.num_replicas_in_sync)
|
||||
dataset = distribution.experimental_distribute_dataset(
|
||||
dataset, options=input_options)
|
||||
if context.executing_eagerly():
|
||||
item = next(iter(dataset))
|
||||
else:
|
||||
if isinstance(dataset, input_lib.DistributedDatasetV1):
|
||||
item = dataset.make_initializable_iterator().get_next()
|
||||
else:
|
||||
self.skipTest("unsupported test combination")
|
||||
device_types = (
|
||||
tf_device.DeviceSpec.from_string(item.device).device_type)
|
||||
expected_device_types = (
|
||||
tf_device.DeviceSpec.from_string(
|
||||
distribution.extended.worker_devices[0]).device_type)
|
||||
self.assertAllEqual(device_types, expected_device_types)
|
||||
|
||||
def test_prefetch_to_host_dataset(self, distribution):
|
||||
input_options = distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.batch(distribution.num_replicas_in_sync)
|
||||
dataset = distribution.experimental_distribute_dataset(
|
||||
dataset, options=input_options)
|
||||
if context.executing_eagerly():
|
||||
item = next(iter(dataset))
|
||||
else:
|
||||
if isinstance(dataset, input_lib.DistributedDatasetV1):
|
||||
item = dataset.make_initializable_iterator().get_next()
|
||||
else:
|
||||
self.skipTest("unsupported test combination")
|
||||
self.assertAllEqual(
|
||||
tf_device.DeviceSpec.from_string(item.device).device_type, "CPU")
|
||||
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
|
@ -122,16 +122,18 @@ class ParameterServerStrategy(distribute_lib.Strategy):
|
||||
distribute_lib.distribution_strategy_replica_gauge.get_cell("num_ps").set(
|
||||
len(self.extended.parameter_devices))
|
||||
|
||||
def experimental_distribute_dataset(self, dataset):
|
||||
def experimental_distribute_dataset(self, dataset, options=None):
|
||||
self._raise_pss_error_if_eager()
|
||||
super(ParameterServerStrategy,
|
||||
self).experimental_distribute_dataset(dataset=dataset)
|
||||
self).experimental_distribute_dataset(dataset=dataset,
|
||||
options=options)
|
||||
|
||||
def experimental_distribute_datasets_from_function(self, dataset_fn):
|
||||
def experimental_distribute_datasets_from_function(self, dataset_fn,
|
||||
options=None):
|
||||
self._raise_pss_error_if_eager()
|
||||
super(ParameterServerStrategy,
|
||||
self).experimental_distribute_datasets_from_function(
|
||||
dataset_fn=dataset_fn)
|
||||
dataset_fn=dataset_fn, options=options)
|
||||
|
||||
def run(self, fn, args=(), kwargs=None, options=None):
|
||||
self._raise_pss_error_if_eager()
|
||||
@ -229,22 +231,21 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
|
||||
assert cluster_spec.as_dict()
|
||||
|
||||
worker_device = "/job:%s/task:%d" % (task_type, task_id)
|
||||
self._input_host_device = numpy_dataset.SingleDevice(worker_device)
|
||||
self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
|
||||
self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)
|
||||
|
||||
# Define compute devices which is a list of device strings and one for each
|
||||
# replica. When there are GPUs, replicate operations on these GPUs.
|
||||
# Otherwise, place operations on CPU.
|
||||
if num_gpus > 0:
|
||||
compute_devices = tuple(
|
||||
"%s/device:GPU:%d" % (worker_device, i) for i in range(num_gpus))
|
||||
"%s/device:GPU:%d" % (self._worker_device, i)
|
||||
for i in range(num_gpus))
|
||||
else:
|
||||
compute_devices = (worker_device,)
|
||||
compute_devices = (self._worker_device,)
|
||||
|
||||
self._compute_devices = [
|
||||
device_util.canonicalize(d) for d in compute_devices]
|
||||
self._input_workers = input_lib.InputWorkers(
|
||||
[(worker_device, compute_devices)])
|
||||
|
||||
# In distributed mode, place variables on ps jobs in a round-robin fashion.
|
||||
# Note that devices returned from `replica_device_setter` are not
|
||||
@ -259,7 +260,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
raise ValueError("The cluster spec needs to have `ps` jobs.")
|
||||
self._variable_device = device_setter.replica_device_setter(
|
||||
ps_tasks=num_ps_replicas,
|
||||
worker_device=worker_device,
|
||||
worker_device=self._worker_device,
|
||||
merge_devices=True,
|
||||
cluster=cluster_spec)
|
||||
|
||||
@ -271,7 +272,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
# Add a default device so that ops without specified devices will not end up
|
||||
# on other workers.
|
||||
self._default_device = worker_device
|
||||
self._default_device = self._worker_device
|
||||
|
||||
self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
|
||||
task_id)
|
||||
@ -294,8 +295,8 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
parameter_device,
|
||||
cluster_resolver=None):
|
||||
"""Initialize local devices for training."""
|
||||
worker_device = device_util.canonicalize("/device:CPU:0")
|
||||
self._input_host_device = numpy_dataset.SingleDevice(worker_device)
|
||||
self._worker_device = device_util.canonicalize("/device:CPU:0")
|
||||
self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)
|
||||
|
||||
if compute_devices is None:
|
||||
if not cluster_resolver:
|
||||
@ -318,9 +319,6 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
else:
|
||||
parameter_device = _LOCAL_CPU
|
||||
|
||||
self._input_workers = input_lib.InputWorkers(
|
||||
[(worker_device, compute_devices)])
|
||||
|
||||
self._variable_device = parameter_device
|
||||
self._compute_devices = compute_devices
|
||||
self._parameter_devices = (parameter_device,)
|
||||
@ -334,13 +332,26 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
"single machine) with compute_devices = %r, variable_device = %r",
|
||||
compute_devices, self._variable_device)
|
||||
|
||||
def _input_workers_with_options(self, options=None):
|
||||
if not options or options.experimental_prefetch_to_device:
|
||||
return input_lib.InputWorkers(
|
||||
[(self._worker_device, self._compute_devices)])
|
||||
else:
|
||||
return input_lib.InputWorkers(
|
||||
[(self._worker_device,
|
||||
(self._worker_device,) * len(self._compute_devices))])
|
||||
|
||||
@property
|
||||
def _input_workers(self):
|
||||
return self._input_workers_with_options()
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
distribute_utils.validate_colocate(colocate_with_variable, self)
|
||||
|
||||
def _experimental_distribute_dataset(self, dataset, options):
|
||||
return input_lib.get_distributed_dataset(
|
||||
dataset,
|
||||
self._input_workers,
|
||||
self._input_workers_with_options(options),
|
||||
self._container_strategy(),
|
||||
split_batch_by=self._num_replicas_in_sync)
|
||||
|
||||
@ -394,7 +405,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
return input_lib.get_distributed_datasets_from_function(
|
||||
dataset_fn,
|
||||
self._input_workers,
|
||||
self._input_workers_with_options(options),
|
||||
[input_context],
|
||||
self._container_strategy())
|
||||
|
||||
@ -497,7 +508,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
if d_spec.job == self._task_type and d_spec.task != self._task_id:
|
||||
raise ValueError(
|
||||
"Cannot reduce to another worker: %r, current worker is %r" %
|
||||
(d, self._input_workers.worker_devices[0]))
|
||||
(d, self._worker_device))
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
|
||||
self._verify_destinations_not_different_worker(destinations)
|
||||
|
@ -27,8 +27,10 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import central_storage_strategy
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
@ -40,6 +42,7 @@ from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.estimator import run_config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
@ -765,6 +768,55 @@ class ParameterServerStrategyTest(
|
||||
self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
|
||||
strategy.run, train_step)
|
||||
|
||||
@combinations.generate(combinations.combine(
|
||||
mode=['graph'],
|
||||
prefetch_to_device=[None, True]))
|
||||
def test_prefetch_to_device_dataset(self, prefetch_to_device):
|
||||
distribution, _, _ = create_test_objects(
|
||||
cluster_spec=self._cluster_spec,
|
||||
task_type='worker',
|
||||
task_id=0,
|
||||
num_gpus=2)
|
||||
if prefetch_to_device is None:
|
||||
input_options = None
|
||||
else:
|
||||
input_options = distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=prefetch_to_device)
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.batch(distribution.num_replicas_in_sync)
|
||||
dataset = distribution.experimental_distribute_dataset(
|
||||
dataset, options=input_options)
|
||||
if isinstance(dataset, input_lib.DistributedDatasetV1):
|
||||
item = dataset.make_initializable_iterator().get_next()
|
||||
else:
|
||||
self.skipTest('unsupported test combination')
|
||||
device_types = {
|
||||
tf_device.DeviceSpec.from_string(tensor.device).device_type for
|
||||
tensor in item.values}
|
||||
self.assertAllEqual(list(device_types), ['GPU'])
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph']))
|
||||
def test_prefetch_to_host_dataset(self):
|
||||
distribution, _, _ = create_test_objects(
|
||||
cluster_spec=self._cluster_spec,
|
||||
task_type='worker',
|
||||
task_id=0,
|
||||
num_gpus=2)
|
||||
input_options = distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)
|
||||
dataset = dataset_ops.Dataset.range(100)
|
||||
dataset = dataset.batch(distribution.num_replicas_in_sync)
|
||||
dataset = distribution.experimental_distribute_dataset(
|
||||
dataset, options=input_options)
|
||||
if isinstance(dataset, input_lib.DistributedDatasetV1):
|
||||
item = dataset.make_initializable_iterator().get_next()
|
||||
else:
|
||||
self.skipTest('unsupported test combination')
|
||||
device_types = {
|
||||
tf_device.DeviceSpec.from_string(tensor.device).device_type for
|
||||
tensor in item.values}
|
||||
self.assertAllEqual(list(device_types), ['CPU'])
|
||||
|
||||
|
||||
class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
|
||||
parameterized.TestCase):
|
||||
|
@ -34,11 +34,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_values_from_function"
|
||||
|
@ -34,11 +34,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_values_from_function"
|
||||
|
@ -34,11 +34,11 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_values_from_function"
|
||||
|
Loading…
Reference in New Issue
Block a user