Merge pull request #38968 from kushanam:distribute_dali_ctl

PiperOrigin-RevId: 337869342
Change-Id: I3c34e90fe023dbefa8c66ac4331a251292ee547a
This commit is contained in:
TensorFlower Gardener 2020-10-19 09:25:22 -07:00
commit 8e8c010e95
12 changed files with 431 additions and 49 deletions

View File

@ -102,6 +102,13 @@ class CentralStorageStrategy(distribute_lib.Strategy):
Returns: Returns:
A "distributed `Dataset`" that the caller can iterate over. A "distributed `Dataset`" that the caller can iterate over.
""" """
if (options and options.experimental_replication_moden ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
'InputReplicationMode.PER_REPLICA '
'is only supported in '
'`experimental_distribute_datasets_from_function`.'
)
return super(CentralStorageStrategy, self).experimental_distribute_dataset( return super(CentralStorageStrategy, self).experimental_distribute_dataset(
dataset, options) dataset, options)

View File

@ -469,6 +469,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
return input_context return input_context
def _experimental_distribute_dataset(self, dataset, options): def _experimental_distribute_dataset(self, dataset, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function`."
)
input_context = self._make_input_context() input_context = self._make_input_context()
return input_lib.get_distributed_dataset( return input_lib.get_distributed_dataset(
dataset, dataset,
@ -478,6 +485,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
input_context=input_context) input_context=input_context)
def _distribute_datasets_from_function(self, dataset_fn, options): def _distribute_datasets_from_function(self, dataset_fn, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
" `experimental_distribute_datasets_from_function` "
"of tf.distribute.MirroredStrategy")
input_context = self._make_input_context() input_context = self._make_input_context()
return input_lib.get_distributed_datasets_from_function( return input_lib.get_distributed_datasets_from_function(
dataset_fn=dataset_fn, dataset_fn=dataset_fn,

View File

@ -439,8 +439,12 @@ class InputReplicationMode(enum.Enum):
Replicas will dequeue from the local Dataset on their worker. Replicas will dequeue from the local Dataset on their worker.
`tf.distribute.Strategy` doesn't manage any state sharing between such `tf.distribute.Strategy` doesn't manage any state sharing between such
separate input pipelines. separate input pipelines.
* `PER_REPLICA`: The input function will be called on each replica seperately.
`tf.distribute.Strategy` doesn't manage any state sharing between such
separate input pipelines.
""" """
PER_WORKER = "PER_WORKER" PER_WORKER = "PER_WORKER"
PER_REPLICA = "PER_REPLICA"
@tf_export("distribute.InputContext") @tf_export("distribute.InputContext")
@ -616,6 +620,8 @@ class RunOptions(
class InputOptions( class InputOptions(
collections.namedtuple("InputOptions", [ collections.namedtuple("InputOptions", [
"experimental_prefetch_to_device", "experimental_prefetch_to_device",
"experimental_replication_mode",
"experimental_place_dataset_on_device",
])): ])):
"""Run options for `experimental_distribute_dataset(s_from_function)`. """Run options for `experimental_distribute_dataset(s_from_function)`.
@ -633,19 +639,36 @@ class InputOptions(
strategy.experimental_distribute_dataset( strategy.experimental_distribute_dataset(
dataset, dataset,
tf.distribute.InputOptions( tf.distribute.InputOptions(
experimental_prefetch_to_device=False))) experimental_replication_mode=
experimental_replication_mode.PER_WORKER,
experimental_place_dataset_on_device=False)))
``` ```
Attributes: Attributes:
experimental_prefetch_to_device: Boolean. Defaults to True. If True, dataset experimental_prefetch_to_device: Boolean. Defaults to True. If True, dataset
elements will be prefetched to accelerator device memory. When False, elements will be prefetched to accelerator device memory. When False,
dataset elements are prefetched to host device memory. Must be False when dataset elements are prefetched to host device memory. Must be False when
using TPUEmbedding API. using TPUEmbedding API. experimental_prefetch_to_device can only be used
with experimental_replication_mode=PER_WORKER
experimental_replication_mode: Replication mode for the input function.
Currently, the InputReplicationMode.PER_REPLICA is only supported with
tf.distribute.MirroredStrategy.
experimental_distribute_datasets_from_function.
The default value is InputReplicationMode.PER_WORKER.
experimental_place_dataset_on_device: Boolean. Default to False. When True,
dataset will be placed on the device, otherwise it will remain on the
host. experimental_place_dataset_on_device=True can only be used with
experimental_replication_mode=PER_REPLICA
""" """
def __new__(cls, experimental_prefetch_to_device=True): def __new__(cls,
return super(InputOptions, cls).__new__(cls, experimental_prefetch_to_device=True,
experimental_prefetch_to_device) experimental_replication_mode=InputReplicationMode.PER_WORKER,
experimental_place_dataset_on_device=False):
return super(InputOptions,
cls).__new__(cls, experimental_prefetch_to_device,
experimental_replication_mode,
experimental_place_dataset_on_device)
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Base classes for all distribution strategies. # Base classes for all distribution strategies.

View File

@ -35,6 +35,7 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_ops from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values from tensorflow.python.distribute import values
from tensorflow.python.distribute.distribute_lib import InputReplicationMode
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
@ -108,7 +109,8 @@ def get_distributed_dataset(dataset,
def get_distributed_datasets_from_function(dataset_fn, def get_distributed_datasets_from_function(dataset_fn,
input_workers, input_workers,
input_contexts, input_contexts,
strategy): strategy,
options=None):
"""Returns a distributed dataset from the given input function. """Returns a distributed dataset from the given input function.
This is a common function that is used by all strategies to return a This is a common function that is used by all strategies to return a
@ -126,22 +128,43 @@ def get_distributed_datasets_from_function(dataset_fn,
`worker_device_pairs`. `worker_device_pairs`.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch. handle last partial batch.
options: Default is None. `tf.distribute.InputOptions` used to control
options on how this dataset is distributed.
Returns: Returns:
A distributed dataset instance. A distributed dataset instance.
Raises:
ValueError: if `options.experimental_replication_mode` and
`options.experimental_place_dataset_on_device` are not consistent
""" """
if (options is not None and
options.experimental_replication_mode != InputReplicationMode.PER_REPLICA
and options.experimental_place_dataset_on_device):
raise ValueError(
"When `experimental_place_dataset_on_device` is set for dataset "
"placement, you must also specify `PER_REPLICA` for the "
"replication mode")
if (options is not None and
options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
and options.experimental_prefetch_to_device and
options.experimental_place_dataset_on_device):
raise ValueError(
"`experimental_place_dataset_on_device` can not be set to True "
"when experimental_prefetch_to_device is True and "
"replication mode is set to `PER_REPLICA`")
if tf2.enabled(): if tf2.enabled():
return DistributedDatasetsFromFunction( return DistributedDatasetsFromFunction(dataset_fn, input_workers,
dataset_fn, input_contexts, strategy, options)
input_workers,
input_contexts,
strategy)
else: else:
return DistributedDatasetsFromFunctionV1( return DistributedDatasetsFromFunctionV1(
dataset_fn, dataset_fn,
input_workers, input_workers,
input_contexts, input_contexts,
strategy) strategy,
options)
@tf_export("distribute.DistributedIterator", v1=[]) @tf_export("distribute.DistributedIterator", v1=[])
@ -1188,7 +1211,8 @@ class DistributedDatasetV1(DistributedDataset):
class DistributedDatasetsFromFunction(_IterableInput): class DistributedDatasetsFromFunction(_IterableInput):
"""Inputs created from dataset function.""" """Inputs created from dataset function."""
def __init__(self, dataset_fn, input_workers, input_contexts, strategy): def __init__(self, dataset_fn, input_workers, input_contexts, strategy,
options):
"""Makes an iterable from datasets created by the given function. """Makes an iterable from datasets created by the given function.
Args: Args:
@ -1199,6 +1223,8 @@ class DistributedDatasetsFromFunction(_IterableInput):
`worker_device_pairs`. `worker_device_pairs`.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch. handle last partial batch.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
""" """
super(DistributedDatasetsFromFunction, self).__init__( super(DistributedDatasetsFromFunction, self).__init__(
input_workers=input_workers) input_workers=input_workers)
@ -1212,10 +1238,10 @@ class DistributedDatasetsFromFunction(_IterableInput):
self._input_workers = input_workers self._input_workers = input_workers
self._input_contexts = input_contexts self._input_contexts = input_contexts
self._strategy = strategy self._strategy = strategy
self._options = options
self._datasets, element_spec = ( self._datasets, element_spec = (
_create_datasets_per_worker_with_input_context(self._input_contexts, _create_datasets_from_function_with_input_context(
self._input_workers, self._input_contexts, self._input_workers, dataset_fn))
dataset_fn))
self._enable_get_next_as_optional = _enable_get_next_as_optional( self._enable_get_next_as_optional = _enable_get_next_as_optional(
self._strategy, element_spec) self._strategy, element_spec)
# When partial batch handling is enabled, always set the batch dimension to # When partial batch handling is enabled, always set the batch dimension to
@ -1239,11 +1265,10 @@ class DistributedDatasetsFromFunction(_IterableInput):
# out this change. # out this change.
enable_legacy_iterators = getattr(self._strategy, enable_legacy_iterators = getattr(self._strategy,
"_enable_legacy_iterators", False) "_enable_legacy_iterators", False)
iterators = _create_iterators_per_worker(self._datasets, iterators = _create_iterators_per_worker(self._datasets,
self._input_workers, self._input_workers,
enable_legacy_iterators) enable_legacy_iterators,
self._options)
if enable_legacy_iterators: if enable_legacy_iterators:
iterator = DistributedIteratorV1( iterator = DistributedIteratorV1(
self._input_workers, self._input_workers,
@ -1252,9 +1277,9 @@ class DistributedDatasetsFromFunction(_IterableInput):
enable_get_next_as_optional=self._enable_get_next_as_optional) enable_get_next_as_optional=self._enable_get_next_as_optional)
else: else:
iterator = DistributedIterator( iterator = DistributedIterator(
self._input_workers, input_workers=self._input_workers,
iterators, iterators=iterators,
self._strategy, strategy=self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional) enable_get_next_as_optional=self._enable_get_next_as_optional)
iterator._element_spec = self._element_spec # pylint: disable=protected-access iterator._element_spec = self._element_spec # pylint: disable=protected-access
@ -1495,7 +1520,7 @@ def _recover_shape_fn(data, value_structure):
class _SingleWorkerDatasetIteratorBase(object): class _SingleWorkerDatasetIteratorBase(object):
"""Iterator for a single `tf.data.Dataset`.""" """Iterator for a single `tf.data.Dataset`."""
def __init__(self, dataset, worker, devices): def __init__(self, dataset, worker, devices, options=None):
"""Create iterator for the `dataset` to fetch data to worker's `devices` . """Create iterator for the `dataset` to fetch data to worker's `devices` .
A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch
@ -1505,16 +1530,36 @@ class _SingleWorkerDatasetIteratorBase(object):
dataset: A `tf.data.Dataset` instance. dataset: A `tf.data.Dataset` instance.
worker: Worker on which ops should be created. worker: Worker on which ops should be created.
devices: Distribute data from `dataset` to these devices. devices: Distribute data from `dataset` to these devices.
options: options.
""" """
self._dataset = dataset self._dataset = dataset
self._worker = worker self._worker = worker
self._devices = devices self._devices = devices
self._element_spec = dataset.element_spec self._element_spec = dataset.element_spec
self._options = options
self._make_iterator() self._make_iterator()
def _make_iterator(self): def _make_iterator(self):
raise NotImplementedError("must be implemented in descendants") raise NotImplementedError("must be implemented in descendants")
def _format_data_list_with_options(self, data_list):
"""Change the data in to a list type if required.
The OwnedMultiDeviceIterator returns the list data type,
while the PER_REPLICA iterator (when used with prefetch disabled)
returns without the enclosed list. This is to fix the inconsistency.
Args:
data_list: data_list
Returns:
list
"""
if (self._options and self._options.experimental_replication_mode ==
InputReplicationMode.PER_REPLICA and
not self._options.experimental_prefetch_to_device):
return [data_list]
else:
return data_list
def get_next(self, device, name=None): def get_next(self, device, name=None):
"""Get next element for the given device.""" """Get next element for the given device."""
del name del name
@ -1536,7 +1581,7 @@ class _SingleWorkerDatasetIteratorBase(object):
""" """
del name del name
with ops.device(self._worker): with ops.device(self._worker):
return self._iterator.get_next() return self._format_data_list_with_options(self._iterator.get_next())
def get_next_as_list(self, name=None): def get_next_as_list(self, name=None):
"""Get next element from underlying iterator. """Get next element from underlying iterator.
@ -1556,7 +1601,8 @@ class _SingleWorkerDatasetIteratorBase(object):
""" """
del name del name
with ops.device(self._worker): with ops.device(self._worker):
data_list = self._iterator.get_next_as_optional() data_list = self._format_data_list_with_options(
self._iterator.get_next_as_optional())
result = [] result = []
for i, data in enumerate(data_list): for i, data in enumerate(data_list):
# Place the condition op in the same device as the data so the data # Place the condition op in the same device as the data so the data
@ -1636,8 +1682,13 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
composite_tensor.CompositeTensor): composite_tensor.CompositeTensor):
"""Iterator for a DistributedDataset instance.""" """Iterator for a DistributedDataset instance."""
def __init__(self, dataset=None, worker=None, devices=None, components=None, def __init__(self,
element_spec=None): dataset=None,
worker=None,
devices=None,
components=None,
element_spec=None,
options=None):
"""Create iterator for the `dataset` to fetch data to worker's `devices` . """Create iterator for the `dataset` to fetch data to worker's `devices` .
`OwnedMultiDeviceIterator` is used to prefetch input to the devices on the `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the
@ -1653,6 +1704,8 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
_SingleWorkerOwnedDatasetIterator from. _SingleWorkerOwnedDatasetIterator from.
element_spec: A nested structure of `TypeSpec` objects that represents the element_spec: A nested structure of `TypeSpec` objects that represents the
type specification of elements of the iterator. type specification of elements of the iterator.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
""" """
if worker is None or devices is None: if worker is None or devices is None:
raise ValueError("Both `worker` and `devices` should be provided") raise ValueError("Both `worker` and `devices` should be provided")
@ -1660,6 +1713,7 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
error_message = ("Either `dataset` or both `components` and `element_spec` " error_message = ("Either `dataset` or both `components` and `element_spec` "
"need to be provided.") "need to be provided.")
self._options = options
if dataset is None: if dataset is None:
if (components is None or element_spec is None): if (components is None or element_spec is None):
raise ValueError(error_message) raise ValueError(error_message)
@ -1670,18 +1724,25 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
else: else:
if (components is not None or element_spec is not None): if (components is not None or element_spec is not None):
raise ValueError(error_message) raise ValueError(error_message)
super(_SingleWorkerOwnedDatasetIterator, self).__init__(dataset, worker, super(_SingleWorkerOwnedDatasetIterator,
devices) self).__init__(dataset, worker, devices, options)
def _make_iterator(self): def _make_iterator(self):
"""Make appropriate iterator on the dataset.""" """Make appropriate iterator on the dataset."""
if not self._worker: if not self._worker:
raise ValueError("Worked device must be specified when creating an " raise ValueError("Worked device must be specified when creating an "
"owned iterator.") "owned iterator.")
host_device = device_util.get_host_for_device(self._worker) if (self._options is None or self._options.experimental_replication_mode ==
with ops.device(self._worker): InputReplicationMode.PER_WORKER or
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( (self._options.experimental_replication_mode == InputReplicationMode
self._dataset, self._devices, source_device=host_device) .PER_REPLICA and self._options.experimental_prefetch_to_device)):
host_device = device_util.get_host_for_device(self._worker)
with ops.device(self._worker):
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
self._dataset, self._devices, source_device=host_device)
else:
with ops.device(self._worker):
self._iterator = iter(self._dataset)
@property @property
def element_spec(self): def element_spec(self):
@ -1802,19 +1863,23 @@ class _SingleWorkerCallableIterator(object):
return [] return []
def _create_iterators_per_worker(worker_datasets, input_workers, def _create_iterators_per_worker(worker_datasets,
enable_legacy_iterators): input_workers,
enable_legacy_iterators,
options=None):
"""Create a multidevice iterator on each of the workers.""" """Create a multidevice iterator on each of the workers."""
assert isinstance(input_workers, InputWorkers) assert isinstance(input_workers, InputWorkers)
assert len(worker_datasets) == len(input_workers.worker_devices) assert len(worker_datasets) == len(input_workers.worker_devices)
iterators = [] iterators = []
for i, worker in enumerate(input_workers.worker_devices): for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker): with ops.device(worker):
worker_devices = input_workers.compute_devices_for_worker(i) worker_devices = input_workers.compute_devices_for_worker(i)
if tf2.enabled() and not enable_legacy_iterators: if tf2.enabled() and not enable_legacy_iterators:
iterator = _SingleWorkerOwnedDatasetIterator(worker_datasets[i], worker, iterator = _SingleWorkerOwnedDatasetIterator(
worker_devices) dataset=worker_datasets[i],
worker=worker,
devices=worker_devices,
options=options)
else: else:
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker, iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
worker_devices) worker_devices)
@ -1822,8 +1887,9 @@ def _create_iterators_per_worker(worker_datasets, input_workers,
return iterators return iterators
def _create_datasets_per_worker_with_input_context(input_contexts, def _create_datasets_from_function_with_input_context(input_contexts,
input_workers, dataset_fn): input_workers,
dataset_fn):
"""Create device datasets per worker given a dataset function.""" """Create device datasets per worker given a dataset function."""
datasets = [] datasets = []
for i, ctx in enumerate(input_contexts): for i, ctx in enumerate(input_contexts):

View File

@ -1421,5 +1421,198 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
input_context=distribution.extended._make_input_context()) input_context=distribution.extended._make_input_context())
class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase,
parameterized.TestCase):
"""Tests for PER_WORKER and PER_REPLICA's InputOptions variants."""
def setUp(self):
context._reset_context()
strategy_combinations.set_virtual_cpus_to_at_least(3)
super(DistributedIteratorPerDeviceTest, self).setUp()
@combinations.generate(
combinations.combine(
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
],
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
]))
def testDevicePlacementForPerWorkerValuesWithPrefetch(self, distribution,
input_options):
def dataset_fn(input_context): # pylint: disable=[unused-argument]
return dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
for x in ds:
assert x.values[0].device == distribution.extended.worker_devices[0]
assert x.values[0].backing_device == distribution.extended.worker_devices[
0]
assert x.values[1].device == distribution.extended.worker_devices[1]
assert x.values[1].backing_device == distribution.extended.worker_devices[
1]
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER)
],
mode=["eager"],
))
def testDevicePlacementForPerWorkerValuesWithoutPrefetch(
self, distribution, input_options):
def dataset_fn(input_context):
return dataset_ops.Dataset.from_tensor_slices(
np.full(4, input_context.input_pipeline_id))
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
for x in ds:
x = distribution.run(lambda inputs: inputs, args=(x,))
assert x.values[
0].device == "/job:localhost/replica:0/task:0/device:CPU:0"
assert x.values[
0].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
assert x.values[
1].device == "/job:localhost/replica:0/task:0/device:CPU:0"
assert x.values[
1].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
@combinations.generate(
combinations.combine(
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=True,
experimental_prefetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=True,
experimental_prefetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA)
],
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
]))
def testDevicePlacementForInvalidCombinations(self, distribution,
input_options):
def dataset_fn(input_context):
return dataset_ops.Dataset.from_tensor_slices(
np.full(4, input_context.input_pipeline_id))
with self.assertRaises(ValueError):
distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
@combinations.generate(
combinations.combine(
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER),
],
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
]))
def testOutputValuesForPerWorkerInputOptions(self, distribution,
input_options):
def dataset_fn(input_context):
return dataset_ops.Dataset.from_tensor_slices(
np.arange(1, 11).reshape(
(2, 5)) * (input_context.input_pipeline_id + 1))
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
# validating the values
x = next(iter(ds))
assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
@combinations.generate(
combinations.combine(
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=True,
experimental_prefetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
],
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
]))
def testOutputValuesForPerReplicaInputOptions(self, distribution,
input_options):
def dataset_fn(input_context):
return dataset_ops.Dataset.from_tensor_slices(
np.arange(1, 10) * (input_context.input_pipeline_id + 1))
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
expected = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
for i, x in enumerate(ds):
# validating the values
assert x.values[0].numpy() == expected[i]
assert x.values[1].numpy() == expected[i] * 2
loop_num = i
assert loop_num == len(expected) - 1
if __name__ == "__main__": if __name__ == "__main__":
test_util.main() test_util.main()

