Add InputOption support to all remaining strategies.

PiperOrigin-RevId: 318160598
Change-Id: I99ec2aa41a6e4e5fb4720b878fd82ad10802dcc9
This commit is contained in:
Bruce Fontaine 2020-06-24 16:10:38 -07:00 committed by TensorFlower Gardener
parent 9f292402a6
commit 4cdc8d04f3
13 changed files with 301 additions and 60 deletions

View File

@ -74,7 +74,7 @@ class CentralStorageStrategy(distribute_lib.Strategy):
def _from_num_gpus(cls, num_gpus):
return cls(device_util.local_devices_from_num_gpus(num_gpus))
def experimental_distribute_dataset(self, dataset): # pylint: disable=useless-super-delegation
def experimental_distribute_dataset(self, dataset, options=None): # pylint: disable=useless-super-delegation
"""Distributes a tf.data.Dataset instance provided via dataset.
The returned dataset is a wrapped strategy dataset which creates a
@ -96,14 +96,17 @@ class CentralStorageStrategy(distribute_lib.Strategy):
```
Args:
dataset: `tf.data.Dataset` to be prefetched to device.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
Returns:
A "distributed `Dataset`" that the caller can iterate over.
"""
return super(CentralStorageStrategy, self).experimental_distribute_dataset(
dataset)
dataset, options)
def experimental_distribute_datasets_from_function(self, dataset_fn): # pylint: disable=useless-super-delegation
def experimental_distribute_datasets_from_function(self, dataset_fn, # pylint: disable=useless-super-delegation
options=None):
"""Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
`dataset_fn` will be called once for each worker in the strategy. In this
@ -136,6 +139,8 @@ class CentralStorageStrategy(distribute_lib.Strategy):
Args:
dataset_fn: A function taking a `tf.distribute.InputContext` instance and
returning a `tf.data.Dataset`.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
Returns:
A "distributed `Dataset`", which the caller can iterate over like regular
@ -143,7 +148,8 @@ class CentralStorageStrategy(distribute_lib.Strategy):
"""
return super(
CentralStorageStrategy,
self).experimental_distribute_datasets_from_function(dataset_fn)
self).experimental_distribute_datasets_from_function(dataset_fn,
options)
def experimental_local_results(self, value): # pylint: disable=useless-super-delegation
"""Returns the list of all local per-replica values contained in `value`.

View File

@ -358,9 +358,6 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
)
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
local_devices)
host_device = device_util.get_host_for_device(self._worker_device)
self._input_workers = input_lib.InputWorkers(
[(host_device, self.worker_devices)])
# Add a default device so that ops without specified devices will not end up
# on other workers.
@ -378,6 +375,20 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
task_id, self._num_workers, local_devices,
self._communication)
def _input_workers_with_options(self, options=None):
host_device = device_util.get_host_for_device(self._worker_device)
if not options or options.experimental_prefetch_to_device:
return input_lib.InputWorkers([(host_device, self.worker_devices)])
else:
return input_lib.InputWorkers([(
host_device,
[device_util.get_host_for_device(worker) for worker in
self.worker_devices])])
@property
def _input_workers(self):
return self._input_workers_with_options()
def _get_variable_creator_initial_value(self,
replica_id,
device,
@ -441,7 +452,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
input_context = self._make_input_context()
return input_lib.get_distributed_dataset(
dataset,
self._input_workers,
self._input_workers_with_options(options),
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync,
input_context=input_context)
@ -451,7 +462,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
input_context = self._make_input_context()
return input_lib.get_distributed_datasets_from_function(
dataset_fn=dataset_fn,
input_workers=self._input_workers,
input_workers=self._input_workers_with_options(options),
input_contexts=[input_context],
strategy=self._container_strategy())

View File

@ -29,7 +29,9 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import reduce_util
@ -37,6 +39,7 @@ from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@ -281,6 +284,53 @@ class DistributedCollectiveAllReduceStrategyTest(
self.assertEqual(2 * num_workers,
distribution.num_replicas_in_sync)
@combinations.generate(combinations.combine(
mode=['graph'],
prefetch_to_device=[None, True]))
def test_prefetch_to_device_dataset(self, prefetch_to_device):
distribution, _, _ = self._get_test_object(
task_type='worker',
task_id=0,
num_gpus=2)
if prefetch_to_device is None:
input_options = None
else:
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=prefetch_to_device)
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.batch(distribution.num_replicas_in_sync)
dataset = distribution.experimental_distribute_dataset(
dataset, options=input_options)
if isinstance(dataset, input_lib.DistributedDatasetV1):
item = dataset.make_initializable_iterator().get_next()
else:
self.skipTest('unsupported test combination')
device_types = {
tf_device.DeviceSpec.from_string(tensor.device).device_type for
tensor in item.values}
self.assertAllEqual(list(device_types), ['GPU'])
@combinations.generate(combinations.combine(mode=['graph']))
def test_prefetch_to_host_dataset(self):
distribution, _, _ = self._get_test_object(
task_type='worker',
task_id=0,
num_gpus=2)
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=False)
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.batch(distribution.num_replicas_in_sync)
dataset = distribution.experimental_distribute_dataset(
dataset, options=input_options)
if isinstance(dataset, input_lib.DistributedDatasetV1):
item = dataset.make_initializable_iterator().get_next()
else:
self.skipTest('unsupported test combination')
device_types = {
tf_device.DeviceSpec.from_string(tensor.device).device_type for
tensor in item.values}
self.assertAllEqual(list(device_types), ['CPU'])
@combinations.generate(
combinations.combine(mode=['graph'], required_gpus=[0, 1, 2]))
def testMinimizeLossGraph(self, required_gpus):

View File

@ -629,11 +629,10 @@ class InputOptions(
```
Attributes:
experimental_prefetch_to_device: Boolean. Currently only applies to
TPUStrategy. Defaults to True. If True, dataset elements will be
prefetched to accelerator device memory. When False, dataset elements are
prefetched to host device memory. Must be False when using TPUEmbedding
API.
experimental_prefetch_to_device: Boolean. Defaults to True. If True, dataset
elements will be prefetched to accelerator device memory. When False,
dataset elements are prefetched to host device memory. Must be False when
using TPUEmbedding API.
"""
def __new__(cls, experimental_prefetch_to_device=True):

