Merge pull request #38968 from kushanam:distribute_dali_ctl
PiperOrigin-RevId: 337869342 Change-Id: I3c34e90fe023dbefa8c66ac4331a251292ee547a
This commit is contained in:
commit
8e8c010e95
@ -102,6 +102,13 @@ class CentralStorageStrategy(distribute_lib.Strategy):
|
||||
Returns:
|
||||
A "distributed `Dataset`" that the caller can iterate over.
|
||||
"""
|
||||
if (options and options.experimental_replication_moden ==
|
||||
distribute_lib.InputReplicationMode.PER_REPLICA):
|
||||
raise NotImplementedError(
|
||||
'InputReplicationMode.PER_REPLICA '
|
||||
'is only supported in '
|
||||
'`experimental_distribute_datasets_from_function`.'
|
||||
)
|
||||
return super(CentralStorageStrategy, self).experimental_distribute_dataset(
|
||||
dataset, options)
|
||||
|
||||
|
@ -469,6 +469,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
return input_context
|
||||
|
||||
def _experimental_distribute_dataset(self, dataset, options):
|
||||
if (options and options.experimental_replication_mode ==
|
||||
distribute_lib.InputReplicationMode.PER_REPLICA):
|
||||
raise NotImplementedError(
|
||||
"InputReplicationMode.PER_REPLICA "
|
||||
"is only supported in "
|
||||
"`experimental_distribute_datasets_from_function`."
|
||||
)
|
||||
input_context = self._make_input_context()
|
||||
return input_lib.get_distributed_dataset(
|
||||
dataset,
|
||||
@ -478,6 +485,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
input_context=input_context)
|
||||
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
if (options and options.experimental_replication_mode ==
|
||||
distribute_lib.InputReplicationMode.PER_REPLICA):
|
||||
raise NotImplementedError(
|
||||
"InputReplicationMode.PER_REPLICA "
|
||||
"is only supported in "
|
||||
" `experimental_distribute_datasets_from_function` "
|
||||
"of tf.distribute.MirroredStrategy")
|
||||
input_context = self._make_input_context()
|
||||
return input_lib.get_distributed_datasets_from_function(
|
||||
dataset_fn=dataset_fn,
|
||||
|
@ -439,8 +439,12 @@ class InputReplicationMode(enum.Enum):
|
||||
Replicas will dequeue from the local Dataset on their worker.
|
||||
`tf.distribute.Strategy` doesn't manage any state sharing between such
|
||||
separate input pipelines.
|
||||
* `PER_REPLICA`: The input function will be called on each replica seperately.
|
||||
`tf.distribute.Strategy` doesn't manage any state sharing between such
|
||||
separate input pipelines.
|
||||
"""
|
||||
PER_WORKER = "PER_WORKER"
|
||||
PER_REPLICA = "PER_REPLICA"
|
||||
|
||||
|
||||
@tf_export("distribute.InputContext")
|
||||
@ -616,6 +620,8 @@ class RunOptions(
|
||||
class InputOptions(
|
||||
collections.namedtuple("InputOptions", [
|
||||
"experimental_prefetch_to_device",
|
||||
"experimental_replication_mode",
|
||||
"experimental_place_dataset_on_device",
|
||||
])):
|
||||
"""Run options for `experimental_distribute_dataset(s_from_function)`.
|
||||
|
||||
@ -633,19 +639,36 @@ class InputOptions(
|
||||
strategy.experimental_distribute_dataset(
|
||||
dataset,
|
||||
tf.distribute.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
experimental_replication_mode=
|
||||
experimental_replication_mode.PER_WORKER,
|
||||
experimental_place_dataset_on_device=False)))
|
||||
```
|
||||
|
||||
Attributes:
|
||||
experimental_prefetch_to_device: Boolean. Defaults to True. If True, dataset
|
||||
elements will be prefetched to accelerator device memory. When False,
|
||||
dataset elements are prefetched to host device memory. Must be False when
|
||||
using TPUEmbedding API.
|
||||
using TPUEmbedding API. experimental_prefetch_to_device can only be used
|
||||
with experimental_replication_mode=PER_WORKER
|
||||
experimental_replication_mode: Replication mode for the input function.
|
||||
Currently, the InputReplicationMode.PER_REPLICA is only supported with
|
||||
tf.distribute.MirroredStrategy.
|
||||
experimental_distribute_datasets_from_function.
|
||||
The default value is InputReplicationMode.PER_WORKER.
|
||||
experimental_place_dataset_on_device: Boolean. Default to False. When True,
|
||||
dataset will be placed on the device, otherwise it will remain on the
|
||||
host. experimental_place_dataset_on_device=True can only be used with
|
||||
experimental_replication_mode=PER_REPLICA
|
||||
"""
|
||||
|
||||
def __new__(cls, experimental_prefetch_to_device=True):
|
||||
return super(InputOptions, cls).__new__(cls,
|
||||
experimental_prefetch_to_device)
|
||||
def __new__(cls,
|
||||
experimental_prefetch_to_device=True,
|
||||
experimental_replication_mode=InputReplicationMode.PER_WORKER,
|
||||
experimental_place_dataset_on_device=False):
|
||||
return super(InputOptions,
|
||||
cls).__new__(cls, experimental_prefetch_to_device,
|
||||
experimental_replication_mode,
|
||||
experimental_place_dataset_on_device)
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Base classes for all distribution strategies.
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import input_ops
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.distribute.distribute_lib import InputReplicationMode
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -108,7 +109,8 @@ def get_distributed_dataset(dataset,
|
||||
def get_distributed_datasets_from_function(dataset_fn,
|
||||
input_workers,
|
||||
input_contexts,
|
||||
strategy):
|
||||
strategy,
|
||||
options=None):
|
||||
"""Returns a distributed dataset from the given input function.
|
||||
|
||||
This is a common function that is used by all strategies to return a
|
||||
@ -126,22 +128,43 @@ def get_distributed_datasets_from_function(dataset_fn,
|
||||
`worker_device_pairs`.
|
||||
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
|
||||
handle last partial batch.
|
||||
options: Default is None. `tf.distribute.InputOptions` used to control
|
||||
options on how this dataset is distributed.
|
||||
|
||||
Returns:
|
||||
A distributed dataset instance.
|
||||
|
||||
Raises:
|
||||
ValueError: if `options.experimental_replication_mode` and
|
||||
`options.experimental_place_dataset_on_device` are not consistent
|
||||
"""
|
||||
if (options is not None and
|
||||
options.experimental_replication_mode != InputReplicationMode.PER_REPLICA
|
||||
and options.experimental_place_dataset_on_device):
|
||||
raise ValueError(
|
||||
"When `experimental_place_dataset_on_device` is set for dataset "
|
||||
"placement, you must also specify `PER_REPLICA` for the "
|
||||
"replication mode")
|
||||
|
||||
if (options is not None and
|
||||
options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
|
||||
and options.experimental_prefetch_to_device and
|
||||
options.experimental_place_dataset_on_device):
|
||||
raise ValueError(
|
||||
"`experimental_place_dataset_on_device` can not be set to True "
|
||||
"when experimental_prefetch_to_device is True and "
|
||||
"replication mode is set to `PER_REPLICA`")
|
||||
|
||||
if tf2.enabled():
|
||||
return DistributedDatasetsFromFunction(
|
||||
dataset_fn,
|
||||
input_workers,
|
||||
input_contexts,
|
||||
strategy)
|
||||
return DistributedDatasetsFromFunction(dataset_fn, input_workers,
|
||||
input_contexts, strategy, options)
|
||||
else:
|
||||
return DistributedDatasetsFromFunctionV1(
|
||||
dataset_fn,
|
||||
input_workers,
|
||||
input_contexts,
|
||||
strategy)
|
||||
strategy,
|
||||
options)
|
||||
|
||||
|
||||
@tf_export("distribute.DistributedIterator", v1=[])
|
||||
@ -1188,7 +1211,8 @@ class DistributedDatasetV1(DistributedDataset):
|
||||
class DistributedDatasetsFromFunction(_IterableInput):
|
||||
"""Inputs created from dataset function."""
|
||||
|
||||
def __init__(self, dataset_fn, input_workers, input_contexts, strategy):
|
||||
def __init__(self, dataset_fn, input_workers, input_contexts, strategy,
|
||||
options):
|
||||
"""Makes an iterable from datasets created by the given function.
|
||||
|
||||
Args:
|
||||
@ -1199,6 +1223,8 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
||||
`worker_device_pairs`.
|
||||
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
|
||||
handle last partial batch.
|
||||
options: `tf.distribute.InputOptions` used to control options on how this
|
||||
dataset is distributed.
|
||||
"""
|
||||
super(DistributedDatasetsFromFunction, self).__init__(
|
||||
input_workers=input_workers)
|
||||
@ -1212,10 +1238,10 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
||||
self._input_workers = input_workers
|
||||
self._input_contexts = input_contexts
|
||||
self._strategy = strategy
|
||||
self._options = options
|
||||
self._datasets, element_spec = (
|
||||
_create_datasets_per_worker_with_input_context(self._input_contexts,
|
||||
self._input_workers,
|
||||
dataset_fn))
|
||||
_create_datasets_from_function_with_input_context(
|
||||
self._input_contexts, self._input_workers, dataset_fn))
|
||||
self._enable_get_next_as_optional = _enable_get_next_as_optional(
|
||||
self._strategy, element_spec)
|
||||
# When partial batch handling is enabled, always set the batch dimension to
|
||||
@ -1239,11 +1265,10 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
||||
# out this change.
|
||||
enable_legacy_iterators = getattr(self._strategy,
|
||||
"_enable_legacy_iterators", False)
|
||||
|
||||
iterators = _create_iterators_per_worker(self._datasets,
|
||||
self._input_workers,
|
||||
enable_legacy_iterators)
|
||||
|
||||
enable_legacy_iterators,
|
||||
self._options)
|
||||
if enable_legacy_iterators:
|
||||
iterator = DistributedIteratorV1(
|
||||
self._input_workers,
|
||||
@ -1252,9 +1277,9 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
||||
else:
|
||||
iterator = DistributedIterator(
|
||||
self._input_workers,
|
||||
iterators,
|
||||
self._strategy,
|
||||
input_workers=self._input_workers,
|
||||
iterators=iterators,
|
||||
strategy=self._strategy,
|
||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
||||
iterator._element_spec = self._element_spec # pylint: disable=protected-access
|
||||
|
||||
@ -1495,7 +1520,7 @@ def _recover_shape_fn(data, value_structure):
|
||||
class _SingleWorkerDatasetIteratorBase(object):
|
||||
"""Iterator for a single `tf.data.Dataset`."""
|
||||
|
||||
def __init__(self, dataset, worker, devices):
|
||||
def __init__(self, dataset, worker, devices, options=None):
|
||||
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
|
||||
|
||||
A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch
|
||||
@ -1505,16 +1530,36 @@ class _SingleWorkerDatasetIteratorBase(object):
|
||||
dataset: A `tf.data.Dataset` instance.
|
||||
worker: Worker on which ops should be created.
|
||||
devices: Distribute data from `dataset` to these devices.
|
||||
options: options.
|
||||
"""
|
||||
self._dataset = dataset
|
||||
self._worker = worker
|
||||
self._devices = devices
|
||||
self._element_spec = dataset.element_spec
|
||||
self._options = options
|
||||
self._make_iterator()
|
||||
|
||||
def _make_iterator(self):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def _format_data_list_with_options(self, data_list):
|
||||
"""Change the data in to a list type if required.
|
||||
|
||||
The OwnedMultiDeviceIterator returns the list data type,
|
||||
while the PER_REPLICA iterator (when used with prefetch disabled)
|
||||
returns without the enclosed list. This is to fix the inconsistency.
|
||||
Args:
|
||||
data_list: data_list
|
||||
Returns:
|
||||
list
|
||||
"""
|
||||
if (self._options and self._options.experimental_replication_mode ==
|
||||
InputReplicationMode.PER_REPLICA and
|
||||
not self._options.experimental_prefetch_to_device):
|
||||
return [data_list]
|
||||
else:
|
||||
return data_list
|
||||
|
||||
def get_next(self, device, name=None):
|
||||
"""Get next element for the given device."""
|
||||
del name
|
||||
@ -1536,7 +1581,7 @@ class _SingleWorkerDatasetIteratorBase(object):
|
||||
"""
|
||||
del name
|
||||
with ops.device(self._worker):
|
||||
return self._iterator.get_next()
|
||||
return self._format_data_list_with_options(self._iterator.get_next())
|
||||
|
||||
def get_next_as_list(self, name=None):
|
||||
"""Get next element from underlying iterator.
|
||||
@ -1556,7 +1601,8 @@ class _SingleWorkerDatasetIteratorBase(object):
|
||||
"""
|
||||
del name
|
||||
with ops.device(self._worker):
|
||||
data_list = self._iterator.get_next_as_optional()
|
||||
data_list = self._format_data_list_with_options(
|
||||
self._iterator.get_next_as_optional())
|
||||
result = []
|
||||
for i, data in enumerate(data_list):
|
||||
# Place the condition op in the same device as the data so the data
|
||||
@ -1636,8 +1682,13 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
|
||||
composite_tensor.CompositeTensor):
|
||||
"""Iterator for a DistributedDataset instance."""
|
||||
|
||||
def __init__(self, dataset=None, worker=None, devices=None, components=None,
|
||||
element_spec=None):
|
||||
def __init__(self,
|
||||
dataset=None,
|
||||
worker=None,
|
||||
devices=None,
|
||||
components=None,
|
||||
element_spec=None,
|
||||
options=None):
|
||||
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
|
||||
|
||||
`OwnedMultiDeviceIterator` is used to prefetch input to the devices on the
|
||||
@ -1653,6 +1704,8 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
|
||||
_SingleWorkerOwnedDatasetIterator from.
|
||||
element_spec: A nested structure of `TypeSpec` objects that represents the
|
||||
type specification of elements of the iterator.
|
||||
options: `tf.distribute.InputOptions` used to control options on how this
|
||||
dataset is distributed.
|
||||
"""
|
||||
if worker is None or devices is None:
|
||||
raise ValueError("Both `worker` and `devices` should be provided")
|
||||
@ -1660,6 +1713,7 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
|
||||
error_message = ("Either `dataset` or both `components` and `element_spec` "
|
||||
"need to be provided.")
|
||||
|
||||
self._options = options
|
||||
if dataset is None:
|
||||
if (components is None or element_spec is None):
|
||||
raise ValueError(error_message)
|
||||
@ -1670,18 +1724,25 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
|
||||
else:
|
||||
if (components is not None or element_spec is not None):
|
||||
raise ValueError(error_message)
|
||||
super(_SingleWorkerOwnedDatasetIterator, self).__init__(dataset, worker,
|
||||
devices)
|
||||
super(_SingleWorkerOwnedDatasetIterator,
|
||||
self).__init__(dataset, worker, devices, options)
|
||||
|
||||
def _make_iterator(self):
|
||||
"""Make appropriate iterator on the dataset."""
|
||||
if not self._worker:
|
||||
raise ValueError("Worked device must be specified when creating an "
|
||||
"owned iterator.")
|
||||
host_device = device_util.get_host_for_device(self._worker)
|
||||
with ops.device(self._worker):
|
||||
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
|
||||
self._dataset, self._devices, source_device=host_device)
|
||||
if (self._options is None or self._options.experimental_replication_mode ==
|
||||
InputReplicationMode.PER_WORKER or
|
||||
(self._options.experimental_replication_mode == InputReplicationMode
|
||||
.PER_REPLICA and self._options.experimental_prefetch_to_device)):
|
||||
host_device = device_util.get_host_for_device(self._worker)
|
||||
with ops.device(self._worker):
|
||||
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
|
||||
self._dataset, self._devices, source_device=host_device)
|
||||
else:
|
||||
with ops.device(self._worker):
|
||||
self._iterator = iter(self._dataset)
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
@ -1802,19 +1863,23 @@ class _SingleWorkerCallableIterator(object):
|
||||
return []
|
||||
|
||||
|
||||
def _create_iterators_per_worker(worker_datasets, input_workers,
|
||||
enable_legacy_iterators):
|
||||
def _create_iterators_per_worker(worker_datasets,
|
||||
input_workers,
|
||||
enable_legacy_iterators,
|
||||
options=None):
|
||||
"""Create a multidevice iterator on each of the workers."""
|
||||
assert isinstance(input_workers, InputWorkers)
|
||||
|
||||
assert len(worker_datasets) == len(input_workers.worker_devices)
|
||||
iterators = []
|
||||
for i, worker in enumerate(input_workers.worker_devices):
|
||||
with ops.device(worker):
|
||||
worker_devices = input_workers.compute_devices_for_worker(i)
|
||||
if tf2.enabled() and not enable_legacy_iterators:
|
||||
iterator = _SingleWorkerOwnedDatasetIterator(worker_datasets[i], worker,
|
||||
worker_devices)
|
||||
iterator = _SingleWorkerOwnedDatasetIterator(
|
||||
dataset=worker_datasets[i],
|
||||
worker=worker,
|
||||
devices=worker_devices,
|
||||
options=options)
|
||||
else:
|
||||
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
|
||||
worker_devices)
|
||||
@ -1822,8 +1887,9 @@ def _create_iterators_per_worker(worker_datasets, input_workers,
|
||||
return iterators
|
||||
|
||||
|
||||
def _create_datasets_per_worker_with_input_context(input_contexts,
|
||||
input_workers, dataset_fn):
|
||||
def _create_datasets_from_function_with_input_context(input_contexts,
|
||||
input_workers,
|
||||
dataset_fn):
|
||||
"""Create device datasets per worker given a dataset function."""
|
||||
datasets = []
|
||||
for i, ctx in enumerate(input_contexts):
|
||||
|
@ -1421,5 +1421,198 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||
input_context=distribution.extended._make_input_context())
|
||||
|
||||
|
||||
class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase,
|
||||
parameterized.TestCase):
|
||||
"""Tests for PER_WORKER and PER_REPLICA's InputOptions variants."""
|
||||
|
||||
def setUp(self):
|
||||
context._reset_context()
|
||||
strategy_combinations.set_virtual_cpus_to_at_least(3)
|
||||
super(DistributedIteratorPerDeviceTest, self).setUp()
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
input_options=[
|
||||
distribute_lib.InputOptions(
|
||||
experimental_place_dataset_on_device=False,
|
||||
experimental_prefetch_to_device=True,
|
||||
experimental_replication_mode=distribute_lib
|
||||
.InputReplicationMode.PER_WORKER),
|
||||
distribute_lib.InputOptions(
|
||||
experimental_place_dataset_on_device=False,
|
||||
experimental_prefetch_to_device=True,
|
||||
experimental_replication_mode=distribute_lib
|
||||
.InputReplicationMode.PER_REPLICA),
|
||||
],
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
]))
|
||||
def testDevicePlacementForPerWorkerValuesWithPrefetch(self, distribution,
|
||||
input_options):
|
||||
|
||||
def dataset_fn(input_context): # pylint: disable=[unused-argument]
|
||||
return dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn, input_options)
|
||||
|
||||
for x in ds:
|
||||
assert x.values[0].device == distribution.extended.worker_devices[0]
|
||||
assert x.values[0].backing_device == distribution.extended.worker_devices[
|
||||
0]
|
||||
assert x.values[1].device == distribution.extended.worker_devices[1]
|
||||
assert x.values[1].backing_device == distribution.extended.worker_devices[
|
||||
1]
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
],
|
||||
input_options=[
|
||||
distribute_lib.InputOptions(
|
||||
experimental_place_dataset_on_device=False,
|
||||
experimental_prefetch_to_device=False,
|
||||
experimental_replication_mode=distribute_lib
|
||||
.InputReplicationMode.PER_WORKER)
|
||||
],
|
||||
mode=["eager"],
|
||||
))
|
||||
def testDevicePlacementForPerWorkerValuesWithoutPrefetch(
|
||||
self, distribution, input_options):
|
||||
|
||||
def dataset_fn(input_context):
|
||||
return dataset_ops.Dataset.from_tensor_slices(
|
||||
np.full(4, input_context.input_pipeline_id))
|
||||
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn, input_options)
|
||||
|
||||
for x in ds:
|
||||
x = distribution.run(lambda inputs: inputs, args=(x,))
|
||||
assert x.values[
|
||||
0].device == "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
assert x.values[
|
||||
0].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
assert x.values[
|
||||
1].device == "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
assert x.values[
|
||||
1].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
input_options=[
|
||||
distribute_lib.InputOptions(
|
||||
experimental_place_dataset_on_device=True,
|
||||
experimental_prefetch_to_device=False,
|
||||
experimental_replication_mode=distribute_lib
|
||||
.InputReplicationMode.PER_WORKER),
|
||||
distribute_lib.InputOptions(
|
||||
experimental_place_dataset_on_device=True,
|
||||
experimental_prefetch_to_device=True,
|
||||
experimental_replication_mode=distribute_lib
|
||||
.InputReplicationMode.PER_REPLICA)
|
||||
],
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
]))
|
||||
def testDevicePlacementForInvalidCombinations(self, distribution,
|
||||
input_options):
|
||||
|
||||
def dataset_fn(input_context):
|
||||
return dataset_ops.Dataset.from_tensor_slices(
|
||||
np.full(4, input_context.input_pipeline_id))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn, input_options)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
input_options=[
|
||||
distribute_lib.InputOptions(
|
||||
experimental_place_dataset_on_device=False,
|
||||
experimental_prefetch_to_device=False,
|
||||
experimental_replication_mode=distribute_lib
|
||||
.InputReplicationMode.PER_WORKER),
|
||||
distribute_lib.InputOptions(
|
||||
experimental_place_dataset_on_device=False,
|
||||
experimental_prefetch_to_device=True,
|
||||
experimental_replication_mode=distribute_lib
|
||||
.InputReplicationMode.PER_WORKER),
|
||||
],
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
]))
|
||||
def testOutputValuesForPerWorkerInputOptions(self, distribution,
|
||||
input_options):
|
||||
|
||||
def dataset_fn(input_context):
|
||||
return dataset_ops.Dataset.from_tensor_slices(
|
||||
np.arange(1, 11).reshape(
|
||||
(2, 5)) * (input_context.input_pipeline_id + 1))
|
||||
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn, input_options)
|
||||
|
||||
# validating the values
|
||||
x = next(iter(ds))
|
||||
assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
|
||||
assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
input_options=[
|
||||
distribute_lib.InputOptions(
|
||||
experimental_place_dataset_on_device=True,
|
||||
experimental_prefetch_to_device=False,
|
||||
experimental_replication_mode=distribute_lib
|
||||
.InputReplicationMode.PER_REPLICA),
|
||||
distribute_lib.InputOptions(
|
||||
experimental_place_dataset_on_device=False,
|
||||
experimental_prefetch_to_device=False,
|
||||
experimental_replication_mode=distribute_lib
|
||||
.InputReplicationMode.PER_REPLICA),
|
||||
distribute_lib.InputOptions(
|
||||
experimental_place_dataset_on_device=False,
|
||||
experimental_prefetch_to_device=True,
|
||||
experimental_replication_mode=distribute_lib
|
||||
.InputReplicationMode.PER_REPLICA),
|
||||
],
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
]))
|
||||
def testOutputValuesForPerReplicaInputOptions(self, distribution,
|
||||
input_options):
|
||||
|
||||
def dataset_fn(input_context):
|
||||
return dataset_ops.Dataset.from_tensor_slices(
|
||||
np.arange(1, 10) * (input_context.input_pipeline_id + 1))
|
||||
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn, input_options)
|
||||
expected = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
for i, x in enumerate(ds):
|
||||
# validating the values
|
||||
assert x.values[0].numpy() == expected[i]
|
||||
assert x.values[1].numpy() == expected[i] * 2
|
||||
loop_num = i
|
||||
assert loop_num == len(expected) - 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_util.main()
|
||||
|
@ -341,6 +341,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
self._devices = tuple(device_util.canonicalize(d) for d in devices)
|
||||
self._input_workers_devices = (
|
||||
(device_util.canonicalize("/device:CPU:0", devices[0]), devices),)
|
||||
|
||||
self._inferred_cross_device_ops = None if self._cross_device_ops else (
|
||||
cross_device_ops_lib.select_cross_device_ops(devices))
|
||||
self._host_input_device = numpy_dataset.SingleDevice(
|
||||
@ -396,12 +397,27 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
logging.info("Using MirroredStrategy with remote devices %r", devices)
|
||||
|
||||
def _input_workers_with_options(self, options=None):
|
||||
if not options or options.experimental_prefetch_to_device:
|
||||
if not options:
|
||||
return input_lib.InputWorkers(self._input_workers_devices)
|
||||
if (options.experimental_replication_mode ==
|
||||
distribute_lib.InputReplicationMode.PER_REPLICA):
|
||||
if options.experimental_place_dataset_on_device:
|
||||
self._input_workers_devices = (
|
||||
tuple(
|
||||
(device_util.canonicalize(d, d), (d,)) for d in self._devices))
|
||||
else:
|
||||
self._input_workers_devices = (
|
||||
tuple((device_util.canonicalize("/device:CPU:0", d), (d,))
|
||||
for d in self._devices))
|
||||
return input_lib.InputWorkers(self._input_workers_devices)
|
||||
else:
|
||||
return input_lib.InputWorkers(
|
||||
[(host_device, (host_device,) * len(compute_devices)) for
|
||||
host_device, compute_devices in self._input_workers_devices])
|
||||
if not options.experimental_prefetch_to_device:
|
||||
return input_lib.InputWorkers([
|
||||
(host_device, (host_device,) * len(compute_devices))
|
||||
for host_device, compute_devices in self._input_workers_devices
|
||||
])
|
||||
else:
|
||||
return input_lib.InputWorkers(self._input_workers_devices)
|
||||
|
||||
@property
|
||||
def _input_workers(self):
|
||||
@ -499,6 +515,13 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
self._container_strategy())
|
||||
|
||||
def _experimental_distribute_dataset(self, dataset, options):
|
||||
if (options and options.experimental_replication_mode ==
|
||||
distribute_lib.InputReplicationMode.PER_REPLICA):
|
||||
raise NotImplementedError(
|
||||
"InputReplicationMode.PER_REPLICA "
|
||||
"is only supported in "
|
||||
"`experimental_distribute_datasets_from_function`."
|
||||
)
|
||||
return input_lib.get_distributed_dataset(
|
||||
dataset,
|
||||
self._input_workers_with_options(options),
|
||||
@ -510,8 +533,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
numpy_input, self._host_input_device, session)
|
||||
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
input_contexts = []
|
||||
input_workers = self._input_workers_with_options(options)
|
||||
input_contexts = []
|
||||
num_workers = input_workers.num_workers
|
||||
for i in range(num_workers):
|
||||
input_contexts.append(distribute_lib.InputContext(
|
||||
@ -520,10 +543,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
num_replicas_in_sync=self._num_replicas_in_sync))
|
||||
|
||||
return input_lib.get_distributed_datasets_from_function(
|
||||
dataset_fn,
|
||||
input_workers,
|
||||
input_contexts,
|
||||
self._container_strategy())
|
||||
dataset_fn, input_workers, input_contexts, self._container_strategy(),
|
||||
options)
|
||||
|
||||
def _experimental_distribute_values_from_function(self, value_fn):
|
||||
per_replica_values = []
|
||||
|
@ -312,12 +312,26 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
||||
def _experimental_distribute_dataset(self, dataset, options):
|
||||
# Note that split_batch_by argument is not passed because it is always 1 in
|
||||
# this strategy, and adding it adds unnecessary overhead to the dataset.
|
||||
if (options and options.experimental_replication_mode ==
|
||||
distribute_lib.InputReplicationMode.PER_REPLICA):
|
||||
raise NotImplementedError(
|
||||
"InputReplicationMode.PER_REPLICA "
|
||||
"is only supported in "
|
||||
"`experimental_distribute_datasets_from_function`."
|
||||
)
|
||||
return input_lib.get_distributed_dataset(
|
||||
dataset,
|
||||
self._input_workers_with_options(options),
|
||||
self._container_strategy())
|
||||
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
if (options and options.experimental_replication_mode ==
|
||||
distribute_lib.InputReplicationMode.PER_REPLICA):
|
||||
raise NotImplementedError(
|
||||
"InputReplicationMode.PER_REPLICA "
|
||||
"is only supported in "
|
||||
"`experimental_distribute_datasets_from_function` "
|
||||
"of tf.distribute.MirroredStrategy")
|
||||
return input_lib.get_distributed_datasets_from_function(
|
||||
dataset_fn,
|
||||
self._input_workers_with_options(options),
|
||||
|
@ -119,12 +119,26 @@ class ParameterServerStrategy(distribute_lib.Strategy):
|
||||
super(ParameterServerStrategy, self).__init__(extended)
|
||||
|
||||
def experimental_distribute_dataset(self, dataset, options=None):
|
||||
if (options and options.experimental_replication_mode ==
|
||||
distribute_lib.InputReplicationMode.PER_REPLICA):
|
||||
raise NotImplementedError(
|
||||
"InputReplicationMode.PER_REPLICA "
|
||||
"is only supported in "
|
||||
"`experimental_distribute_datasets_from_function`."
|
||||
)
|
||||
self._raise_pss_error_if_eager()
|
||||
super(ParameterServerStrategy,
|
||||
self).experimental_distribute_dataset(dataset=dataset,
|
||||
options=options)
|
||||
|
||||
def distribute_datasets_from_function(self, dataset_fn, options=None):
|
||||
if (options and options.experimental_replication_mode ==
|
||||
distribute_lib.InputReplicationMode.PER_REPLICA):
|
||||
raise NotImplementedError(
|
||||
"InputReplicationMode.PER_REPLICA "
|
||||
"is only supported in "
|
||||
"`experimental_distribute_datasets_from_function` "
|
||||
"of tf.distribute.MirroredStrategy")
|
||||
self._raise_pss_error_if_eager()
|
||||
super(ParameterServerStrategy, self).distribute_datasets_from_function(
|
||||
dataset_fn=dataset_fn, options=options)
|
||||
|
@ -803,6 +803,13 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
"distribution function.".format(path, type(spec)))
|
||||
|
||||
def _experimental_distribute_dataset(self, dataset, options):
|
||||
if (options and options.experimental_replication_mode ==
|
||||
distribute_lib.InputReplicationMode.PER_REPLICA):
|
||||
raise NotImplementedError(
|
||||
"InputReplicationMode.PER_REPLICA "
|
||||
"is only supported in "
|
||||
"`experimental_distribute_datasets_from_function`."
|
||||
)
|
||||
if options is None or options.experimental_prefetch_to_device:
|
||||
self._check_spec(dataset.element_spec)
|
||||
|
||||
@ -813,6 +820,13 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
if (options and options.experimental_replication_mode ==
|
||||
distribute_lib.InputReplicationMode.PER_REPLICA):
|
||||
raise NotImplementedError(
|
||||
"InputReplicationMode.PER_REPLICA "
|
||||
"is only supported in "
|
||||
" `experimental_distribute_datasets_from_function` "
|
||||
"of tf.distribute.MirroredStrategy")
|
||||
input_workers = self._get_input_workers(options)
|
||||
input_contexts = []
|
||||
num_workers = input_workers.num_workers
|
||||
|
@ -1,6 +1,10 @@
|
||||
path: "tensorflow.distribute.InputReplicationMode"
|
||||
tf_class {
|
||||
is_instance: "<enum \'InputReplicationMode\'>"
|
||||
member {
|
||||
name: "PER_REPLICA"
|
||||
mtype: "<enum \'InputReplicationMode\'>"
|
||||
}
|
||||
member {
|
||||
name: "PER_WORKER"
|
||||
mtype: "<enum \'InputReplicationMode\'>"
|
||||
|
@ -3,10 +3,18 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputOptions\'>"
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputOptions\'>"
|
||||
is_instance: "<type \'tuple\'>"
|
||||
member {
|
||||
name: "experimental_place_dataset_on_device"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_prefetch_to_device"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_replication_mode"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
|
@ -1,6 +1,10 @@
|
||||
path: "tensorflow.distribute.InputReplicationMode"
|
||||
tf_class {
|
||||
is_instance: "<enum \'InputReplicationMode\'>"
|
||||
member {
|
||||
name: "PER_REPLICA"
|
||||
mtype: "<enum \'InputReplicationMode\'>"
|
||||
}
|
||||
member {
|
||||
name: "PER_WORKER"
|
||||
mtype: "<enum \'InputReplicationMode\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user