Merge pull request #38968 from kushanam:distribute_dali_ctl

PiperOrigin-RevId: 337869342
Change-Id: I3c34e90fe023dbefa8c66ac4331a251292ee547a
This commit is contained in:
TensorFlower Gardener 2020-10-19 09:25:22 -07:00
commit 8e8c010e95
12 changed files with 431 additions and 49 deletions

View File

@ -102,6 +102,13 @@ class CentralStorageStrategy(distribute_lib.Strategy):
Returns:
A "distributed `Dataset`" that the caller can iterate over.
"""
if (options and options.experimental_replication_moden ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
'InputReplicationMode.PER_REPLICA '
'is only supported in '
'`experimental_distribute_datasets_from_function`.'
)
return super(CentralStorageStrategy, self).experimental_distribute_dataset(
dataset, options)

View File

@ -469,6 +469,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
return input_context
def _experimental_distribute_dataset(self, dataset, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function`."
)
input_context = self._make_input_context()
return input_lib.get_distributed_dataset(
dataset,
@ -478,6 +485,13 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
input_context=input_context)
def _distribute_datasets_from_function(self, dataset_fn, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
" `experimental_distribute_datasets_from_function` "
"of tf.distribute.MirroredStrategy")
input_context = self._make_input_context()
return input_lib.get_distributed_datasets_from_function(
dataset_fn=dataset_fn,

View File

@ -439,8 +439,12 @@ class InputReplicationMode(enum.Enum):
Replicas will dequeue from the local Dataset on their worker.
`tf.distribute.Strategy` doesn't manage any state sharing between such
separate input pipelines.
* `PER_REPLICA`: The input function will be called on each replica seperately.
`tf.distribute.Strategy` doesn't manage any state sharing between such
separate input pipelines.
"""
PER_WORKER = "PER_WORKER"
PER_REPLICA = "PER_REPLICA"
@tf_export("distribute.InputContext")
@ -616,6 +620,8 @@ class RunOptions(
class InputOptions(
collections.namedtuple("InputOptions", [
"experimental_prefetch_to_device",
"experimental_replication_mode",
"experimental_place_dataset_on_device",
])):
"""Run options for `experimental_distribute_dataset(s_from_function)`.
@ -633,19 +639,36 @@ class InputOptions(
strategy.experimental_distribute_dataset(
dataset,
tf.distribute.InputOptions(
experimental_prefetch_to_device=False)))
experimental_replication_mode=
experimental_replication_mode.PER_WORKER,
experimental_place_dataset_on_device=False)))
```
Attributes:
experimental_prefetch_to_device: Boolean. Defaults to True. If True, dataset
elements will be prefetched to accelerator device memory. When False,
dataset elements are prefetched to host device memory. Must be False when
using TPUEmbedding API.
using TPUEmbedding API. experimental_prefetch_to_device can only be used
with experimental_replication_mode=PER_WORKER
experimental_replication_mode: Replication mode for the input function.
Currently, the InputReplicationMode.PER_REPLICA is only supported with
tf.distribute.MirroredStrategy.
experimental_distribute_datasets_from_function.
The default value is InputReplicationMode.PER_WORKER.
experimental_place_dataset_on_device: Boolean. Default to False. When True,
dataset will be placed on the device, otherwise it will remain on the
host. experimental_place_dataset_on_device=True can only be used with
experimental_replication_mode=PER_REPLICA
"""
def __new__(cls, experimental_prefetch_to_device=True):
return super(InputOptions, cls).__new__(cls,
experimental_prefetch_to_device)
def __new__(cls,
experimental_prefetch_to_device=True,
experimental_replication_mode=InputReplicationMode.PER_WORKER,
experimental_place_dataset_on_device=False):
return super(InputOptions,
cls).__new__(cls, experimental_prefetch_to_device,
experimental_replication_mode,
experimental_place_dataset_on_device)
# ------------------------------------------------------------------------------
# Base classes for all distribution strategies.

View File

@ -35,6 +35,7 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.distribute.distribute_lib import InputReplicationMode
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
@ -108,7 +109,8 @@ def get_distributed_dataset(dataset,
def get_distributed_datasets_from_function(dataset_fn,
input_workers,
input_contexts,
strategy):
strategy,
options=None):
"""Returns a distributed dataset from the given input function.
This is a common function that is used by all strategies to return a
@ -126,22 +128,43 @@ def get_distributed_datasets_from_function(dataset_fn,
`worker_device_pairs`.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
options: Default is None. `tf.distribute.InputOptions` used to control
options on how this dataset is distributed.
Returns:
A distributed dataset instance.
Raises:
ValueError: if `options.experimental_replication_mode` and
`options.experimental_place_dataset_on_device` are not consistent
"""
if (options is not None and
options.experimental_replication_mode != InputReplicationMode.PER_REPLICA
and options.experimental_place_dataset_on_device):
raise ValueError(
"When `experimental_place_dataset_on_device` is set for dataset "
"placement, you must also specify `PER_REPLICA` for the "
"replication mode")
if (options is not None and
options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
and options.experimental_prefetch_to_device and
options.experimental_place_dataset_on_device):
raise ValueError(
"`experimental_place_dataset_on_device` can not be set to True "
"when experimental_prefetch_to_device is True and "
"replication mode is set to `PER_REPLICA`")
if tf2.enabled():
return DistributedDatasetsFromFunction(
dataset_fn,
input_workers,
input_contexts,
strategy)
return DistributedDatasetsFromFunction(dataset_fn, input_workers,
input_contexts, strategy, options)
else:
return DistributedDatasetsFromFunctionV1(
dataset_fn,
input_workers,
input_contexts,
strategy)
strategy,
options)
@tf_export("distribute.DistributedIterator", v1=[])
@ -1188,7 +1211,8 @@ class DistributedDatasetV1(DistributedDataset):
class DistributedDatasetsFromFunction(_IterableInput):
"""Inputs created from dataset function."""
def __init__(self, dataset_fn, input_workers, input_contexts, strategy):
def __init__(self, dataset_fn, input_workers, input_contexts, strategy,
options):
"""Makes an iterable from datasets created by the given function.
Args:
@ -1199,6 +1223,8 @@ class DistributedDatasetsFromFunction(_IterableInput):
`worker_device_pairs`.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
"""
super(DistributedDatasetsFromFunction, self).__init__(
input_workers=input_workers)
@ -1212,10 +1238,10 @@ class DistributedDatasetsFromFunction(_IterableInput):
self._input_workers = input_workers
self._input_contexts = input_contexts
self._strategy = strategy
self._options = options
self._datasets, element_spec = (
_create_datasets_per_worker_with_input_context(self._input_contexts,
self._input_workers,
dataset_fn))
_create_datasets_from_function_with_input_context(
self._input_contexts, self._input_workers, dataset_fn))
self._enable_get_next_as_optional = _enable_get_next_as_optional(
self._strategy, element_spec)
# When partial batch handling is enabled, always set the batch dimension to
@ -1239,11 +1265,10 @@ class DistributedDatasetsFromFunction(_IterableInput):
# out this change.
enable_legacy_iterators = getattr(self._strategy,
"_enable_legacy_iterators", False)
iterators = _create_iterators_per_worker(self._datasets,
self._input_workers,
enable_legacy_iterators)
enable_legacy_iterators,
self._options)
if enable_legacy_iterators:
iterator = DistributedIteratorV1(
self._input_workers,
@ -1252,9 +1277,9 @@ class DistributedDatasetsFromFunction(_IterableInput):
enable_get_next_as_optional=self._enable_get_next_as_optional)
else:
iterator = DistributedIterator(
self._input_workers,
iterators,
self._strategy,
input_workers=self._input_workers,
iterators=iterators,
strategy=self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
iterator._element_spec = self._element_spec # pylint: disable=protected-access
@ -1495,7 +1520,7 @@ def _recover_shape_fn(data, value_structure):
class _SingleWorkerDatasetIteratorBase(object):
"""Iterator for a single `tf.data.Dataset`."""
def __init__(self, dataset, worker, devices):
def __init__(self, dataset, worker, devices, options=None):
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch
@ -1505,16 +1530,36 @@ class _SingleWorkerDatasetIteratorBase(object):
dataset: A `tf.data.Dataset` instance.
worker: Worker on which ops should be created.
devices: Distribute data from `dataset` to these devices.
options: options.
"""
self._dataset = dataset
self._worker = worker
self._devices = devices
self._element_spec = dataset.element_spec
self._options = options
self._make_iterator()
def _make_iterator(self):
raise NotImplementedError("must be implemented in descendants")
def _format_data_list_with_options(self, data_list):
"""Change the data in to a list type if required.
The OwnedMultiDeviceIterator returns the list data type,
while the PER_REPLICA iterator (when used with prefetch disabled)
returns without the enclosed list. This is to fix the inconsistency.
Args:
data_list: data_list
Returns:
list
"""
if (self._options and self._options.experimental_replication_mode ==
InputReplicationMode.PER_REPLICA and
not self._options.experimental_prefetch_to_device):
return [data_list]
else:
return data_list
def get_next(self, device, name=None):
"""Get next element for the given device."""
del name
@ -1536,7 +1581,7 @@ class _SingleWorkerDatasetIteratorBase(object):
"""
del name
with ops.device(self._worker):
return self._iterator.get_next()
return self._format_data_list_with_options(self._iterator.get_next())
def get_next_as_list(self, name=None):
"""Get next element from underlying iterator.
@ -1556,7 +1601,8 @@ class _SingleWorkerDatasetIteratorBase(object):
"""
del name
with ops.device(self._worker):
data_list = self._iterator.get_next_as_optional()
data_list = self._format_data_list_with_options(
self._iterator.get_next_as_optional())
result = []
for i, data in enumerate(data_list):
# Place the condition op in the same device as the data so the data
@ -1636,8 +1682,13 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
composite_tensor.CompositeTensor):
"""Iterator for a DistributedDataset instance."""
def __init__(self, dataset=None, worker=None, devices=None, components=None,
element_spec=None):
def __init__(self,
dataset=None,
worker=None,
devices=None,
components=None,
element_spec=None,
options=None):
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
`OwnedMultiDeviceIterator` is used to prefetch input to the devices on the
@ -1653,6 +1704,8 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
_SingleWorkerOwnedDatasetIterator from.
element_spec: A nested structure of `TypeSpec` objects that represents the
type specification of elements of the iterator.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
"""
if worker is None or devices is None:
raise ValueError("Both `worker` and `devices` should be provided")
@ -1660,6 +1713,7 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
error_message = ("Either `dataset` or both `components` and `element_spec` "
"need to be provided.")
self._options = options
if dataset is None:
if (components is None or element_spec is None):
raise ValueError(error_message)
@ -1670,18 +1724,25 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
else:
if (components is not None or element_spec is not None):
raise ValueError(error_message)
super(_SingleWorkerOwnedDatasetIterator, self).__init__(dataset, worker,
devices)
super(_SingleWorkerOwnedDatasetIterator,
self).__init__(dataset, worker, devices, options)
def _make_iterator(self):
"""Make appropriate iterator on the dataset."""
if not self._worker:
raise ValueError("Worked device must be specified when creating an "
"owned iterator.")
host_device = device_util.get_host_for_device(self._worker)
with ops.device(self._worker):
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
self._dataset, self._devices, source_device=host_device)
if (self._options is None or self._options.experimental_replication_mode ==
InputReplicationMode.PER_WORKER or
(self._options.experimental_replication_mode == InputReplicationMode
.PER_REPLICA and self._options.experimental_prefetch_to_device)):
host_device = device_util.get_host_for_device(self._worker)
with ops.device(self._worker):
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
self._dataset, self._devices, source_device=host_device)
else:
with ops.device(self._worker):
self._iterator = iter(self._dataset)
@property
def element_spec(self):
@ -1802,19 +1863,23 @@ class _SingleWorkerCallableIterator(object):
return []
def _create_iterators_per_worker(worker_datasets, input_workers,
enable_legacy_iterators):
def _create_iterators_per_worker(worker_datasets,
input_workers,
enable_legacy_iterators,
options=None):
"""Create a multidevice iterator on each of the workers."""
assert isinstance(input_workers, InputWorkers)
assert len(worker_datasets) == len(input_workers.worker_devices)
iterators = []
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
worker_devices = input_workers.compute_devices_for_worker(i)
if tf2.enabled() and not enable_legacy_iterators:
iterator = _SingleWorkerOwnedDatasetIterator(worker_datasets[i], worker,
worker_devices)
iterator = _SingleWorkerOwnedDatasetIterator(
dataset=worker_datasets[i],
worker=worker,
devices=worker_devices,
options=options)
else:
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
worker_devices)
@ -1822,8 +1887,9 @@ def _create_iterators_per_worker(worker_datasets, input_workers,
return iterators
def _create_datasets_per_worker_with_input_context(input_contexts,
input_workers, dataset_fn):
def _create_datasets_from_function_with_input_context(input_contexts,
input_workers,
dataset_fn):
"""Create device datasets per worker given a dataset function."""
datasets = []
for i, ctx in enumerate(input_contexts):

View File

@ -1421,5 +1421,198 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
input_context=distribution.extended._make_input_context())
class DistributedIteratorPerDeviceTest(DistributedIteratorTestBase,
parameterized.TestCase):
"""Tests for PER_WORKER and PER_REPLICA's InputOptions variants."""
def setUp(self):
context._reset_context()
strategy_combinations.set_virtual_cpus_to_at_least(3)
super(DistributedIteratorPerDeviceTest, self).setUp()
@combinations.generate(
combinations.combine(
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
],
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
]))
def testDevicePlacementForPerWorkerValuesWithPrefetch(self, distribution,
input_options):
def dataset_fn(input_context): # pylint: disable=[unused-argument]
return dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
for x in ds:
assert x.values[0].device == distribution.extended.worker_devices[0]
assert x.values[0].backing_device == distribution.extended.worker_devices[
0]
assert x.values[1].device == distribution.extended.worker_devices[1]
assert x.values[1].backing_device == distribution.extended.worker_devices[
1]
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER)
],
mode=["eager"],
))
def testDevicePlacementForPerWorkerValuesWithoutPrefetch(
self, distribution, input_options):
def dataset_fn(input_context):
return dataset_ops.Dataset.from_tensor_slices(
np.full(4, input_context.input_pipeline_id))
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
for x in ds:
x = distribution.run(lambda inputs: inputs, args=(x,))
assert x.values[
0].device == "/job:localhost/replica:0/task:0/device:CPU:0"
assert x.values[
0].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
assert x.values[
1].device == "/job:localhost/replica:0/task:0/device:CPU:0"
assert x.values[
1].backing_device == "/job:localhost/replica:0/task:0/device:CPU:0"
@combinations.generate(
combinations.combine(
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=True,
experimental_prefetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=True,
experimental_prefetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA)
],
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
]))
def testDevicePlacementForInvalidCombinations(self, distribution,
input_options):
def dataset_fn(input_context):
return dataset_ops.Dataset.from_tensor_slices(
np.full(4, input_context.input_pipeline_id))
with self.assertRaises(ValueError):
distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
@combinations.generate(
combinations.combine(
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_WORKER),
],
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
]))
def testOutputValuesForPerWorkerInputOptions(self, distribution,
input_options):
def dataset_fn(input_context):
return dataset_ops.Dataset.from_tensor_slices(
np.arange(1, 11).reshape(
(2, 5)) * (input_context.input_pipeline_id + 1))
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
# validating the values
x = next(iter(ds))
assert np.array_equal(x.values[0].numpy(), np.array([1, 2, 3, 4, 5]))
assert np.array_equal(x.values[1].numpy(), np.array([6, 7, 8, 9, 10]))
@combinations.generate(
combinations.combine(
input_options=[
distribute_lib.InputOptions(
experimental_place_dataset_on_device=True,
experimental_prefetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=False,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
distribute_lib.InputOptions(
experimental_place_dataset_on_device=False,
experimental_prefetch_to_device=True,
experimental_replication_mode=distribute_lib
.InputReplicationMode.PER_REPLICA),
],
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
]))
def testOutputValuesForPerReplicaInputOptions(self, distribution,
input_options):
def dataset_fn(input_context):
return dataset_ops.Dataset.from_tensor_slices(
np.arange(1, 10) * (input_context.input_pipeline_id + 1))
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, input_options)
expected = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
for i, x in enumerate(ds):
# validating the values
assert x.values[0].numpy() == expected[i]
assert x.values[1].numpy() == expected[i] * 2
loop_num = i
assert loop_num == len(expected) - 1
if __name__ == "__main__":
test_util.main()