View File

@ -341,6 +341,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
self._devices = tuple(device_util.canonicalize(d) for d in devices) self._devices = tuple(device_util.canonicalize(d) for d in devices)
self._input_workers_devices = ( self._input_workers_devices = (
(device_util.canonicalize("/device:CPU:0", devices[0]), devices),) (device_util.canonicalize("/device:CPU:0", devices[0]), devices),)
self._inferred_cross_device_ops = None if self._cross_device_ops else ( self._inferred_cross_device_ops = None if self._cross_device_ops else (
cross_device_ops_lib.select_cross_device_ops(devices)) cross_device_ops_lib.select_cross_device_ops(devices))
self._host_input_device = numpy_dataset.SingleDevice( self._host_input_device = numpy_dataset.SingleDevice(
@ -396,12 +397,27 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
logging.info("Using MirroredStrategy with remote devices %r", devices) logging.info("Using MirroredStrategy with remote devices %r", devices)
def _input_workers_with_options(self, options=None): def _input_workers_with_options(self, options=None):
if not options or options.experimental_prefetch_to_device: if not options:
return input_lib.InputWorkers(self._input_workers_devices)
if (options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
if options.experimental_place_dataset_on_device:
self._input_workers_devices = (
tuple(
(device_util.canonicalize(d, d), (d,)) for d in self._devices))
else:
self._input_workers_devices = (
tuple((device_util.canonicalize("/device:CPU:0", d), (d,))
for d in self._devices))
return input_lib.InputWorkers(self._input_workers_devices) return input_lib.InputWorkers(self._input_workers_devices)
else: else:
return input_lib.InputWorkers( if not options.experimental_prefetch_to_device:
[(host_device, (host_device,) * len(compute_devices)) for return input_lib.InputWorkers([
host_device, compute_devices in self._input_workers_devices]) (host_device, (host_device,) * len(compute_devices))
for host_device, compute_devices in self._input_workers_devices
])
else:
return input_lib.InputWorkers(self._input_workers_devices)
@property @property
def _input_workers(self): def _input_workers(self):
@ -499,6 +515,13 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
self._container_strategy()) self._container_strategy())
def _experimental_distribute_dataset(self, dataset, options): def _experimental_distribute_dataset(self, dataset, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function`."
)
return input_lib.get_distributed_dataset( return input_lib.get_distributed_dataset(
dataset, dataset,
self._input_workers_with_options(options), self._input_workers_with_options(options),
@ -510,8 +533,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
numpy_input, self._host_input_device, session) numpy_input, self._host_input_device, session)
def _distribute_datasets_from_function(self, dataset_fn, options): def _distribute_datasets_from_function(self, dataset_fn, options):
input_contexts = []
input_workers = self._input_workers_with_options(options) input_workers = self._input_workers_with_options(options)
input_contexts = []
num_workers = input_workers.num_workers num_workers = input_workers.num_workers
for i in range(num_workers): for i in range(num_workers):
input_contexts.append(distribute_lib.InputContext( input_contexts.append(distribute_lib.InputContext(
@ -520,10 +543,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
num_replicas_in_sync=self._num_replicas_in_sync)) num_replicas_in_sync=self._num_replicas_in_sync))
return input_lib.get_distributed_datasets_from_function( return input_lib.get_distributed_datasets_from_function(
dataset_fn, dataset_fn, input_workers, input_contexts, self._container_strategy(),
input_workers, options)
input_contexts,
self._container_strategy())
def _experimental_distribute_values_from_function(self, value_fn): def _experimental_distribute_values_from_function(self, value_fn):
per_replica_values = [] per_replica_values = []

View File

@ -312,12 +312,26 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
def _experimental_distribute_dataset(self, dataset, options): def _experimental_distribute_dataset(self, dataset, options):
# Note that split_batch_by argument is not passed because it is always 1 in # Note that split_batch_by argument is not passed because it is always 1 in
# this strategy, and adding it adds unnecessary overhead to the dataset. # this strategy, and adding it adds unnecessary overhead to the dataset.
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function`."
)
return input_lib.get_distributed_dataset( return input_lib.get_distributed_dataset(
dataset, dataset,
self._input_workers_with_options(options), self._input_workers_with_options(options),
self._container_strategy()) self._container_strategy())
def _distribute_datasets_from_function(self, dataset_fn, options): def _distribute_datasets_from_function(self, dataset_fn, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function` "
"of tf.distribute.MirroredStrategy")
return input_lib.get_distributed_datasets_from_function( return input_lib.get_distributed_datasets_from_function(
dataset_fn, dataset_fn,
self._input_workers_with_options(options), self._input_workers_with_options(options),

View File

@ -119,12 +119,26 @@ class ParameterServerStrategy(distribute_lib.Strategy):
super(ParameterServerStrategy, self).__init__(extended) super(ParameterServerStrategy, self).__init__(extended)
def experimental_distribute_dataset(self, dataset, options=None): def experimental_distribute_dataset(self, dataset, options=None):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function`."
)
self._raise_pss_error_if_eager() self._raise_pss_error_if_eager()
super(ParameterServerStrategy, super(ParameterServerStrategy,
self).experimental_distribute_dataset(dataset=dataset, self).experimental_distribute_dataset(dataset=dataset,
options=options) options=options)
def distribute_datasets_from_function(self, dataset_fn, options=None): def distribute_datasets_from_function(self, dataset_fn, options=None):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function` "
"of tf.distribute.MirroredStrategy")
self._raise_pss_error_if_eager() self._raise_pss_error_if_eager()
super(ParameterServerStrategy, self).distribute_datasets_from_function( super(ParameterServerStrategy, self).distribute_datasets_from_function(
dataset_fn=dataset_fn, options=options) dataset_fn=dataset_fn, options=options)

View File

@ -803,6 +803,13 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
"distribution function.".format(path, type(spec))) "distribution function.".format(path, type(spec)))
def _experimental_distribute_dataset(self, dataset, options): def _experimental_distribute_dataset(self, dataset, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function`."
)
if options is None or options.experimental_prefetch_to_device: if options is None or options.experimental_prefetch_to_device:
self._check_spec(dataset.element_spec) self._check_spec(dataset.element_spec)
@ -813,6 +820,13 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
num_replicas_in_sync=self._num_replicas_in_sync) num_replicas_in_sync=self._num_replicas_in_sync)
def _distribute_datasets_from_function(self, dataset_fn, options): def _distribute_datasets_from_function(self, dataset_fn, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
" `experimental_distribute_datasets_from_function` "
"of tf.distribute.MirroredStrategy")
input_workers = self._get_input_workers(options) input_workers = self._get_input_workers(options)
input_contexts = [] input_contexts = []
num_workers = input_workers.num_workers num_workers = input_workers.num_workers

