[tf.data + tf.distribute] Use RebatchDataset instead of LegacyRebatchDataset in distribution strategies when global batch size can be statically determined.
PiperOrigin-RevId: 334646795 Change-Id: I6ed7a09e38577c1b1d3487473e12a48c92c2a2c7
This commit is contained in:
parent
8e14fdb6a2
commit
a7f8535480
@ -85,9 +85,9 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
|
|||||||
return self._element_spec
|
return self._element_spec
|
||||||
|
|
||||||
|
|
||||||
def _AutoShardDatasetV1(input_dataset, num_workers, index): # pylint: disable=invalid-name
|
def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None): # pylint: disable=invalid-name
|
||||||
return dataset_ops.DatasetV1Adapter(
|
return dataset_ops.DatasetV1Adapter(
|
||||||
_AutoShardDataset(input_dataset, num_workers, index))
|
_AutoShardDataset(input_dataset, num_workers, index, num_replicas))
|
||||||
|
|
||||||
|
|
||||||
class _RebatchDataset(dataset_ops.UnaryDataset):
|
class _RebatchDataset(dataset_ops.UnaryDataset):
|
||||||
|
@ -476,7 +476,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
dataset,
|
dataset,
|
||||||
self._input_workers_with_options(options),
|
self._input_workers_with_options(options),
|
||||||
self._container_strategy(),
|
self._container_strategy(),
|
||||||
split_batch_by=self._num_replicas_in_sync,
|
num_replicas_in_sync=self._num_replicas_in_sync,
|
||||||
input_context=input_context)
|
input_context=input_context)
|
||||||
|
|
||||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||||
@ -505,7 +505,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
|||||||
dataset,
|
dataset,
|
||||||
self._input_workers,
|
self._input_workers,
|
||||||
self._container_strategy(),
|
self._container_strategy(),
|
||||||
split_batch_by=self._num_replicas_in_sync,
|
num_replicas_in_sync=self._num_replicas_in_sync,
|
||||||
input_context=input_context)
|
input_context=input_context)
|
||||||
|
|
||||||
def _make_input_fn_iterator(
|
def _make_input_fn_iterator(
|
||||||
|
@ -61,7 +61,7 @@ from tensorflow.tools.docs import doc_controls
|
|||||||
def get_distributed_dataset(dataset,
|
def get_distributed_dataset(dataset,
|
||||||
input_workers,
|
input_workers,
|
||||||
strategy,
|
strategy,
|
||||||
split_batch_by=None,
|
num_replicas_in_sync=None,
|
||||||
input_context=None):
|
input_context=None):
|
||||||
"""Returns a distributed dataset from the given tf.data.Dataset instance.
|
"""Returns a distributed dataset from the given tf.data.Dataset instance.
|
||||||
|
|
||||||
@ -77,8 +77,10 @@ def get_distributed_dataset(dataset,
|
|||||||
iterators should be created.
|
iterators should be created.
|
||||||
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
|
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
|
||||||
handle last partial batch.
|
handle last partial batch.
|
||||||
split_batch_by: Optional integer. If present, we "split" each batch of the
|
num_replicas_in_sync: Optional integer. If this is not None, the value is
|
||||||
dataset by `split_batch_by` value.
|
used to decide how to rebatch datasets into smaller batches so that
|
||||||
|
the total batch size for each step (across all workers and replicas)
|
||||||
|
adds up to `dataset`'s batch size.
|
||||||
input_context: `InputContext` for sharding. Only pass this in for between
|
input_context: `InputContext` for sharding. Only pass this in for between
|
||||||
graph multi-worker cases where there is only one `input_worker`. In
|
graph multi-worker cases where there is only one `input_worker`. In
|
||||||
these cases, we will shard based on the `input_pipeline_id` and
|
these cases, we will shard based on the `input_pipeline_id` and
|
||||||
@ -92,14 +94,14 @@ def get_distributed_dataset(dataset,
|
|||||||
dataset,
|
dataset,
|
||||||
input_workers,
|
input_workers,
|
||||||
strategy,
|
strategy,
|
||||||
split_batch_by=split_batch_by,
|
num_replicas_in_sync=num_replicas_in_sync,
|
||||||
input_context=input_context)
|
input_context=input_context)
|
||||||
else:
|
else:
|
||||||
return DistributedDatasetV1(
|
return DistributedDatasetV1(
|
||||||
dataset,
|
dataset,
|
||||||
input_workers,
|
input_workers,
|
||||||
strategy,
|
strategy,
|
||||||
split_batch_by=split_batch_by,
|
num_replicas_in_sync=num_replicas_in_sync,
|
||||||
input_context=input_context)
|
input_context=input_context)
|
||||||
|
|
||||||
|
|
||||||
@ -917,20 +919,24 @@ class DistributedDataset(_IterableInput):
|
|||||||
dataset,
|
dataset,
|
||||||
input_workers,
|
input_workers,
|
||||||
strategy,
|
strategy,
|
||||||
split_batch_by=None,
|
num_replicas_in_sync=None,
|
||||||
input_context=None):
|
input_context=None):
|
||||||
"""Distribute the dataset on all workers.
|
"""Distribute the dataset on all workers.
|
||||||
|
|
||||||
If `split_batch_by` is not None, we "split" each batch of the dataset by
|
If `num_replicas_in_sync` is not None, we split each batch of the dataset
|
||||||
`split_batch_by` value.
|
into `num_replicas_in_sync` smaller batches, to be distributed among that
|
||||||
|
worker's replicas, so that the batch size for a global step (across all
|
||||||
|
workers and replicas) is as expected.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: `tf.data.Dataset` that will be used as the input source.
|
dataset: `tf.data.Dataset` that will be used as the input source.
|
||||||
input_workers: an `InputWorkers` object.
|
input_workers: an `InputWorkers` object.
|
||||||
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
|
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
|
||||||
handle last partial batch.
|
handle last partial batch.
|
||||||
split_batch_by: Optional integer. If present, we "split" each batch of the
|
num_replicas_in_sync: Optional integer. If this is not None, the value
|
||||||
dataset by `split_batch_by` value.
|
is used to decide how to rebatch datasets into smaller batches so that
|
||||||
|
the total batch size for each step (across all workers and replicas)
|
||||||
|
adds up to `dataset`'s batch size.
|
||||||
input_context: `InputContext` for sharding. Only pass this in for between
|
input_context: `InputContext` for sharding. Only pass this in for between
|
||||||
graph multi-worker cases where there is only one `input_worker`. In
|
graph multi-worker cases where there is only one `input_worker`. In
|
||||||
these cases, we will shard based on the `input_pipeline_id` and
|
these cases, we will shard based on the `input_pipeline_id` and
|
||||||
@ -942,36 +948,29 @@ class DistributedDataset(_IterableInput):
|
|||||||
# different subset of files. If that is not possible, will attempt to shard
|
# different subset of files. If that is not possible, will attempt to shard
|
||||||
# the final input such that each worker will run the entire preprocessing
|
# the final input such that each worker will run the entire preprocessing
|
||||||
# pipeline and only receive its own shard of the dataset.
|
# pipeline and only receive its own shard of the dataset.
|
||||||
if split_batch_by:
|
|
||||||
try:
|
# Additionally, we rebatch the dataset on each worker into
|
||||||
# pylint: disable=protected-access
|
# `num_replicas_in_sync` smaller batches to be distributed among that
|
||||||
with ops.colocate_with(dataset._variant_tensor):
|
# worker's replicas, so that the batch size for a global step (across all
|
||||||
dataset = distribute._LegacyRebatchDataset(dataset, split_batch_by)
|
# workers and replicas) adds up to the original dataset's batch size.
|
||||||
# Add a prefetch to pipeline rebatching for performance.
|
if num_replicas_in_sync is not None:
|
||||||
# TODO(rachelim): Instead of inserting an extra prefetch stage here,
|
num_workers = input_context.num_input_pipelines if input_context else len(
|
||||||
# leverage static graph rewrites to insert _RebatchDataset before
|
input_workers.worker_devices)
|
||||||
# the final `prefetch` if it exists.
|
rebatch_fn = self._make_rebatch_fn(dataset, num_workers,
|
||||||
dataset = dataset.prefetch(split_batch_by)
|
num_replicas_in_sync)
|
||||||
except errors.InvalidArgumentError as e:
|
else:
|
||||||
if "without encountering a batch" in str(e):
|
rebatch_fn = None
|
||||||
six.reraise(
|
|
||||||
ValueError,
|
|
||||||
ValueError(
|
|
||||||
"Call the `batch` method on the input Dataset in order to be "
|
|
||||||
"able to split your input across {} replicas.\n Please "
|
|
||||||
"the tf.distribute.Strategy guide. {}".format(
|
|
||||||
split_batch_by, e)),
|
|
||||||
sys.exc_info()[2])
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
|
|
||||||
self._cloned_datasets = []
|
self._cloned_datasets = []
|
||||||
if input_context:
|
if input_context:
|
||||||
# Between-graph where we rely on the input_context for sharding
|
# Between-graph where we rely on the input_context for sharding
|
||||||
assert input_workers.num_workers == 1
|
assert input_workers.num_workers == 1
|
||||||
|
if rebatch_fn is not None:
|
||||||
|
dataset = rebatch_fn(dataset, input_context.input_pipeline_id)
|
||||||
dataset = input_ops.auto_shard_dataset(dataset,
|
dataset = input_ops.auto_shard_dataset(dataset,
|
||||||
input_context.num_input_pipelines,
|
input_context.num_input_pipelines,
|
||||||
input_context.input_pipeline_id)
|
input_context.input_pipeline_id,
|
||||||
|
num_replicas_in_sync)
|
||||||
self._cloned_datasets.append(dataset)
|
self._cloned_datasets.append(dataset)
|
||||||
else:
|
else:
|
||||||
replicated_ds = distribute.replicate(dataset,
|
replicated_ds = distribute.replicate(dataset,
|
||||||
@ -980,16 +979,73 @@ class DistributedDataset(_IterableInput):
|
|||||||
with ops.device(worker):
|
with ops.device(worker):
|
||||||
cloned_dataset = replicated_ds[worker]
|
cloned_dataset = replicated_ds[worker]
|
||||||
cloned_dataset = cloned_dataset.with_options(dataset.options())
|
cloned_dataset = cloned_dataset.with_options(dataset.options())
|
||||||
|
if rebatch_fn is not None:
|
||||||
|
cloned_dataset = rebatch_fn(cloned_dataset, i)
|
||||||
cloned_dataset = input_ops.auto_shard_dataset(
|
cloned_dataset = input_ops.auto_shard_dataset(
|
||||||
cloned_dataset, len(input_workers.worker_devices), i)
|
cloned_dataset, len(input_workers.worker_devices), i,
|
||||||
|
num_replicas_in_sync)
|
||||||
self._cloned_datasets.append(cloned_dataset)
|
self._cloned_datasets.append(cloned_dataset)
|
||||||
|
|
||||||
self._input_workers = input_workers
|
self._input_workers = input_workers
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
self._enable_get_next_as_optional = _enable_get_next_as_optional(
|
self._enable_get_next_as_optional = _enable_get_next_as_optional(
|
||||||
self._strategy, dataset.element_spec)
|
self._strategy, dataset.element_spec)
|
||||||
self._element_spec = _create_distributed_tensor_spec(self._strategy,
|
self._element_spec = _create_distributed_tensor_spec(
|
||||||
dataset.element_spec) # pylint: disable=protected-access
|
self._strategy, self._cloned_datasets[0].element_spec)
|
||||||
|
|
||||||
|
def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync):
|
||||||
|
"""Returns a callable that rebatches the input dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: A `tf.data.Dataset` representing the dataset to be distributed.
|
||||||
|
num_workers: An integer representing the number of workers to distribute
|
||||||
|
`dataset` among.
|
||||||
|
num_replicas_in_sync: An integer representing the number of replicas in
|
||||||
|
sync across all workers.
|
||||||
|
"""
|
||||||
|
if num_replicas_in_sync % num_workers:
|
||||||
|
raise ValueError(
|
||||||
|
"tf.distribute expects every worker to have the same number of "
|
||||||
|
"replicas. However, encountered `num_replicas_in_sync` ({}) that "
|
||||||
|
"cannot be divided by `num_workers` ({})".format(
|
||||||
|
num_replicas_in_sync, num_workers))
|
||||||
|
|
||||||
|
num_replicas_per_worker = num_replicas_in_sync // num_workers
|
||||||
|
with ops.colocate_with(dataset._variant_tensor): # pylint: disable=protected-access
|
||||||
|
batch_size = distribute.compute_batch_size(dataset)
|
||||||
|
|
||||||
|
def rebatch_fn(dataset, worker_index):
|
||||||
|
try:
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
def apply_rebatch():
|
||||||
|
batch_sizes = distribute.batch_sizes_for_worker(
|
||||||
|
batch_size, num_workers, num_replicas_per_worker, worker_index)
|
||||||
|
return distribute._RebatchDataset(
|
||||||
|
dataset, batch_sizes).prefetch(num_replicas_per_worker)
|
||||||
|
|
||||||
|
def apply_legacy_rebatch():
|
||||||
|
return distribute._LegacyRebatchDataset(
|
||||||
|
dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker)
|
||||||
|
|
||||||
|
with ops.colocate_with(dataset._variant_tensor):
|
||||||
|
return control_flow_ops.cond(
|
||||||
|
math_ops.not_equal(batch_size, -1),
|
||||||
|
true_fn=apply_rebatch,
|
||||||
|
false_fn=apply_legacy_rebatch)
|
||||||
|
except errors.InvalidArgumentError as e:
|
||||||
|
if "without encountering a batch" in str(e):
|
||||||
|
six.reraise(
|
||||||
|
ValueError,
|
||||||
|
ValueError(
|
||||||
|
"Call the `batch` method on the input Dataset in order to be "
|
||||||
|
"able to split your input across {} replicas.\n Please see "
|
||||||
|
"the tf.distribute.Strategy guide. {}".format(
|
||||||
|
num_replicas_in_sync, e)),
|
||||||
|
sys.exc_info()[2])
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
return rebatch_fn
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
if not (context.executing_eagerly() or
|
if not (context.executing_eagerly() or
|
||||||
@ -1040,14 +1096,14 @@ class DistributedDatasetV1(DistributedDataset):
|
|||||||
dataset,
|
dataset,
|
||||||
input_workers,
|
input_workers,
|
||||||
strategy,
|
strategy,
|
||||||
split_batch_by=None,
|
num_replicas_in_sync=None,
|
||||||
input_context=None):
|
input_context=None):
|
||||||
self._input_workers = input_workers
|
self._input_workers = input_workers
|
||||||
super(DistributedDatasetV1, self).__init__(
|
super(DistributedDatasetV1, self).__init__(
|
||||||
dataset,
|
dataset,
|
||||||
input_workers,
|
input_workers,
|
||||||
strategy,
|
strategy,
|
||||||
split_batch_by=split_batch_by,
|
num_replicas_in_sync=num_replicas_in_sync,
|
||||||
input_context=input_context)
|
input_context=input_context)
|
||||||
|
|
||||||
def make_one_shot_iterator(self):
|
def make_one_shot_iterator(self):
|
||||||
@ -1305,20 +1361,24 @@ class DatasetIterator(DistributedIteratorV1):
|
|||||||
dataset,
|
dataset,
|
||||||
input_workers,
|
input_workers,
|
||||||
strategy,
|
strategy,
|
||||||
split_batch_by=None,
|
num_replicas_in_sync=None,
|
||||||
input_context=None):
|
input_context=None):
|
||||||
"""Make an iterator for the dataset on given devices.
|
"""Make an iterator for the dataset on given devices.
|
||||||
|
|
||||||
If `split_batch_by` is not None, we "split" each batch of the
|
If `num_replicas_in_sync` is not None, we split each batch of the dataset
|
||||||
dataset by `split_batch_by` value.
|
into `num_replicas_in_sync` smaller batches, to be distributed among that
|
||||||
|
worker's replicas, so that the batch size for a global step (across all
|
||||||
|
workers and replicas) is as expected.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: `tf.data.Dataset` that will be used as the input source.
|
dataset: `tf.data.Dataset` that will be used as the input source.
|
||||||
input_workers: an `InputWorkers` object.
|
input_workers: an `InputWorkers` object.
|
||||||
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
|
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
|
||||||
handle last partial batch.
|
handle last partial batch.
|
||||||
split_batch_by: Optional integer. If present, we "split" each batch of the
|
num_replicas_in_sync: Optional integer. If this is not None, the value is
|
||||||
dataset by `split_batch_by` value.
|
used to decide how to rebatch datasets into smaller batches so that the
|
||||||
|
total batch size for each step (across all workers and replicas) adds up
|
||||||
|
to `dataset`'s batch size.
|
||||||
input_context: `InputContext` for sharding. Only pass this in for between
|
input_context: `InputContext` for sharding. Only pass this in for between
|
||||||
graph multi-worker cases where there is only one `input_worker`. In
|
graph multi-worker cases where there is only one `input_worker`. In
|
||||||
these cases, we will shard based on the `input_pipeline_id` and
|
these cases, we will shard based on the `input_pipeline_id` and
|
||||||
@ -1328,7 +1388,7 @@ class DatasetIterator(DistributedIteratorV1):
|
|||||||
dataset,
|
dataset,
|
||||||
input_workers,
|
input_workers,
|
||||||
strategy,
|
strategy,
|
||||||
split_batch_by=split_batch_by,
|
num_replicas_in_sync=num_replicas_in_sync,
|
||||||
input_context=input_context)
|
input_context=input_context)
|
||||||
worker_iterators = _create_iterators_per_worker(
|
worker_iterators = _create_iterators_per_worker(
|
||||||
dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access
|
dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access
|
||||||
|
@ -61,7 +61,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
|||||||
dataset_or_input_fn,
|
dataset_or_input_fn,
|
||||||
input_workers,
|
input_workers,
|
||||||
devices,
|
devices,
|
||||||
split_batch_by,
|
num_replicas_in_sync,
|
||||||
strategy,
|
strategy,
|
||||||
input_context=None):
|
input_context=None):
|
||||||
# The `input_context` passed in is to shard dataset for
|
# The `input_context` passed in is to shard dataset for
|
||||||
@ -93,7 +93,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
|||||||
dataset_or_input_fn,
|
dataset_or_input_fn,
|
||||||
input_workers,
|
input_workers,
|
||||||
strategy,
|
strategy,
|
||||||
split_batch_by=split_batch_by,
|
num_replicas_in_sync=num_replicas_in_sync,
|
||||||
input_context=input_context)
|
input_context=input_context)
|
||||||
return iterator
|
return iterator
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
|||||||
input_type,
|
input_type,
|
||||||
dataset,
|
dataset,
|
||||||
input_workers,
|
input_workers,
|
||||||
split_batch_by,
|
num_replicas_in_sync,
|
||||||
strategy,
|
strategy,
|
||||||
input_context=None):
|
input_context=None):
|
||||||
if input_type == "dataset":
|
if input_type == "dataset":
|
||||||
@ -110,14 +110,14 @@ class DistributedIteratorTestBase(test.TestCase):
|
|||||||
dataset,
|
dataset,
|
||||||
input_workers,
|
input_workers,
|
||||||
strategy,
|
strategy,
|
||||||
split_batch_by=split_batch_by,
|
num_replicas_in_sync=num_replicas_in_sync,
|
||||||
input_context=input_context)
|
input_context=input_context)
|
||||||
else:
|
else:
|
||||||
return input_lib.DistributedDatasetV1(
|
return input_lib.DistributedDatasetV1(
|
||||||
dataset,
|
dataset,
|
||||||
input_workers,
|
input_workers,
|
||||||
strategy,
|
strategy,
|
||||||
split_batch_by=split_batch_by,
|
num_replicas_in_sync=num_replicas_in_sync,
|
||||||
input_context=input_context)
|
input_context=input_context)
|
||||||
else:
|
else:
|
||||||
return strategy.distribute_datasets_from_function(dataset)
|
return strategy.distribute_datasets_from_function(dataset)
|
||||||
@ -163,7 +163,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
|||||||
expected_values,
|
expected_values,
|
||||||
strategy,
|
strategy,
|
||||||
sess=None,
|
sess=None,
|
||||||
split_batch_by=None,
|
num_replicas_in_sync=None,
|
||||||
input_context=None):
|
input_context=None):
|
||||||
if iteration_type == "for_loop" and not context.executing_eagerly():
|
if iteration_type == "for_loop" and not context.executing_eagerly():
|
||||||
self.skipTest("unsupported test combination.")
|
self.skipTest("unsupported test combination.")
|
||||||
@ -183,7 +183,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
|||||||
dataset_or_input_fn,
|
dataset_or_input_fn,
|
||||||
input_workers,
|
input_workers,
|
||||||
devices,
|
devices,
|
||||||
split_batch_by,
|
num_replicas_in_sync,
|
||||||
strategy,
|
strategy,
|
||||||
input_context=input_context)
|
input_context=input_context)
|
||||||
else:
|
else:
|
||||||
@ -192,7 +192,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
|||||||
input_type,
|
input_type,
|
||||||
dataset_or_input_fn,
|
dataset_or_input_fn,
|
||||||
input_workers,
|
input_workers,
|
||||||
split_batch_by,
|
num_replicas_in_sync,
|
||||||
strategy,
|
strategy,
|
||||||
input_context=input_context)
|
input_context=input_context)
|
||||||
|
|
||||||
@ -361,10 +361,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
|||||||
def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution,
|
def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution,
|
||||||
enable_get_next_as_optional):
|
enable_get_next_as_optional):
|
||||||
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||||
if tf2.enabled():
|
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
|
||||||
else:
|
|
||||||
dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
|
|
||||||
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
||||||
input_type, dataset_fn)
|
input_type, dataset_fn)
|
||||||
|
|
||||||
@ -419,10 +416,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
|||||||
distribution, enable_get_next_as_optional):
|
distribution, enable_get_next_as_optional):
|
||||||
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
||||||
"/device:CPU:0"])]
|
"/device:CPU:0"])]
|
||||||
if tf2.enabled():
|
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
|
||||||
else:
|
|
||||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
|
||||||
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
||||||
input_type, dataset_fn)
|
input_type, dataset_fn)
|
||||||
|
|
||||||
@ -455,10 +449,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
|||||||
worker_device_pairs.setdefault(host_device, [])
|
worker_device_pairs.setdefault(host_device, [])
|
||||||
worker_device_pairs[host_device].append(tpu_device)
|
worker_device_pairs[host_device].append(tpu_device)
|
||||||
worker_device_pairs = worker_device_pairs.items()
|
worker_device_pairs = worker_device_pairs.items()
|
||||||
if tf2.enabled():
|
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
|
||||||
else:
|
|
||||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
|
||||||
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
||||||
input_type, dataset_fn)
|
input_type, dataset_fn)
|
||||||
|
|
||||||
@ -493,14 +484,10 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
|||||||
|
|
||||||
def dataset_fn(ctx):
|
def dataset_fn(ctx):
|
||||||
del ctx
|
del ctx
|
||||||
if tf2.enabled():
|
dataset1 = dataset_ops.Dataset.range(10)
|
||||||
dataset1 = dataset_ops.DatasetV2.range(10)
|
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
||||||
dataset2 = dataset_ops.DatasetV2.range(10).map(lambda x: x**2)
|
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||||
return dataset_ops.DatasetV2.zip((dataset1, dataset2))
|
|
||||||
else:
|
|
||||||
dataset1 = dataset_ops.Dataset.range(10)
|
|
||||||
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
|
||||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
|
||||||
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
||||||
input_type, dataset_fn)
|
input_type, dataset_fn)
|
||||||
|
|
||||||
@ -563,7 +550,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
|||||||
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||||
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||||
|
|
||||||
dataset = dataset_ops.DatasetV2.range(10)
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
|
dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
|
||||||
distribution)
|
distribution)
|
||||||
|
|
||||||
@ -663,7 +650,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
|||||||
worker_device_pairs,
|
worker_device_pairs,
|
||||||
expected_values,
|
expected_values,
|
||||||
distribution,
|
distribution,
|
||||||
split_batch_by=distribution.num_replicas_in_sync,
|
num_replicas_in_sync=distribution.num_replicas_in_sync,
|
||||||
input_context=distribution.extended._make_input_context())
|
input_context=distribution.extended._make_input_context())
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
@ -724,7 +711,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
|||||||
worker_device_pairs,
|
worker_device_pairs,
|
||||||
expected_values,
|
expected_values,
|
||||||
distribution,
|
distribution,
|
||||||
split_batch_by=distribution.num_replicas_in_sync,
|
num_replicas_in_sync=distribution.num_replicas_in_sync,
|
||||||
input_context=distribution.extended._make_input_context())
|
input_context=distribution.extended._make_input_context())
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
@ -733,14 +720,14 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
|||||||
input_type=["dataset"],
|
input_type=["dataset"],
|
||||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||||
iteration_type=["get_next", "for_loop"],
|
iteration_type=["get_next", "for_loop"],
|
||||||
split_batch_by=[None, 2],
|
num_replicas_in_sync=[None, 2],
|
||||||
distribution=[
|
distribution=[
|
||||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
|
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
|
||||||
],
|
],
|
||||||
enable_get_next_as_optional=[True, False]))
|
enable_get_next_as_optional=[True, False]))
|
||||||
def testBatchSplitting(self, input_type, api_type, iteration_type,
|
def testBatchSplitting(self, input_type, api_type, iteration_type,
|
||||||
split_batch_by, distribution,
|
num_replicas_in_sync, distribution,
|
||||||
enable_get_next_as_optional):
|
enable_get_next_as_optional):
|
||||||
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
||||||
"/device:CPU:0"])]
|
"/device:CPU:0"])]
|
||||||
@ -750,7 +737,8 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
|||||||
input_type, dataset_fn)
|
input_type, dataset_fn)
|
||||||
|
|
||||||
updated_batch_size = (
|
updated_batch_size = (
|
||||||
batch_size // split_batch_by if split_batch_by else batch_size)
|
batch_size //
|
||||||
|
num_replicas_in_sync if num_replicas_in_sync else batch_size)
|
||||||
expected_values = [[range(i, i+updated_batch_size),
|
expected_values = [[range(i, i+updated_batch_size),
|
||||||
range(i+updated_batch_size, i+2*updated_batch_size)]
|
range(i+updated_batch_size, i+2*updated_batch_size)]
|
||||||
for i in range(0, 100, updated_batch_size*2)]
|
for i in range(0, 100, updated_batch_size*2)]
|
||||||
@ -766,7 +754,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
|||||||
expected_values,
|
expected_values,
|
||||||
distribution,
|
distribution,
|
||||||
sess=None,
|
sess=None,
|
||||||
split_batch_by=split_batch_by)
|
num_replicas_in_sync=num_replicas_in_sync)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
@ -774,13 +762,13 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
|||||||
input_type=["dataset"],
|
input_type=["dataset"],
|
||||||
api_type=["wrap_into_dataset"],
|
api_type=["wrap_into_dataset"],
|
||||||
iteration_type=["get_next", "for_loop"],
|
iteration_type=["get_next", "for_loop"],
|
||||||
split_batch_by=[None, 2],
|
num_replicas_in_sync=[None, 2],
|
||||||
distribution=[
|
distribution=[
|
||||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||||
],
|
],
|
||||||
enable_get_next_as_optional=[True, False]))
|
enable_get_next_as_optional=[True, False]))
|
||||||
def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type,
|
def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type,
|
||||||
split_batch_by, distribution,
|
num_replicas_in_sync, distribution,
|
||||||
enable_get_next_as_optional):
|
enable_get_next_as_optional):
|
||||||
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
||||||
"/device:GPU:1"])]
|
"/device:GPU:1"])]
|
||||||
@ -796,7 +784,8 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
|||||||
input_type, dataset_fn)
|
input_type, dataset_fn)
|
||||||
|
|
||||||
updated_batch_size = (
|
updated_batch_size = (
|
||||||
batch_size // split_batch_by if split_batch_by else batch_size)
|
batch_size //
|
||||||
|
num_replicas_in_sync if num_replicas_in_sync else batch_size)
|
||||||
expected_values = [
|
expected_values = [
|
||||||
[ # pylint: disable=g-complex-comprehension
|
[ # pylint: disable=g-complex-comprehension
|
||||||
range(i, i + updated_batch_size),
|
range(i, i + updated_batch_size),
|
||||||
@ -815,7 +804,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
|||||||
expected_values,
|
expected_values,
|
||||||
distribution,
|
distribution,
|
||||||
sess=None,
|
sess=None,
|
||||||
split_batch_by=split_batch_by)
|
num_replicas_in_sync=num_replicas_in_sync)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.combine(
|
combinations.combine(
|
||||||
@ -1249,6 +1238,187 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
|||||||
expected_for_sum = 310.
|
expected_for_sum = 310.
|
||||||
self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3)
|
self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
mode=["eager"],
|
||||||
|
input_type=["dataset"],
|
||||||
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||||
|
iteration_type=["get_next", "for_loop"],
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||||
|
]))
|
||||||
|
def testMWMSPartialBatch(self, input_type, api_type, iteration_type,
|
||||||
|
distribution):
|
||||||
|
# Test case: 2 workers, 1 replica each.
|
||||||
|
# This test simulates the sharded behavior when we have two files each with
|
||||||
|
# 12 elements and a global batch size of 8. When we consider the dataset in
|
||||||
|
# aggregate (non-distributed), there are 24 elements divided into 3 batches
|
||||||
|
# of size 8. Hence, the correct distributed behavior is for each replica to
|
||||||
|
# see sub-batches of size 4, over three steps.
|
||||||
|
def dataset_fn(ctx):
|
||||||
|
del ctx
|
||||||
|
dataset = dataset_ops.Dataset.range(12).batch(8)
|
||||||
|
|
||||||
|
# Set the sharding behavior to OFF for simplicity of test setup; namely,
|
||||||
|
# `dataset` defines the per-worker dataset and will not be further
|
||||||
|
# sharded. Each worker will see a dataset that is
|
||||||
|
# tf.data.Dataset.range(12).batch(8).rebatch(...).
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
|
||||||
|
dataset = dataset.with_options(options)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
|
||||||
|
|
||||||
|
# Actual devices don't matter in this test as long as there is 1 local
|
||||||
|
# replica.
|
||||||
|
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||||
|
|
||||||
|
# Each test runs individually on each worker, so we compare the
|
||||||
|
# values on each worker. Each worker should rebatch its dataset into
|
||||||
|
# smaller batches of size 4.
|
||||||
|
expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]]
|
||||||
|
self._test_input_iteration(
|
||||||
|
input_type,
|
||||||
|
api_type,
|
||||||
|
iteration_type,
|
||||||
|
dataset,
|
||||||
|
worker_device_pairs,
|
||||||
|
expected_values,
|
||||||
|
distribution,
|
||||||
|
num_replicas_in_sync=distribution.num_replicas_in_sync,
|
||||||
|
input_context=distribution.extended._make_input_context())
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
mode=["eager"],
|
||||||
|
input_type=["dataset"],
|
||||||
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||||
|
iteration_type=["get_next", "for_loop"],
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||||
|
]))
|
||||||
|
def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type,
|
||||||
|
iteration_type, distribution):
|
||||||
|
# Test case: 2 workers, 1 replica each.
|
||||||
|
# This test simulates the sharded behavior when we have two files each with
|
||||||
|
# 12 elements and a global batch size of 8. When we consider the dataset in
|
||||||
|
# aggregate (non-distributed), there are 24 elements divided into 3 batches
|
||||||
|
# of size 8. Hence, the correct distributed behavior is for each replica to
|
||||||
|
# see sub-batches of size 4, over three steps. However, when we create a
|
||||||
|
# DistributedDataset and cannot statically infer the intended global batch
|
||||||
|
# size (e.g. if the user does not use a batching dataset), each worker will
|
||||||
|
# rebatch based on the dynamic batch size of the data encountered, even when
|
||||||
|
# it encounters partial batches. The last per-worker partial batch (size 4)
|
||||||
|
# ends up being split into two replicas, resulting in 4 steps in total, of
|
||||||
|
# (global) batch sizes 8, 8, 4, 4.
|
||||||
|
def dataset_fn(ctx):
|
||||||
|
del ctx
|
||||||
|
# The following dataset is equivalent to
|
||||||
|
# tf.data.Dataset.range(12).batch(8), but does not use a batching dataset.
|
||||||
|
# This causes DistributedDataset to use LegacyRebatch instead.
|
||||||
|
batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4])
|
||||||
|
offsets = dataset_ops.Dataset.from_tensor_slices([0, 8])
|
||||||
|
dataset = dataset_ops.Dataset.zip((offsets, batch_sizes))
|
||||||
|
|
||||||
|
def map_fn(offset, batch_size):
|
||||||
|
return math_ops.range(offset, offset + batch_size)
|
||||||
|
|
||||||
|
dataset = dataset.map(map_fn)
|
||||||
|
|
||||||
|
# Set the sharding behavior to OFF for simplicity of test setup; namely,
|
||||||
|
# `dataset` defines the per-worker dataset and will not be further
|
||||||
|
# sharded. Each worker will see a dataset that is equivalent to
|
||||||
|
# tf.data.Dataset.range(12).batch(8).rebatch(...).
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
|
||||||
|
dataset = dataset.with_options(options)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
|
||||||
|
|
||||||
|
# Actual devices don't matter in this test as long as the number of global
|
||||||
|
# replicas is 2.
|
||||||
|
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||||
|
|
||||||
|
# Each test runs individually on each worker, so we compare the
|
||||||
|
# values on each worker. Each worker should rebatch its dataset into
|
||||||
|
# smaller batches of size 4.
|
||||||
|
expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]]
|
||||||
|
self._test_input_iteration(
|
||||||
|
input_type,
|
||||||
|
api_type,
|
||||||
|
iteration_type,
|
||||||
|
dataset,
|
||||||
|
worker_device_pairs,
|
||||||
|
expected_values,
|
||||||
|
distribution,
|
||||||
|
num_replicas_in_sync=distribution.num_replicas_in_sync,
|
||||||
|
input_context=distribution.extended._make_input_context())
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
mode=["eager"],
|
||||||
|
input_type=["dataset"],
|
||||||
|
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||||
|
iteration_type=["get_next", "for_loop"],
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||||
|
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||||
|
],
|
||||||
|
auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA]))
|
||||||
|
def testMWMSWithDataSharding(self, input_type, api_type, iteration_type,
|
||||||
|
distribution, auto_shard_policy):
|
||||||
|
# Test case: 2 workers, 1 replica each.
|
||||||
|
# This test simulates the sharded behavior the dataset is sharded by data
|
||||||
|
# and the batch size is indivisible by the number of replicas. This checks
|
||||||
|
# that the elements are as expected and the batch size across all workers
|
||||||
|
# adds up to 3. This test will only pass if the autoshard rewrite rewrites
|
||||||
|
# RebatchDatasetV2 to legacy RebatchDataset when sharding by data.
|
||||||
|
def dataset_fn(ctx):
|
||||||
|
del ctx
|
||||||
|
dataset = dataset_ops.Dataset.range(8).batch(3)
|
||||||
|
|
||||||
|
# Set the sharding behavior to OFF for simplicity of test setup; namely,
|
||||||
|
# `dataset` defines the per-worker dataset and will not be further
|
||||||
|
# sharded. Each worker will see a dataset that is
|
||||||
|
# tf.data.Dataset.range(12).batch(8).rebatch(...).
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_distribute.auto_shard_policy = auto_shard_policy
|
||||||
|
dataset = dataset.with_options(options)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
|
||||||
|
|
||||||
|
# Actual devices don't matter in this test as long as there is 1 local
|
||||||
|
# replica.
|
||||||
|
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||||
|
|
||||||
|
# Each test runs individually on each worker, so we compare the
|
||||||
|
# values on each worker. We expect each worker to see different shards of
|
||||||
|
# data.
|
||||||
|
cr = distribution.cluster_resolver
|
||||||
|
worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type,
|
||||||
|
cr.task_id)
|
||||||
|
|
||||||
|
if worker_id == 0:
|
||||||
|
expected_values = [[[0, 1]], [[3, 4]], [[6]]]
|
||||||
|
elif worker_id == 1:
|
||||||
|
expected_values = [[[2]], [[5]], [[7]]]
|
||||||
|
|
||||||
|
self._test_input_iteration(
|
||||||
|
input_type,
|
||||||
|
api_type,
|
||||||
|
iteration_type,
|
||||||
|
dataset,
|
||||||
|
worker_device_pairs,
|
||||||
|
expected_values,
|
||||||
|
distribution,
|
||||||
|
num_replicas_in_sync=distribution.num_replicas_in_sync,
|
||||||
|
input_context=distribution.extended._make_input_context())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
combinations.main()
|
combinations.main()
|
||||||
|
@ -27,7 +27,7 @@ from tensorflow.python.framework import ops
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
def auto_shard_dataset(dataset, num_shards, index):
|
def auto_shard_dataset(dataset, num_shards, index, num_replicas_in_sync=None):
|
||||||
"""Shard the input pipeline by sharding the underlying list of files.
|
"""Shard the input pipeline by sharding the underlying list of files.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -37,6 +37,8 @@ def auto_shard_dataset(dataset, num_shards, index):
|
|||||||
shards operating in parallel. Same usage as in `tf.data.Dataset.shard`.
|
shards operating in parallel. Same usage as in `tf.data.Dataset.shard`.
|
||||||
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
|
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
|
||||||
Same usage as in `tf.data.Dataset.shard`.
|
Same usage as in `tf.data.Dataset.shard`.
|
||||||
|
num_replicas_in_sync: An integer representing the total number of replicas
|
||||||
|
across all workers. This is used in the rewrite when sharding by data.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A modified `Dataset` obtained by updating the pipeline sharded by the
|
A modified `Dataset` obtained by updating the pipeline sharded by the
|
||||||
@ -45,10 +47,14 @@ def auto_shard_dataset(dataset, num_shards, index):
|
|||||||
"""
|
"""
|
||||||
if (dataset.options().experimental_distribute.auto_shard_policy !=
|
if (dataset.options().experimental_distribute.auto_shard_policy !=
|
||||||
AutoShardPolicy.OFF):
|
AutoShardPolicy.OFF):
|
||||||
|
if num_replicas_in_sync is None:
|
||||||
|
num_replicas_in_sync = 1
|
||||||
if isinstance(dataset, dataset_ops.DatasetV1):
|
if isinstance(dataset, dataset_ops.DatasetV1):
|
||||||
return distribute._AutoShardDatasetV1(dataset, num_shards, index)
|
return distribute._AutoShardDatasetV1(dataset, num_shards, index,
|
||||||
|
num_replicas_in_sync)
|
||||||
else:
|
else:
|
||||||
return distribute._AutoShardDataset(dataset, num_shards, index)
|
return distribute._AutoShardDataset(dataset, num_shards, index,
|
||||||
|
num_replicas_in_sync)
|
||||||
else:
|
else:
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
@ -479,7 +479,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
dataset,
|
dataset,
|
||||||
self._input_workers,
|
self._input_workers,
|
||||||
self._container_strategy(),
|
self._container_strategy(),
|
||||||
split_batch_by=self._num_replicas_in_sync)
|
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||||
|
|
||||||
def _make_input_fn_iterator(
|
def _make_input_fn_iterator(
|
||||||
self,
|
self,
|
||||||
@ -501,7 +501,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
dataset,
|
dataset,
|
||||||
self._input_workers_with_options(options),
|
self._input_workers_with_options(options),
|
||||||
self._container_strategy(),
|
self._container_strategy(),
|
||||||
split_batch_by=self._num_replicas_in_sync)
|
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||||
|
|
||||||
def _experimental_make_numpy_dataset(self, numpy_input, session):
|
def _experimental_make_numpy_dataset(self, numpy_input, session):
|
||||||
return numpy_dataset.one_host_numpy_dataset(
|
return numpy_dataset.one_host_numpy_dataset(
|
||||||
|
@ -351,14 +351,14 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
dataset,
|
dataset,
|
||||||
self._input_workers_with_options(options),
|
self._input_workers_with_options(options),
|
||||||
self._container_strategy(),
|
self._container_strategy(),
|
||||||
split_batch_by=self._num_replicas_in_sync)
|
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||||
|
|
||||||
def _make_dataset_iterator(self, dataset):
|
def _make_dataset_iterator(self, dataset):
|
||||||
return input_lib.DatasetIterator(
|
return input_lib.DatasetIterator(
|
||||||
dataset,
|
dataset,
|
||||||
self._input_workers,
|
self._input_workers,
|
||||||
self._container_strategy(),
|
self._container_strategy(),
|
||||||
split_batch_by=self._num_replicas_in_sync)
|
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||||
|
|
||||||
def _make_input_fn_iterator(
|
def _make_input_fn_iterator(
|
||||||
self,
|
self,
|
||||||
|
@ -752,7 +752,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
dataset,
|
dataset,
|
||||||
input_workers,
|
input_workers,
|
||||||
self._container_strategy(),
|
self._container_strategy(),
|
||||||
split_batch_by=self._num_replicas_in_sync)
|
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||||
|
|
||||||
def _make_input_fn_iterator(
|
def _make_input_fn_iterator(
|
||||||
self,
|
self,
|
||||||
@ -809,7 +809,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
dataset,
|
dataset,
|
||||||
self._get_input_workers(options),
|
self._get_input_workers(options),
|
||||||
self._container_strategy(),
|
self._container_strategy(),
|
||||||
split_batch_by=self._num_replicas_in_sync)
|
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||||
|
|
||||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||||
input_workers = self._get_input_workers(options)
|
input_workers = self._get_input_workers(options)
|
||||||
|
Loading…
Reference in New Issue
Block a user