View File

@ -331,16 +331,16 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
def _initialize_single_worker(self, devices):
"""Initializes the object for single-worker training."""
self._devices = tuple(device_util.canonicalize(d) for d in devices)
self._input_workers = input_lib.InputWorkers(
((device_util.canonicalize("/device:CPU:0", devices[0]), devices),))
self._input_workers_devices = (
(device_util.canonicalize("/device:CPU:0", devices[0]), devices),)
self._inferred_cross_device_ops = None if self._cross_device_ops else (
cross_device_ops_lib.choose_the_best(devices))
self._host_input_device = numpy_dataset.SingleDevice(
self._input_workers.worker_devices[0])
self._input_workers_devices[0][0])
self._is_multi_worker_training = False
logging.info("Using MirroredStrategy with devices %r", devices)
device_spec = tf_device.DeviceSpec.from_string(
self._input_workers.worker_devices[0])
self._input_workers_devices[0][0])
# Ensures when we enter strategy.scope() we use the correct default device
if device_spec.job is not None and device_spec.job != "localhost":
self._default_device = "/job:%s/replica:%d/task:%d" % (
@ -368,7 +368,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
self._host_input_device = numpy_dataset.SingleDevice(workers[0])
self._devices = tuple(devices)
self._input_workers = input_lib.InputWorkers(worker_devices)
self._input_workers_devices = worker_devices
self._is_multi_worker_training = True
if len(workers) > 1:
@ -385,6 +385,18 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
logging.info("Using MirroredStrategy with remote devices %r", devices)
def _input_workers_with_options(self, options=None):
if not options or options.experimental_prefetch_to_device:
return input_lib.InputWorkers(self._input_workers_devices)
else:
return input_lib.InputWorkers(
[(host_device, (host_device,) * len(compute_devices)) for
host_device, compute_devices in self._input_workers_devices])
@property
def _input_workers(self):
return self._input_workers_with_options()
def _get_variable_creator_initial_value(self,
replica_id,
device,
@ -478,7 +490,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
def _experimental_distribute_dataset(self, dataset, options):
return input_lib.get_distributed_dataset(
dataset,
self._input_workers,
self._input_workers_with_options(options),
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
@ -489,7 +501,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
def _experimental_distribute_datasets_from_function(self, dataset_fn,
options):
input_contexts = []
num_workers = self._input_workers.num_workers
input_workers = self._input_workers_with_options(options)
num_workers = input_workers.num_workers
for i in range(num_workers):
input_contexts.append(distribute_lib.InputContext(
num_input_pipelines=num_workers,
@ -498,7 +511,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
return input_lib.get_distributed_datasets_from_function(
dataset_fn,
self._input_workers,
input_workers,
input_contexts,
self._container_strategy())

View File

@ -30,8 +30,10 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import reduce_util
@ -44,6 +46,7 @@ from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
@ -230,6 +233,47 @@ class MirroredTwoDeviceDistributionTest(
def testTrainableVariables(self, distribution):
self._test_trainable_variable(distribution)
def test_prefetch_to_device_dataset(self, distribution):
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=True)
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.batch(distribution.num_replicas_in_sync)
dataset = distribution.experimental_distribute_dataset(
dataset, options=input_options)
if context.executing_eagerly():
item = next(iter(dataset))
else:
if isinstance(dataset, input_lib.DistributedDatasetV1):
item = dataset.make_initializable_iterator().get_next()
else:
self.skipTest("unsupported test combination")
device_types = [
tf_device.DeviceSpec.from_string(tensor.device).device_type for
tensor in item.values]
expected_device_types = [
tf_device.DeviceSpec.from_string(device).device_type for
device in distribution.extended.worker_devices]
self.assertAllEqual(device_types, expected_device_types)
def test_prefetch_to_host_dataset(self, distribution):
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=False)
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.batch(distribution.num_replicas_in_sync)
dataset = distribution.experimental_distribute_dataset(
dataset, options=input_options)
if context.executing_eagerly():
item = next(iter(dataset))
else:
if isinstance(dataset, input_lib.DistributedDatasetV1):
item = dataset.make_initializable_iterator().get_next()
else:
self.skipTest("unsupported test combination")
device_types = {
tf_device.DeviceSpec.from_string(tensor.device).device_type for
tensor in item.values}
self.assertAllEqual(list(device_types), ["CPU"])
def one_device_combinations():
return combinations.combine(

View File

@ -81,7 +81,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
"OneDeviceStrategy")
def experimental_distribute_dataset(self, dataset): # pylint: disable=useless-super-delegation
def experimental_distribute_dataset(self, dataset, options=None): # pylint: disable=useless-super-delegation
"""Distributes a tf.data.Dataset instance provided via dataset.
In this case, there is only one device, so this is only a thin wrapper
@ -102,14 +102,16 @@ class OneDeviceStrategy(distribute_lib.Strategy):
```
Args:
dataset: `tf.data.Dataset` to be prefetched to device.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
Returns:
A "distributed `Dataset`" that the caller can iterate over.
"""
return super(OneDeviceStrategy, self).experimental_distribute_dataset(
dataset)
dataset, options)
def experimental_distribute_datasets_from_function(self, dataset_fn): # pylint: disable=useless-super-delegation
def experimental_distribute_datasets_from_function(self, dataset_fn, # pylint: disable=useless-super-delegation
options=None):
"""Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
`dataset_fn` will be called once for each worker in the strategy. In this
@ -140,6 +142,8 @@ class OneDeviceStrategy(distribute_lib.Strategy):
Args:
dataset_fn: A function taking a `tf.distribute.InputContext` instance and
returning a `tf.data.Dataset`.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
Returns:
A "distributed `Dataset`", which the caller can iterate over like regular
@ -147,7 +151,7 @@ class OneDeviceStrategy(distribute_lib.Strategy):
"""
return super(
OneDeviceStrategy, self).experimental_distribute_datasets_from_function(
dataset_fn)
dataset_fn, options)
def experimental_local_results(self, value): # pylint: disable=useless-super-delegation
"""Returns the list of all local per-replica values contained in `value`.
@ -254,10 +258,18 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
def __init__(self, container_strategy, device):
super(OneDeviceExtended, self).__init__(container_strategy)
self._device = device_util.resolve(device)
suffix_loc = self._device.rfind("/")
self._input_device = self._device[:suffix_loc] + "/device:CPU:0"
worker_device_pairs = [(self._input_device, [self._device])]
self._input_workers = input_lib.InputWorkers(worker_device_pairs)
self._input_device = device_util.get_host_for_device(self._device)
def _input_workers_with_options(self, options=None):
if not options or options.experimental_prefetch_to_device:
return input_lib.InputWorkers([(self._input_device, (self._device,))])
else:
return input_lib.InputWorkers([(self._input_device,
(self._input_device,))])
@property
def _input_workers(self):
return self._input_workers_with_options()
def _create_variable(self, next_creator, **kwargs):
colocate_with = kwargs.pop("colocate_with", None)
@ -300,14 +312,16 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
def _experimental_distribute_dataset(self, dataset, options):
# Note that split_batch_by argument is not passed because it is always 1 in
# this strategy, and adding it adds unnecessary overhead to the dataset.
return input_lib.get_distributed_dataset(dataset, self._input_workers,
self._container_strategy())
return input_lib.get_distributed_dataset(
dataset,
self._input_workers_with_options(options),
self._container_strategy())
def _experimental_distribute_datasets_from_function(self, dataset_fn,
options):
return input_lib.get_distributed_datasets_from_function(
dataset_fn,
self._input_workers,
self._input_workers_with_options(options),
[distribute_lib.InputContext()],
self._container_strategy())

View File

@ -20,10 +20,13 @@ from __future__ import print_function
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import device as tf_device
@combinations.generate(
@ -116,6 +119,44 @@ class OneDeviceStrategyTest(
def testTrainableVariables(self, distribution):
self._test_trainable_variable(distribution)
def test_prefetch_to_device_dataset(self, distribution):
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=True)
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.batch(distribution.num_replicas_in_sync)
dataset = distribution.experimental_distribute_dataset(
dataset, options=input_options)
if context.executing_eagerly():
item = next(iter(dataset))
else:
if isinstance(dataset, input_lib.DistributedDatasetV1):
item = dataset.make_initializable_iterator().get_next()
else:
self.skipTest("unsupported test combination")
device_types = (
tf_device.DeviceSpec.from_string(item.device).device_type)
expected_device_types = (
tf_device.DeviceSpec.from_string(
distribution.extended.worker_devices[0]).device_type)
self.assertAllEqual(device_types, expected_device_types)
def test_prefetch_to_host_dataset(self, distribution):
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=False)
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.batch(distribution.num_replicas_in_sync)
dataset = distribution.experimental_distribute_dataset(
dataset, options=input_options)
if context.executing_eagerly():
item = next(iter(dataset))
else:
if isinstance(dataset, input_lib.DistributedDatasetV1):
item = dataset.make_initializable_iterator().get_next()
else:
self.skipTest("unsupported test combination")
self.assertAllEqual(
tf_device.DeviceSpec.from_string(item.device).device_type, "CPU")
@combinations.generate(
combinations.combine(

View File

@ -122,16 +122,18 @@ class ParameterServerStrategy(distribute_lib.Strategy):
distribute_lib.distribution_strategy_replica_gauge.get_cell("num_ps").set(
len(self.extended.parameter_devices))
def experimental_distribute_dataset(self, dataset):
def experimental_distribute_dataset(self, dataset, options=None):
self._raise_pss_error_if_eager()
super(ParameterServerStrategy,
self).experimental_distribute_dataset(dataset=dataset)
self).experimental_distribute_dataset(dataset=dataset,
options=options)
def experimental_distribute_datasets_from_function(self, dataset_fn):
def experimental_distribute_datasets_from_function(self, dataset_fn,
options=None):
self._raise_pss_error_if_eager()
super(ParameterServerStrategy,
self).experimental_distribute_datasets_from_function(
dataset_fn=dataset_fn)
dataset_fn=dataset_fn, options=options)
def run(self, fn, args=(), kwargs=None, options=None):
self._raise_pss_error_if_eager()
@ -229,22 +231,21 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
assert cluster_spec.as_dict()
worker_device = "/job:%s/task:%d" % (task_type, task_id)
self._input_host_device = numpy_dataset.SingleDevice(worker_device)
self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)
# Define compute devices which is a list of device strings and one for each
# replica. When there are GPUs, replicate operations on these GPUs.
# Otherwise, place operations on CPU.
if num_gpus > 0:
compute_devices = tuple(
"%s/device:GPU:%d" % (worker_device, i) for i in range(num_gpus))
"%s/device:GPU:%d" % (self._worker_device, i)
for i in range(num_gpus))
else:
compute_devices = (worker_device,)
compute_devices = (self._worker_device,)
self._compute_devices = [
device_util.canonicalize(d) for d in compute_devices]
self._input_workers = input_lib.InputWorkers(
[(worker_device, compute_devices)])
# In distributed mode, place variables on ps jobs in a round-robin fashion.
# Note that devices returned from `replica_device_setter` are not
@ -259,7 +260,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
raise ValueError("The cluster spec needs to have `ps` jobs.")
self._variable_device = device_setter.replica_device_setter(
ps_tasks=num_ps_replicas,
worker_device=worker_device,
worker_device=self._worker_device,
merge_devices=True,
cluster=cluster_spec)
@ -271,7 +272,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
# Add a default device so that ops without specified devices will not end up
# on other workers.
self._default_device = worker_device
self._default_device = self._worker_device
self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
task_id)
@ -294,8 +295,8 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
parameter_device,
cluster_resolver=None):
"""Initialize local devices for training."""
worker_device = device_util.canonicalize("/device:CPU:0")
self._input_host_device = numpy_dataset.SingleDevice(worker_device)
self._worker_device = device_util.canonicalize("/device:CPU:0")
self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)
if compute_devices is None:
if not cluster_resolver:
@ -318,9 +319,6 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
else:
parameter_device = _LOCAL_CPU
self._input_workers = input_lib.InputWorkers(
[(worker_device, compute_devices)])
self._variable_device = parameter_device
self._compute_devices = compute_devices
self._parameter_devices = (parameter_device,)
@ -334,13 +332,26 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
"single machine) with compute_devices = %r, variable_device = %r",
compute_devices, self._variable_device)
def _input_workers_with_options(self, options=None):
if not options or options.experimental_prefetch_to_device:
return input_lib.InputWorkers(
[(self._worker_device, self._compute_devices)])
else:
return input_lib.InputWorkers(
[(self._worker_device,
(self._worker_device,) * len(self._compute_devices))])
@property
def _input_workers(self):
return self._input_workers_with_options()
def _validate_colocate_with_variable(self, colocate_with_variable):
distribute_utils.validate_colocate(colocate_with_variable, self)
def _experimental_distribute_dataset(self, dataset, options):
return input_lib.get_distributed_dataset(
dataset,
self._input_workers,
self._input_workers_with_options(options),
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
@ -394,7 +405,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
return input_lib.get_distributed_datasets_from_function(
dataset_fn,
self._input_workers,
self._input_workers_with_options(options),
[input_context],
self._container_strategy())
@ -497,7 +508,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
if d_spec.job == self._task_type and d_spec.task != self._task_id:
raise ValueError(
"Cannot reduce to another worker: %r, current worker is %r" %
(d, self._input_workers.worker_devices[0]))
(d, self._worker_device))
def _reduce_to(self, reduce_op, value, destinations, experimental_hints):
self._verify_destinations_not_different_worker(destinations)

View File

@ -27,8 +27,10 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import central_storage_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import parameter_server_strategy
@ -40,6 +42,7 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@ -765,6 +768,55 @@ class ParameterServerStrategyTest(
self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*',
strategy.run, train_step)
@combinations.generate(combinations.combine(
mode=['graph'],
prefetch_to_device=[None, True]))
def test_prefetch_to_device_dataset(self, prefetch_to_device):
distribution, _, _ = create_test_objects(
cluster_spec=self._cluster_spec,
task_type='worker',
task_id=0,
num_gpus=2)
if prefetch_to_device is None:
input_options = None
else:
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=prefetch_to_device)
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.batch(distribution.num_replicas_in_sync)
dataset = distribution.experimental_distribute_dataset(
dataset, options=input_options)
if isinstance(dataset, input_lib.DistributedDatasetV1):
item = dataset.make_initializable_iterator().get_next()
else:
self.skipTest('unsupported test combination')
device_types = {
tf_device.DeviceSpec.from_string(tensor.device).device_type for
tensor in item.values}
self.assertAllEqual(list(device_types), ['GPU'])
@combinations.generate(combinations.combine(mode=['graph']))
def test_prefetch_to_host_dataset(self):
distribution, _, _ = create_test_objects(
cluster_spec=self._cluster_spec,
task_type='worker',
task_id=0,
num_gpus=2)
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=False)
dataset = dataset_ops.Dataset.range(100)
dataset = dataset.batch(distribution.num_replicas_in_sync)
dataset = distribution.experimental_distribute_dataset(
dataset, options=input_options)
if isinstance(dataset, input_lib.DistributedDatasetV1):
item = dataset.make_initializable_iterator().get_next()
else:
self.skipTest('unsupported test combination')
device_types = {
tf_device.DeviceSpec.from_string(tensor.device).device_type for
tensor in item.values}
self.assertAllEqual(list(device_types), ['CPU'])
class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
parameterized.TestCase):

View File

@ -34,11 +34,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_values_from_function"

View File

@ -34,11 +34,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_values_from_function"

View File

@ -34,11 +34,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_values_from_function"