View File

@ -1,6 +1,10 @@
path: "tensorflow.distribute.InputReplicationMode" path: "tensorflow.distribute.InputReplicationMode"
tf_class { tf_class {
is_instance: "<enum \'InputReplicationMode\'>" is_instance: "<enum \'InputReplicationMode\'>"
member {
name: "PER_REPLICA"
mtype: "<enum \'InputReplicationMode\'>"
}
member { member {
name: "PER_WORKER" name: "PER_WORKER"
mtype: "<enum \'InputReplicationMode\'>" mtype: "<enum \'InputReplicationMode\'>"

View File

@ -3,10 +3,18 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputOptions\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputOptions\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputOptions\'>" is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputOptions\'>"
is_instance: "<type \'tuple\'>" is_instance: "<type \'tuple\'>"
member {
name: "experimental_place_dataset_on_device"
mtype: "<type \'property\'>"
}
member { member {
name: "experimental_prefetch_to_device" name: "experimental_prefetch_to_device"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "experimental_replication_mode"
mtype: "<type \'property\'>"
}
member_method { member_method {
name: "__init__" name: "__init__"
} }

View File

@ -1,6 +1,10 @@
path: "tensorflow.distribute.InputReplicationMode" path: "tensorflow.distribute.InputReplicationMode"
tf_class { tf_class {
is_instance: "<enum \'InputReplicationMode\'>" is_instance: "<enum \'InputReplicationMode\'>"
member {
name: "PER_REPLICA"
mtype: "<enum \'InputReplicationMode\'>"
}
member { member {
name: "PER_WORKER" name: "PER_WORKER"
mtype: "<enum \'InputReplicationMode\'>" mtype: "<enum \'InputReplicationMode\'>"