[tf.data + tf.distribute] Use RebatchDataset instead of LegacyRebatchDataset in distribution strategies when global batch size can be statically determined.
PiperOrigin-RevId: 334646795 Change-Id: I6ed7a09e38577c1b1d3487473e12a48c92c2a2c7
This commit is contained in:
parent
8e14fdb6a2
commit
a7f8535480
@ -85,9 +85,9 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
|
||||
return self._element_spec
|
||||
|
||||
|
||||
def _AutoShardDatasetV1(input_dataset, num_workers, index): # pylint: disable=invalid-name
|
||||
def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None): # pylint: disable=invalid-name
|
||||
return dataset_ops.DatasetV1Adapter(
|
||||
_AutoShardDataset(input_dataset, num_workers, index))
|
||||
_AutoShardDataset(input_dataset, num_workers, index, num_replicas))
|
||||
|
||||
|
||||
class _RebatchDataset(dataset_ops.UnaryDataset):
|
||||
|
@ -476,7 +476,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
dataset,
|
||||
self._input_workers_with_options(options),
|
||||
self._container_strategy(),
|
||||
split_batch_by=self._num_replicas_in_sync,
|
||||
num_replicas_in_sync=self._num_replicas_in_sync,
|
||||
input_context=input_context)
|
||||
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
@ -505,7 +505,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
dataset,
|
||||
self._input_workers,
|
||||
self._container_strategy(),
|
||||
split_batch_by=self._num_replicas_in_sync,
|
||||
num_replicas_in_sync=self._num_replicas_in_sync,
|
||||
input_context=input_context)
|
||||
|
||||
def _make_input_fn_iterator(
|
||||
|
@ -61,7 +61,7 @@ from tensorflow.tools.docs import doc_controls
|
||||
def get_distributed_dataset(dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
split_batch_by=None,
|
||||
num_replicas_in_sync=None,
|
||||
input_context=None):
|
||||
"""Returns a distributed dataset from the given tf.data.Dataset instance.
|
||||
|
||||
@ -77,8 +77,10 @@ def get_distributed_dataset(dataset,
|
||||
iterators should be created.
|
||||
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
|
||||
handle last partial batch.
|
||||
split_batch_by: Optional integer. If present, we "split" each batch of the
|
||||
dataset by `split_batch_by` value.
|
||||
num_replicas_in_sync: Optional integer. If this is not None, the value is
|
||||
used to decide how to rebatch datasets into smaller batches so that
|
||||
the total batch size for each step (across all workers and replicas)
|
||||
adds up to `dataset`'s batch size.
|
||||
input_context: `InputContext` for sharding. Only pass this in for between
|
||||
graph multi-worker cases where there is only one `input_worker`. In
|
||||
these cases, we will shard based on the `input_pipeline_id` and
|
||||
@ -92,14 +94,14 @@ def get_distributed_dataset(dataset,
|
||||
dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
split_batch_by=split_batch_by,
|
||||
num_replicas_in_sync=num_replicas_in_sync,
|
||||
input_context=input_context)
|
||||
else:
|
||||
return DistributedDatasetV1(
|
||||
dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
split_batch_by=split_batch_by,
|
||||
num_replicas_in_sync=num_replicas_in_sync,
|
||||
input_context=input_context)
|
||||
|
||||
|
||||
@ -917,20 +919,24 @@ class DistributedDataset(_IterableInput):
|
||||
dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
split_batch_by=None,
|
||||
num_replicas_in_sync=None,
|
||||
input_context=None):
|
||||
"""Distribute the dataset on all workers.
|
||||
|
||||
If `split_batch_by` is not None, we "split" each batch of the dataset by
|
||||
`split_batch_by` value.
|
||||
If `num_replicas_in_sync` is not None, we split each batch of the dataset
|
||||
into `num_replicas_in_sync` smaller batches, to be distributed among that
|
||||
worker's replicas, so that the batch size for a global step (across all
|
||||
workers and replicas) is as expected.
|
||||
|
||||
Args:
|
||||
dataset: `tf.data.Dataset` that will be used as the input source.
|
||||
input_workers: an `InputWorkers` object.
|
||||
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
|
||||
handle last partial batch.
|
||||
split_batch_by: Optional integer. If present, we "split" each batch of the
|
||||
dataset by `split_batch_by` value.
|
||||
num_replicas_in_sync: Optional integer. If this is not None, the value
|
||||
is used to decide how to rebatch datasets into smaller batches so that
|
||||
the total batch size for each step (across all workers and replicas)
|
||||
adds up to `dataset`'s batch size.
|
||||
input_context: `InputContext` for sharding. Only pass this in for between
|
||||
graph multi-worker cases where there is only one `input_worker`. In
|
||||
these cases, we will shard based on the `input_pipeline_id` and
|
||||
@ -942,36 +948,29 @@ class DistributedDataset(_IterableInput):
|
||||
# different subset of files. If that is not possible, will attempt to shard
|
||||
# the final input such that each worker will run the entire preprocessing
|
||||
# pipeline and only receive its own shard of the dataset.
|
||||
if split_batch_by:
|
||||
try:
|
||||
# pylint: disable=protected-access
|
||||
with ops.colocate_with(dataset._variant_tensor):
|
||||
dataset = distribute._LegacyRebatchDataset(dataset, split_batch_by)
|
||||
# Add a prefetch to pipeline rebatching for performance.
|
||||
# TODO(rachelim): Instead of inserting an extra prefetch stage here,
|
||||
# leverage static graph rewrites to insert _RebatchDataset before
|
||||
# the final `prefetch` if it exists.
|
||||
dataset = dataset.prefetch(split_batch_by)
|
||||
except errors.InvalidArgumentError as e:
|
||||
if "without encountering a batch" in str(e):
|
||||
six.reraise(
|
||||
ValueError,
|
||||
ValueError(
|
||||
"Call the `batch` method on the input Dataset in order to be "
|
||||
"able to split your input across {} replicas.\n Please "
|
||||
"the tf.distribute.Strategy guide. {}".format(
|
||||
split_batch_by, e)),
|
||||
sys.exc_info()[2])
|
||||
else:
|
||||
raise
|
||||
|
||||
# Additionally, we rebatch the dataset on each worker into
|
||||
# `num_replicas_in_sync` smaller batches to be distributed among that
|
||||
# worker's replicas, so that the batch size for a global step (across all
|
||||
# workers and replicas) adds up to the original dataset's batch size.
|
||||
if num_replicas_in_sync is not None:
|
||||
num_workers = input_context.num_input_pipelines if input_context else len(
|
||||
input_workers.worker_devices)
|
||||
rebatch_fn = self._make_rebatch_fn(dataset, num_workers,
|
||||
num_replicas_in_sync)
|
||||
else:
|
||||
rebatch_fn = None
|
||||
|
||||
self._cloned_datasets = []
|
||||
if input_context:
|
||||
# Between-graph where we rely on the input_context for sharding
|
||||
assert input_workers.num_workers == 1
|
||||
if rebatch_fn is not None:
|
||||
dataset = rebatch_fn(dataset, input_context.input_pipeline_id)
|
||||
dataset = input_ops.auto_shard_dataset(dataset,
|
||||
input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
input_context.input_pipeline_id,
|
||||
num_replicas_in_sync)
|
||||
self._cloned_datasets.append(dataset)
|
||||
else:
|
||||
replicated_ds = distribute.replicate(dataset,
|
||||
@ -980,16 +979,73 @@ class DistributedDataset(_IterableInput):
|
||||
with ops.device(worker):
|
||||
cloned_dataset = replicated_ds[worker]
|
||||
cloned_dataset = cloned_dataset.with_options(dataset.options())
|
||||
if rebatch_fn is not None:
|
||||
cloned_dataset = rebatch_fn(cloned_dataset, i)
|
||||
cloned_dataset = input_ops.auto_shard_dataset(
|
||||
cloned_dataset, len(input_workers.worker_devices), i)
|
||||
cloned_dataset, len(input_workers.worker_devices), i,
|
||||
num_replicas_in_sync)
|
||||
self._cloned_datasets.append(cloned_dataset)
|
||||
|
||||
self._input_workers = input_workers
|
||||
self._strategy = strategy
|
||||
self._enable_get_next_as_optional = _enable_get_next_as_optional(
|
||||
self._strategy, dataset.element_spec)
|
||||
self._element_spec = _create_distributed_tensor_spec(self._strategy,
|
||||
dataset.element_spec) # pylint: disable=protected-access
|
||||
self._element_spec = _create_distributed_tensor_spec(
|
||||
self._strategy, self._cloned_datasets[0].element_spec)
|
||||
|
||||
def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync):
|
||||
"""Returns a callable that rebatches the input dataset.
|
||||
|
||||
Args:
|
||||
dataset: A `tf.data.Dataset` representing the dataset to be distributed.
|
||||
num_workers: An integer representing the number of workers to distribute
|
||||
`dataset` among.
|
||||
num_replicas_in_sync: An integer representing the number of replicas in
|
||||
sync across all workers.
|
||||
"""
|
||||
if num_replicas_in_sync % num_workers:
|
||||
raise ValueError(
|
||||
"tf.distribute expects every worker to have the same number of "
|
||||
"replicas. However, encountered `num_replicas_in_sync` ({}) that "
|
||||
"cannot be divided by `num_workers` ({})".format(
|
||||
num_replicas_in_sync, num_workers))
|
||||
|
||||
num_replicas_per_worker = num_replicas_in_sync // num_workers
|
||||
with ops.colocate_with(dataset._variant_tensor): # pylint: disable=protected-access
|
||||
batch_size = distribute.compute_batch_size(dataset)
|
||||
|
||||
def rebatch_fn(dataset, worker_index):
|
||||
try:
|
||||
# pylint: disable=protected-access
|
||||
def apply_rebatch():
|
||||
batch_sizes = distribute.batch_sizes_for_worker(
|
||||
batch_size, num_workers, num_replicas_per_worker, worker_index)
|
||||
return distribute._RebatchDataset(
|
||||
dataset, batch_sizes).prefetch(num_replicas_per_worker)
|
||||
|
||||
def apply_legacy_rebatch():
|
||||
return distribute._LegacyRebatchDataset(
|
||||
dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker)
|
||||
|
||||
with ops.colocate_with(dataset._variant_tensor):
|
||||
return control_flow_ops.cond(
|
||||
math_ops.not_equal(batch_size, -1),
|
||||
true_fn=apply_rebatch,
|
||||
false_fn=apply_legacy_rebatch)
|
||||
except errors.InvalidArgumentError as e:
|
||||
if "without encountering a batch" in str(e):
|
||||
six.reraise(
|
||||
ValueError,
|
||||
ValueError(
|
||||
"Call the `batch` method on the input Dataset in order to be "
|
||||
"able to split your input across {} replicas.\n Please see "
|
||||
"the tf.distribute.Strategy guide. {}".format(
|
||||
num_replicas_in_sync, e)),
|
||||
sys.exc_info()[2])
|
||||
else:
|
||||
raise
|
||||
|
||||
return rebatch_fn
|
||||
|
||||
def __iter__(self):
|
||||
if not (context.executing_eagerly() or
|
||||
@ -1040,14 +1096,14 @@ class DistributedDatasetV1(DistributedDataset):
|
||||
dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
split_batch_by=None,
|
||||
num_replicas_in_sync=None,
|
||||
input_context=None):
|
||||
self._input_workers = input_workers
|
||||
super(DistributedDatasetV1, self).__init__(
|
||||
dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
split_batch_by=split_batch_by,
|
||||
num_replicas_in_sync=num_replicas_in_sync,
|
||||
input_context=input_context)
|
||||
|
||||
def make_one_shot_iterator(self):
|
||||
@ -1305,20 +1361,24 @@ class DatasetIterator(DistributedIteratorV1):
|
||||
dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
split_batch_by=None,
|
||||
num_replicas_in_sync=None,
|
||||
input_context=None):
|
||||
"""Make an iterator for the dataset on given devices.
|
||||
|
||||
If `split_batch_by` is not None, we "split" each batch of the
|
||||
dataset by `split_batch_by` value.
|
||||
If `num_replicas_in_sync` is not None, we split each batch of the dataset
|
||||
into `num_replicas_in_sync` smaller batches, to be distributed among that
|
||||
worker's replicas, so that the batch size for a global step (across all
|
||||
workers and replicas) is as expected.
|
||||
|
||||
Args:
|
||||
dataset: `tf.data.Dataset` that will be used as the input source.
|
||||
input_workers: an `InputWorkers` object.
|
||||
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
|
||||
handle last partial batch.
|
||||
split_batch_by: Optional integer. If present, we "split" each batch of the
|
||||
dataset by `split_batch_by` value.
|
||||
num_replicas_in_sync: Optional integer. If this is not None, the value is
|
||||
used to decide how to rebatch datasets into smaller batches so that the
|
||||
total batch size for each step (across all workers and replicas) adds up
|
||||
to `dataset`'s batch size.
|
||||
input_context: `InputContext` for sharding. Only pass this in for between
|
||||
graph multi-worker cases where there is only one `input_worker`. In
|
||||
these cases, we will shard based on the `input_pipeline_id` and
|
||||
@ -1328,7 +1388,7 @@ class DatasetIterator(DistributedIteratorV1):
|
||||
dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
split_batch_by=split_batch_by,
|
||||
num_replicas_in_sync=num_replicas_in_sync,
|
||||
input_context=input_context)
|
||||
worker_iterators = _create_iterators_per_worker(
|
||||
dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access
|
||||
|
@ -61,7 +61,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
dataset_or_input_fn,
|
||||
input_workers,
|
||||
devices,
|
||||
split_batch_by,
|
||||
num_replicas_in_sync,
|
||||
strategy,
|
||||
input_context=None):
|
||||
# The `input_context` passed in is to shard dataset for
|
||||
@ -93,7 +93,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
dataset_or_input_fn,
|
||||
input_workers,
|
||||
strategy,
|
||||
split_batch_by=split_batch_by,
|
||||
num_replicas_in_sync=num_replicas_in_sync,
|
||||
input_context=input_context)
|
||||
return iterator
|
||||
|
||||
@ -101,7 +101,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
input_type,
|
||||
dataset,
|
||||
input_workers,
|
||||
split_batch_by,
|
||||
num_replicas_in_sync,
|
||||
strategy,
|
||||
input_context=None):
|
||||
if input_type == "dataset":
|
||||
@ -110,14 +110,14 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
split_batch_by=split_batch_by,
|
||||
num_replicas_in_sync=num_replicas_in_sync,
|
||||
input_context=input_context)
|
||||
else:
|
||||
return input_lib.DistributedDatasetV1(
|
||||
dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
split_batch_by=split_batch_by,
|
||||
num_replicas_in_sync=num_replicas_in_sync,
|
||||
input_context=input_context)
|
||||
else:
|
||||
return strategy.distribute_datasets_from_function(dataset)
|
||||
@ -163,7 +163,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
expected_values,
|
||||
strategy,
|
||||
sess=None,
|
||||
split_batch_by=None,
|
||||
num_replicas_in_sync=None,
|
||||
input_context=None):
|
||||
if iteration_type == "for_loop" and not context.executing_eagerly():
|
||||
self.skipTest("unsupported test combination.")
|
||||
@ -183,7 +183,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
dataset_or_input_fn,
|
||||
input_workers,
|
||||
devices,
|
||||
split_batch_by,
|
||||
num_replicas_in_sync,
|
||||
strategy,
|
||||
input_context=input_context)
|
||||
else:
|
||||
@ -192,7 +192,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
input_type,
|
||||
dataset_or_input_fn,
|
||||
input_workers,
|
||||
split_batch_by,
|
||||
num_replicas_in_sync,
|
||||
strategy,
|
||||
input_context=input_context)
|
||||
|
||||
@ -361,10 +361,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution,
|
||||
enable_get_next_as_optional):
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
||||
else:
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
||||
input_type, dataset_fn)
|
||||
|
||||
@ -419,10 +416,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
distribution, enable_get_next_as_optional):
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
||||
"/device:CPU:0"])]
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
||||
else:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
||||
input_type, dataset_fn)
|
||||
|
||||
@ -455,10 +449,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
worker_device_pairs.setdefault(host_device, [])
|
||||
worker_device_pairs[host_device].append(tpu_device)
|
||||
worker_device_pairs = worker_device_pairs.items()
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
||||
else:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
||||
input_type, dataset_fn)
|
||||
|
||||
@ -493,14 +484,10 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
|
||||
def dataset_fn(ctx):
|
||||
del ctx
|
||||
if tf2.enabled():
|
||||
dataset1 = dataset_ops.DatasetV2.range(10)
|
||||
dataset2 = dataset_ops.DatasetV2.range(10).map(lambda x: x**2)
|
||||
return dataset_ops.DatasetV2.zip((dataset1, dataset2))
|
||||
else:
|
||||
dataset1 = dataset_ops.Dataset.range(10)
|
||||
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
dataset1 = dataset_ops.Dataset.range(10)
|
||||
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
|
||||
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
||||
input_type, dataset_fn)
|
||||
|
||||
@ -563,7 +550,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||
|
||||
dataset = dataset_ops.DatasetV2.range(10)
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
|
||||
distribution)
|
||||
|
||||
@ -663,7 +650,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
worker_device_pairs,
|
||||
expected_values,
|
||||
distribution,
|
||||
split_batch_by=distribution.num_replicas_in_sync,
|
||||
num_replicas_in_sync=distribution.num_replicas_in_sync,
|
||||
input_context=distribution.extended._make_input_context())
|
||||
|
||||
@combinations.generate(
|
||||
@ -724,7 +711,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
worker_device_pairs,
|
||||
expected_values,
|
||||
distribution,
|
||||
split_batch_by=distribution.num_replicas_in_sync,
|
||||
num_replicas_in_sync=distribution.num_replicas_in_sync,
|
||||
input_context=distribution.extended._make_input_context())
|
||||
|
||||
@combinations.generate(
|
||||
@ -733,14 +720,14 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
input_type=["dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"],
|
||||
split_batch_by=[None, 2],
|
||||
num_replicas_in_sync=[None, 2],
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
|
||||
],
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testBatchSplitting(self, input_type, api_type, iteration_type,
|
||||
split_batch_by, distribution,
|
||||
num_replicas_in_sync, distribution,
|
||||
enable_get_next_as_optional):
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
||||
"/device:CPU:0"])]
|
||||
@ -750,7 +737,8 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
input_type, dataset_fn)
|
||||
|
||||
updated_batch_size = (
|
||||
batch_size // split_batch_by if split_batch_by else batch_size)
|
||||
batch_size //
|
||||
num_replicas_in_sync if num_replicas_in_sync else batch_size)
|
||||
expected_values = [[range(i, i+updated_batch_size),
|
||||
range(i+updated_batch_size, i+2*updated_batch_size)]
|
||||
for i in range(0, 100, updated_batch_size*2)]
|
||||
@ -766,7 +754,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
expected_values,
|
||||
distribution,
|
||||
sess=None,
|
||||
split_batch_by=split_batch_by)
|
||||
num_replicas_in_sync=num_replicas_in_sync)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
@ -774,13 +762,13 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
input_type=["dataset"],
|
||||
api_type=["wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"],
|
||||
split_batch_by=[None, 2],
|
||||
num_replicas_in_sync=[None, 2],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type,
|
||||
split_batch_by, distribution,
|
||||
num_replicas_in_sync, distribution,
|
||||
enable_get_next_as_optional):
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
||||
"/device:GPU:1"])]
|
||||
@ -796,7 +784,8 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
input_type, dataset_fn)
|
||||
|
||||
updated_batch_size = (
|
||||
batch_size // split_batch_by if split_batch_by else batch_size)
|
||||
batch_size //
|
||||
num_replicas_in_sync if num_replicas_in_sync else batch_size)
|
||||
expected_values = [
|
||||
[ # pylint: disable=g-complex-comprehension
|
||||
range(i, i + updated_batch_size),
|
||||
@ -815,7 +804,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
expected_values,
|
||||
distribution,
|
||||
sess=None,
|
||||
split_batch_by=split_batch_by)
|
||||
num_replicas_in_sync=num_replicas_in_sync)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
@ -1249,6 +1238,187 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||
expected_for_sum = 310.
|
||||
self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
input_type=["dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
]))
|
||||
def testMWMSPartialBatch(self, input_type, api_type, iteration_type,
|
||||
distribution):
|
||||
# Test case: 2 workers, 1 replica each.
|
||||
# This test simulates the sharded behavior when we have two files each with
|
||||
# 12 elements and a global batch size of 8. When we consider the dataset in
|
||||
# aggregate (non-distributed), there are 24 elements divided into 3 batches
|
||||
# of size 8. Hence, the correct distributed behavior is for each replica to
|
||||
# see sub-batches of size 4, over three steps.
|
||||
def dataset_fn(ctx):
|
||||
del ctx
|
||||
dataset = dataset_ops.Dataset.range(12).batch(8)
|
||||
|
||||
# Set the sharding behavior to OFF for simplicity of test setup; namely,
|
||||
# `dataset` defines the per-worker dataset and will not be further
|
||||
# sharded. Each worker will see a dataset that is
|
||||
# tf.data.Dataset.range(12).batch(8).rebatch(...).
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
|
||||
|
||||
# Actual devices don't matter in this test as long as there is 1 local
|
||||
# replica.
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||
|
||||
# Each test runs individually on each worker, so we compare the
|
||||
# values on each worker. Each worker should rebatch its dataset into
|
||||
# smaller batches of size 4.
|
||||
expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]]
|
||||
self._test_input_iteration(
|
||||
input_type,
|
||||
api_type,
|
||||
iteration_type,
|
||||
dataset,
|
||||
worker_device_pairs,
|
||||
expected_values,
|
||||
distribution,
|
||||
num_replicas_in_sync=distribution.num_replicas_in_sync,
|
||||
input_context=distribution.extended._make_input_context())
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
input_type=["dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
]))
|
||||
def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type,
|
||||
iteration_type, distribution):
|
||||
# Test case: 2 workers, 1 replica each.
|
||||
# This test simulates the sharded behavior when we have two files each with
|
||||
# 12 elements and a global batch size of 8. When we consider the dataset in
|
||||
# aggregate (non-distributed), there are 24 elements divided into 3 batches
|
||||
# of size 8. Hence, the correct distributed behavior is for each replica to
|
||||
# see sub-batches of size 4, over three steps. However, when we create a
|
||||
# DistributedDataset and cannot statically infer the intended global batch
|
||||
# size (e.g. if the user does not use a batching dataset), each worker will
|
||||
# rebatch based on the dynamic batch size of the data encountered, even when
|
||||
# it encounters partial batches. The last per-worker partial batch (size 4)
|
||||
# ends up being split into two replicas, resulting in 4 steps in total, of
|
||||
# (global) batch sizes 8, 8, 4, 4.
|
||||
def dataset_fn(ctx):
|
||||
del ctx
|
||||
# The following dataset is equivalent to
|
||||
# tf.data.Dataset.range(12).batch(8), but does not use a batching dataset.
|
||||
# This causes DistributedDataset to use LegacyRebatch instead.
|
||||
batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4])
|
||||
offsets = dataset_ops.Dataset.from_tensor_slices([0, 8])
|
||||
dataset = dataset_ops.Dataset.zip((offsets, batch_sizes))
|
||||
|
||||
def map_fn(offset, batch_size):
|
||||
return math_ops.range(offset, offset + batch_size)
|
||||
|
||||
dataset = dataset.map(map_fn)
|
||||
|
||||
# Set the sharding behavior to OFF for simplicity of test setup; namely,
|
||||
# `dataset` defines the per-worker dataset and will not be further
|
||||
# sharded. Each worker will see a dataset that is equivalent to
|
||||
# tf.data.Dataset.range(12).batch(8).rebatch(...).
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
|
||||
|
||||
# Actual devices don't matter in this test as long as the number of global
|
||||
# replicas is 2.
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||
|
||||
# Each test runs individually on each worker, so we compare the
|
||||
# values on each worker. Each worker should rebatch its dataset into
|
||||
# smaller batches of size 4.
|
||||
expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]]
|
||||
self._test_input_iteration(
|
||||
input_type,
|
||||
api_type,
|
||||
iteration_type,
|
||||
dataset,
|
||||
worker_device_pairs,
|
||||
expected_values,
|
||||
distribution,
|
||||
num_replicas_in_sync=distribution.num_replicas_in_sync,
|
||||
input_context=distribution.extended._make_input_context())
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
input_type=["dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
],
|
||||
auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA]))
|
||||
def testMWMSWithDataSharding(self, input_type, api_type, iteration_type,
|
||||
distribution, auto_shard_policy):
|
||||
# Test case: 2 workers, 1 replica each.
|
||||
# This test simulates the sharded behavior the dataset is sharded by data
|
||||
# and the batch size is indivisible by the number of replicas. This checks
|
||||
# that the elements are as expected and the batch size across all workers
|
||||
# adds up to 3. This test will only pass if the autoshard rewrite rewrites
|
||||
# RebatchDatasetV2 to legacy RebatchDataset when sharding by data.
|
||||
def dataset_fn(ctx):
|
||||
del ctx
|
||||
dataset = dataset_ops.Dataset.range(8).batch(3)
|
||||
|
||||
# Set the sharding behavior to OFF for simplicity of test setup; namely,
|
||||
# `dataset` defines the per-worker dataset and will not be further
|
||||
# sharded. Each worker will see a dataset that is
|
||||
# tf.data.Dataset.range(12).batch(8).rebatch(...).
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_distribute.auto_shard_policy = auto_shard_policy
|
||||
dataset = dataset.with_options(options)
|
||||
return dataset
|
||||
|
||||
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
|
||||
|
||||
# Actual devices don't matter in this test as long as there is 1 local
|
||||
# replica.
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||
|
||||
# Each test runs individually on each worker, so we compare the
|
||||
# values on each worker. We expect each worker to see different shards of
|
||||
# data.
|
||||
cr = distribution.cluster_resolver
|
||||
worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type,
|
||||
cr.task_id)
|
||||
|
||||
if worker_id == 0:
|
||||
expected_values = [[[0, 1]], [[3, 4]], [[6]]]
|
||||
elif worker_id == 1:
|
||||
expected_values = [[[2]], [[5]], [[7]]]
|
||||
|
||||
self._test_input_iteration(
|
||||
input_type,
|
||||
api_type,
|
||||
iteration_type,
|
||||
dataset,
|
||||
worker_device_pairs,
|
||||
expected_values,
|
||||
distribution,
|
||||
num_replicas_in_sync=distribution.num_replicas_in_sync,
|
||||
input_context=distribution.extended._make_input_context())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
combinations.main()
|
||||
|
@ -27,7 +27,7 @@ from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def auto_shard_dataset(dataset, num_shards, index):
|
||||
def auto_shard_dataset(dataset, num_shards, index, num_replicas_in_sync=None):
|
||||
"""Shard the input pipeline by sharding the underlying list of files.
|
||||
|
||||
Args:
|
||||
@ -37,6 +37,8 @@ def auto_shard_dataset(dataset, num_shards, index):
|
||||
shards operating in parallel. Same usage as in `tf.data.Dataset.shard`.
|
||||
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
|
||||
Same usage as in `tf.data.Dataset.shard`.
|
||||
num_replicas_in_sync: An integer representing the total number of replicas
|
||||
across all workers. This is used in the rewrite when sharding by data.
|
||||
|
||||
Returns:
|
||||
A modified `Dataset` obtained by updating the pipeline sharded by the
|
||||
@ -45,10 +47,14 @@ def auto_shard_dataset(dataset, num_shards, index):
|
||||
"""
|
||||
if (dataset.options().experimental_distribute.auto_shard_policy !=
|
||||
AutoShardPolicy.OFF):
|
||||
if num_replicas_in_sync is None:
|
||||
num_replicas_in_sync = 1
|
||||
if isinstance(dataset, dataset_ops.DatasetV1):
|
||||
return distribute._AutoShardDatasetV1(dataset, num_shards, index)
|
||||
return distribute._AutoShardDatasetV1(dataset, num_shards, index,
|
||||
num_replicas_in_sync)
|
||||
else:
|
||||
return distribute._AutoShardDataset(dataset, num_shards, index)
|
||||
return distribute._AutoShardDataset(dataset, num_shards, index,
|
||||
num_replicas_in_sync)
|
||||
else:
|
||||
return dataset
|
||||
|
||||
|
@ -479,7 +479,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
dataset,
|
||||
self._input_workers,
|
||||
self._container_strategy(),
|
||||
split_batch_by=self._num_replicas_in_sync)
|
||||
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||
|
||||
def _make_input_fn_iterator(
|
||||
self,
|
||||
@ -501,7 +501,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
dataset,
|
||||
self._input_workers_with_options(options),
|
||||
self._container_strategy(),
|
||||
split_batch_by=self._num_replicas_in_sync)
|
||||
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||
|
||||
def _experimental_make_numpy_dataset(self, numpy_input, session):
|
||||
return numpy_dataset.one_host_numpy_dataset(
|
||||
|
@ -351,14 +351,14 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
dataset,
|
||||
self._input_workers_with_options(options),
|
||||
self._container_strategy(),
|
||||
split_batch_by=self._num_replicas_in_sync)
|
||||
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||
|
||||
def _make_dataset_iterator(self, dataset):
|
||||
return input_lib.DatasetIterator(
|
||||
dataset,
|
||||
self._input_workers,
|
||||
self._container_strategy(),
|
||||
split_batch_by=self._num_replicas_in_sync)
|
||||
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||
|
||||
def _make_input_fn_iterator(
|
||||
self,
|
||||
|
@ -752,7 +752,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
dataset,
|
||||
input_workers,
|
||||
self._container_strategy(),
|
||||
split_batch_by=self._num_replicas_in_sync)
|
||||
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||
|
||||
def _make_input_fn_iterator(
|
||||
self,
|
||||
@ -809,7 +809,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
dataset,
|
||||
self._get_input_workers(options),
|
||||
self._container_strategy(),
|
||||
split_batch_by=self._num_replicas_in_sync)
|
||||
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||
|
||||
def _distribute_datasets_from_function(self, dataset_fn, options):
|
||||
input_workers = self._get_input_workers(options)
|
||||
|
Loading…
Reference in New Issue
Block a user