View File

@ -341,6 +341,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
self._devices = tuple(device_util.canonicalize(d) for d in devices)
self._input_workers_devices = (
(device_util.canonicalize("/device:CPU:0", devices[0]), devices),)
self._inferred_cross_device_ops = None if self._cross_device_ops else (
cross_device_ops_lib.select_cross_device_ops(devices))
self._host_input_device = numpy_dataset.SingleDevice(
@ -396,12 +397,27 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
logging.info("Using MirroredStrategy with remote devices %r", devices)
def _input_workers_with_options(self, options=None):
if not options or options.experimental_prefetch_to_device:
if not options:
return input_lib.InputWorkers(self._input_workers_devices)
if (options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
if options.experimental_place_dataset_on_device:
self._input_workers_devices = (
tuple(
(device_util.canonicalize(d, d), (d,)) for d in self._devices))
else:
self._input_workers_devices = (
tuple((device_util.canonicalize("/device:CPU:0", d), (d,))
for d in self._devices))
return input_lib.InputWorkers(self._input_workers_devices)
else:
return input_lib.InputWorkers(
[(host_device, (host_device,) * len(compute_devices)) for
host_device, compute_devices in self._input_workers_devices])
if not options.experimental_prefetch_to_device:
return input_lib.InputWorkers([
(host_device, (host_device,) * len(compute_devices))
for host_device, compute_devices in self._input_workers_devices
])
else:
return input_lib.InputWorkers(self._input_workers_devices)
@property
def _input_workers(self):
@ -499,6 +515,13 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
self._container_strategy())
def _experimental_distribute_dataset(self, dataset, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function`."
)
return input_lib.get_distributed_dataset(
dataset,
self._input_workers_with_options(options),
@ -510,8 +533,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
numpy_input, self._host_input_device, session)
def _distribute_datasets_from_function(self, dataset_fn, options):
input_contexts = []
input_workers = self._input_workers_with_options(options)
input_contexts = []
num_workers = input_workers.num_workers
for i in range(num_workers):
input_contexts.append(distribute_lib.InputContext(
@ -520,10 +543,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
num_replicas_in_sync=self._num_replicas_in_sync))
return input_lib.get_distributed_datasets_from_function(
dataset_fn,
input_workers,
input_contexts,
self._container_strategy())
dataset_fn, input_workers, input_contexts, self._container_strategy(),
options)
def _experimental_distribute_values_from_function(self, value_fn):
per_replica_values = []

View File

@ -312,12 +312,26 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
def _experimental_distribute_dataset(self, dataset, options):
# Note that split_batch_by argument is not passed because it is always 1 in
# this strategy, and adding it adds unnecessary overhead to the dataset.
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function`."
)
return input_lib.get_distributed_dataset(
dataset,
self._input_workers_with_options(options),
self._container_strategy())
def _distribute_datasets_from_function(self, dataset_fn, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function` "
"of tf.distribute.MirroredStrategy")
return input_lib.get_distributed_datasets_from_function(
dataset_fn,
self._input_workers_with_options(options),

View File

@ -119,12 +119,26 @@ class ParameterServerStrategy(distribute_lib.Strategy):
super(ParameterServerStrategy, self).__init__(extended)
def experimental_distribute_dataset(self, dataset, options=None):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function`."
)
self._raise_pss_error_if_eager()
super(ParameterServerStrategy,
self).experimental_distribute_dataset(dataset=dataset,
options=options)
def distribute_datasets_from_function(self, dataset_fn, options=None):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function` "
"of tf.distribute.MirroredStrategy")
self._raise_pss_error_if_eager()
super(ParameterServerStrategy, self).distribute_datasets_from_function(
dataset_fn=dataset_fn, options=options)

View File

@ -803,6 +803,13 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
"distribution function.".format(path, type(spec)))
def _experimental_distribute_dataset(self, dataset, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
"`experimental_distribute_datasets_from_function`."
)
if options is None or options.experimental_prefetch_to_device:
self._check_spec(dataset.element_spec)
@ -813,6 +820,13 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
num_replicas_in_sync=self._num_replicas_in_sync)
def _distribute_datasets_from_function(self, dataset_fn, options):
if (options and options.experimental_replication_mode ==
distribute_lib.InputReplicationMode.PER_REPLICA):
raise NotImplementedError(
"InputReplicationMode.PER_REPLICA "
"is only supported in "
" `experimental_distribute_datasets_from_function` "
"of tf.distribute.MirroredStrategy")
input_workers = self._get_input_workers(options)
input_contexts = []
num_workers = input_workers.num_workers

View File

@ -1,6 +1,10 @@
path: "tensorflow.distribute.InputReplicationMode"
tf_class {
is_instance: "<enum \'InputReplicationMode\'>"
member {
name: "PER_REPLICA"
mtype: "<enum \'InputReplicationMode\'>"
}
member {
name: "PER_WORKER"
mtype: "<enum \'InputReplicationMode\'>"

View File

@ -3,10 +3,18 @@ tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputOptions\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputOptions\'>"
is_instance: "<type \'tuple\'>"
member {
name: "experimental_place_dataset_on_device"
mtype: "<type \'property\'>"
}
member {
name: "experimental_prefetch_to_device"
mtype: "<type \'property\'>"
}
member {
name: "experimental_replication_mode"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}

View File

@ -1,6 +1,10 @@
path: "tensorflow.distribute.InputReplicationMode"
tf_class {
is_instance: "<enum \'InputReplicationMode\'>"
member {
name: "PER_REPLICA"
mtype: "<enum \'InputReplicationMode\'>"
}
member {
name: "PER_WORKER"
mtype: "<enum \'InputReplicationMode\'>"