[tf.data + tf.distribute] Use RebatchDataset instead of LegacyRebatchDataset in distribution strategies when global batch size can be statically determined.

PiperOrigin-RevId: 334646795
Change-Id: I6ed7a09e38577c1b1d3487473e12a48c92c2a2c7
This commit is contained in:
Rachel Lim 2020-09-30 11:50:44 -07:00 committed by TensorFlower Gardener
parent 8e14fdb6a2
commit a7f8535480
8 changed files with 332 additions and 96 deletions

View File

@ -85,9 +85,9 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
return self._element_spec return self._element_spec
def _AutoShardDatasetV1(input_dataset, num_workers, index): # pylint: disable=invalid-name def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None): # pylint: disable=invalid-name
return dataset_ops.DatasetV1Adapter( return dataset_ops.DatasetV1Adapter(
_AutoShardDataset(input_dataset, num_workers, index)) _AutoShardDataset(input_dataset, num_workers, index, num_replicas))
class _RebatchDataset(dataset_ops.UnaryDataset): class _RebatchDataset(dataset_ops.UnaryDataset):

View File

@ -476,7 +476,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
dataset, dataset,
self._input_workers_with_options(options), self._input_workers_with_options(options),
self._container_strategy(), self._container_strategy(),
split_batch_by=self._num_replicas_in_sync, num_replicas_in_sync=self._num_replicas_in_sync,
input_context=input_context) input_context=input_context)
def _distribute_datasets_from_function(self, dataset_fn, options): def _distribute_datasets_from_function(self, dataset_fn, options):
@ -505,7 +505,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
dataset, dataset,
self._input_workers, self._input_workers,
self._container_strategy(), self._container_strategy(),
split_batch_by=self._num_replicas_in_sync, num_replicas_in_sync=self._num_replicas_in_sync,
input_context=input_context) input_context=input_context)
def _make_input_fn_iterator( def _make_input_fn_iterator(

View File

@ -61,7 +61,7 @@ from tensorflow.tools.docs import doc_controls
def get_distributed_dataset(dataset, def get_distributed_dataset(dataset,
input_workers, input_workers,
strategy, strategy,
split_batch_by=None, num_replicas_in_sync=None,
input_context=None): input_context=None):
"""Returns a distributed dataset from the given tf.data.Dataset instance. """Returns a distributed dataset from the given tf.data.Dataset instance.
@ -77,8 +77,10 @@ def get_distributed_dataset(dataset,
iterators should be created. iterators should be created.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch. handle last partial batch.
split_batch_by: Optional integer. If present, we "split" each batch of the num_replicas_in_sync: Optional integer. If this is not None, the value is
dataset by `split_batch_by` value. used to decide how to rebatch datasets into smaller batches so that
the total batch size for each step (across all workers and replicas)
adds up to `dataset`'s batch size.
input_context: `InputContext` for sharding. Only pass this in for between input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and these cases, we will shard based on the `input_pipeline_id` and
@ -92,14 +94,14 @@ def get_distributed_dataset(dataset,
dataset, dataset,
input_workers, input_workers,
strategy, strategy,
split_batch_by=split_batch_by, num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context) input_context=input_context)
else: else:
return DistributedDatasetV1( return DistributedDatasetV1(
dataset, dataset,
input_workers, input_workers,
strategy, strategy,
split_batch_by=split_batch_by, num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context) input_context=input_context)
@ -917,20 +919,24 @@ class DistributedDataset(_IterableInput):
dataset, dataset,
input_workers, input_workers,
strategy, strategy,
split_batch_by=None, num_replicas_in_sync=None,
input_context=None): input_context=None):
"""Distribute the dataset on all workers. """Distribute the dataset on all workers.
If `split_batch_by` is not None, we "split" each batch of the dataset by If `num_replicas_in_sync` is not None, we split each batch of the dataset
`split_batch_by` value. into `num_replicas_in_sync` smaller batches, to be distributed among that
worker's replicas, so that the batch size for a global step (across all
workers and replicas) is as expected.
Args: Args:
dataset: `tf.data.Dataset` that will be used as the input source. dataset: `tf.data.Dataset` that will be used as the input source.
input_workers: an `InputWorkers` object. input_workers: an `InputWorkers` object.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch. handle last partial batch.
split_batch_by: Optional integer. If present, we "split" each batch of the num_replicas_in_sync: Optional integer. If this is not None, the value
dataset by `split_batch_by` value. is used to decide how to rebatch datasets into smaller batches so that
the total batch size for each step (across all workers and replicas)
adds up to `dataset`'s batch size.
input_context: `InputContext` for sharding. Only pass this in for between input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and these cases, we will shard based on the `input_pipeline_id` and
@ -942,36 +948,29 @@ class DistributedDataset(_IterableInput):
# different subset of files. If that is not possible, will attempt to shard # different subset of files. If that is not possible, will attempt to shard
# the final input such that each worker will run the entire preprocessing # the final input such that each worker will run the entire preprocessing
# pipeline and only receive its own shard of the dataset. # pipeline and only receive its own shard of the dataset.
if split_batch_by:
try: # Additionally, we rebatch the dataset on each worker into
# pylint: disable=protected-access # `num_replicas_in_sync` smaller batches to be distributed among that
with ops.colocate_with(dataset._variant_tensor): # worker's replicas, so that the batch size for a global step (across all
dataset = distribute._LegacyRebatchDataset(dataset, split_batch_by) # workers and replicas) adds up to the original dataset's batch size.
# Add a prefetch to pipeline rebatching for performance. if num_replicas_in_sync is not None:
# TODO(rachelim): Instead of inserting an extra prefetch stage here, num_workers = input_context.num_input_pipelines if input_context else len(
# leverage static graph rewrites to insert _RebatchDataset before input_workers.worker_devices)
# the final `prefetch` if it exists. rebatch_fn = self._make_rebatch_fn(dataset, num_workers,
dataset = dataset.prefetch(split_batch_by) num_replicas_in_sync)
except errors.InvalidArgumentError as e:
if "without encountering a batch" in str(e):
six.reraise(
ValueError,
ValueError(
"Call the `batch` method on the input Dataset in order to be "
"able to split your input across {} replicas.\n Please "
"the tf.distribute.Strategy guide. {}".format(
split_batch_by, e)),
sys.exc_info()[2])
else: else:
raise rebatch_fn = None
self._cloned_datasets = [] self._cloned_datasets = []
if input_context: if input_context:
# Between-graph where we rely on the input_context for sharding # Between-graph where we rely on the input_context for sharding
assert input_workers.num_workers == 1 assert input_workers.num_workers == 1
if rebatch_fn is not None:
dataset = rebatch_fn(dataset, input_context.input_pipeline_id)
dataset = input_ops.auto_shard_dataset(dataset, dataset = input_ops.auto_shard_dataset(dataset,
input_context.num_input_pipelines, input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id,
num_replicas_in_sync)
self._cloned_datasets.append(dataset) self._cloned_datasets.append(dataset)
else: else:
replicated_ds = distribute.replicate(dataset, replicated_ds = distribute.replicate(dataset,
@ -980,16 +979,73 @@ class DistributedDataset(_IterableInput):
with ops.device(worker): with ops.device(worker):
cloned_dataset = replicated_ds[worker] cloned_dataset = replicated_ds[worker]
cloned_dataset = cloned_dataset.with_options(dataset.options()) cloned_dataset = cloned_dataset.with_options(dataset.options())
if rebatch_fn is not None:
cloned_dataset = rebatch_fn(cloned_dataset, i)
cloned_dataset = input_ops.auto_shard_dataset( cloned_dataset = input_ops.auto_shard_dataset(
cloned_dataset, len(input_workers.worker_devices), i) cloned_dataset, len(input_workers.worker_devices), i,
num_replicas_in_sync)
self._cloned_datasets.append(cloned_dataset) self._cloned_datasets.append(cloned_dataset)
self._input_workers = input_workers self._input_workers = input_workers
self._strategy = strategy self._strategy = strategy
self._enable_get_next_as_optional = _enable_get_next_as_optional( self._enable_get_next_as_optional = _enable_get_next_as_optional(
self._strategy, dataset.element_spec) self._strategy, dataset.element_spec)
self._element_spec = _create_distributed_tensor_spec(self._strategy, self._element_spec = _create_distributed_tensor_spec(
dataset.element_spec) # pylint: disable=protected-access self._strategy, self._cloned_datasets[0].element_spec)
def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync):
"""Returns a callable that rebatches the input dataset.
Args:
dataset: A `tf.data.Dataset` representing the dataset to be distributed.
num_workers: An integer representing the number of workers to distribute
`dataset` among.
num_replicas_in_sync: An integer representing the number of replicas in
sync across all workers.
"""
if num_replicas_in_sync % num_workers:
raise ValueError(
"tf.distribute expects every worker to have the same number of "
"replicas. However, encountered `num_replicas_in_sync` ({}) that "
"cannot be divided by `num_workers` ({})".format(
num_replicas_in_sync, num_workers))
num_replicas_per_worker = num_replicas_in_sync // num_workers
with ops.colocate_with(dataset._variant_tensor): # pylint: disable=protected-access
batch_size = distribute.compute_batch_size(dataset)
def rebatch_fn(dataset, worker_index):
try:
# pylint: disable=protected-access
def apply_rebatch():
batch_sizes = distribute.batch_sizes_for_worker(
batch_size, num_workers, num_replicas_per_worker, worker_index)
return distribute._RebatchDataset(
dataset, batch_sizes).prefetch(num_replicas_per_worker)
def apply_legacy_rebatch():
return distribute._LegacyRebatchDataset(
dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker)
with ops.colocate_with(dataset._variant_tensor):
return control_flow_ops.cond(
math_ops.not_equal(batch_size, -1),
true_fn=apply_rebatch,
false_fn=apply_legacy_rebatch)
except errors.InvalidArgumentError as e:
if "without encountering a batch" in str(e):
six.reraise(
ValueError,
ValueError(
"Call the `batch` method on the input Dataset in order to be "
"able to split your input across {} replicas.\n Please see "
"the tf.distribute.Strategy guide. {}".format(
num_replicas_in_sync, e)),
sys.exc_info()[2])
else:
raise
return rebatch_fn
def __iter__(self): def __iter__(self):
if not (context.executing_eagerly() or if not (context.executing_eagerly() or
@ -1040,14 +1096,14 @@ class DistributedDatasetV1(DistributedDataset):
dataset, dataset,
input_workers, input_workers,
strategy, strategy,
split_batch_by=None, num_replicas_in_sync=None,
input_context=None): input_context=None):
self._input_workers = input_workers self._input_workers = input_workers
super(DistributedDatasetV1, self).__init__( super(DistributedDatasetV1, self).__init__(
dataset, dataset,
input_workers, input_workers,
strategy, strategy,
split_batch_by=split_batch_by, num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context) input_context=input_context)
def make_one_shot_iterator(self): def make_one_shot_iterator(self):
@ -1305,20 +1361,24 @@ class DatasetIterator(DistributedIteratorV1):
dataset, dataset,
input_workers, input_workers,
strategy, strategy,
split_batch_by=None, num_replicas_in_sync=None,
input_context=None): input_context=None):
"""Make an iterator for the dataset on given devices. """Make an iterator for the dataset on given devices.
If `split_batch_by` is not None, we "split" each batch of the If `num_replicas_in_sync` is not None, we split each batch of the dataset
dataset by `split_batch_by` value. into `num_replicas_in_sync` smaller batches, to be distributed among that
worker's replicas, so that the batch size for a global step (across all
workers and replicas) is as expected.
Args: Args:
dataset: `tf.data.Dataset` that will be used as the input source. dataset: `tf.data.Dataset` that will be used as the input source.
input_workers: an `InputWorkers` object. input_workers: an `InputWorkers` object.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch. handle last partial batch.
split_batch_by: Optional integer. If present, we "split" each batch of the num_replicas_in_sync: Optional integer. If this is not None, the value is
dataset by `split_batch_by` value. used to decide how to rebatch datasets into smaller batches so that the
total batch size for each step (across all workers and replicas) adds up
to `dataset`'s batch size.
input_context: `InputContext` for sharding. Only pass this in for between input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and these cases, we will shard based on the `input_pipeline_id` and
@ -1328,7 +1388,7 @@ class DatasetIterator(DistributedIteratorV1):
dataset, dataset,
input_workers, input_workers,
strategy, strategy,
split_batch_by=split_batch_by, num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context) input_context=input_context)
worker_iterators = _create_iterators_per_worker( worker_iterators = _create_iterators_per_worker(
dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access

View File

@ -61,7 +61,7 @@ class DistributedIteratorTestBase(test.TestCase):
dataset_or_input_fn, dataset_or_input_fn,
input_workers, input_workers,
devices, devices,
split_batch_by, num_replicas_in_sync,
strategy, strategy,
input_context=None): input_context=None):
# The `input_context` passed in is to shard dataset for # The `input_context` passed in is to shard dataset for
@ -93,7 +93,7 @@ class DistributedIteratorTestBase(test.TestCase):
dataset_or_input_fn, dataset_or_input_fn,
input_workers, input_workers,
strategy, strategy,
split_batch_by=split_batch_by, num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context) input_context=input_context)
return iterator return iterator
@ -101,7 +101,7 @@ class DistributedIteratorTestBase(test.TestCase):
input_type, input_type,
dataset, dataset,
input_workers, input_workers,
split_batch_by, num_replicas_in_sync,
strategy, strategy,
input_context=None): input_context=None):
if input_type == "dataset": if input_type == "dataset":
@ -110,14 +110,14 @@ class DistributedIteratorTestBase(test.TestCase):
dataset, dataset,
input_workers, input_workers,
strategy, strategy,
split_batch_by=split_batch_by, num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context) input_context=input_context)
else: else:
return input_lib.DistributedDatasetV1( return input_lib.DistributedDatasetV1(
dataset, dataset,
input_workers, input_workers,
strategy, strategy,
split_batch_by=split_batch_by, num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context) input_context=input_context)
else: else:
return strategy.distribute_datasets_from_function(dataset) return strategy.distribute_datasets_from_function(dataset)
@ -163,7 +163,7 @@ class DistributedIteratorTestBase(test.TestCase):
expected_values, expected_values,
strategy, strategy,
sess=None, sess=None,
split_batch_by=None, num_replicas_in_sync=None,
input_context=None): input_context=None):
if iteration_type == "for_loop" and not context.executing_eagerly(): if iteration_type == "for_loop" and not context.executing_eagerly():
self.skipTest("unsupported test combination.") self.skipTest("unsupported test combination.")
@ -183,7 +183,7 @@ class DistributedIteratorTestBase(test.TestCase):
dataset_or_input_fn, dataset_or_input_fn,
input_workers, input_workers,
devices, devices,
split_batch_by, num_replicas_in_sync,
strategy, strategy,
input_context=input_context) input_context=input_context)
else: else:
@ -192,7 +192,7 @@ class DistributedIteratorTestBase(test.TestCase):
input_type, input_type,
dataset_or_input_fn, dataset_or_input_fn,
input_workers, input_workers,
split_batch_by, num_replicas_in_sync,
strategy, strategy,
input_context=input_context) input_context=input_context)
@ -361,10 +361,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution, def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution,
enable_get_next_as_optional): enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
if tf2.enabled(): dataset_fn = lambda _: dataset_ops.Dataset.range(10)
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
else:
dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
dataset_or_input_fn = self._create_dataset_or_input_fn( dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn) input_type, dataset_fn)
@ -419,9 +416,6 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
distribution, enable_get_next_as_optional): distribution, enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])] "/device:CPU:0"])]
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
else:
dataset_fn = lambda _: dataset_ops.Dataset.range(10) dataset_fn = lambda _: dataset_ops.Dataset.range(10)
dataset_or_input_fn = self._create_dataset_or_input_fn( dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn) input_type, dataset_fn)
@ -455,9 +449,6 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
worker_device_pairs.setdefault(host_device, []) worker_device_pairs.setdefault(host_device, [])
worker_device_pairs[host_device].append(tpu_device) worker_device_pairs[host_device].append(tpu_device)
worker_device_pairs = worker_device_pairs.items() worker_device_pairs = worker_device_pairs.items()
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
else:
dataset_fn = lambda _: dataset_ops.Dataset.range(10) dataset_fn = lambda _: dataset_ops.Dataset.range(10)
dataset_or_input_fn = self._create_dataset_or_input_fn( dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn) input_type, dataset_fn)
@ -493,14 +484,10 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
def dataset_fn(ctx): def dataset_fn(ctx):
del ctx del ctx
if tf2.enabled():
dataset1 = dataset_ops.DatasetV2.range(10)
dataset2 = dataset_ops.DatasetV2.range(10).map(lambda x: x**2)
return dataset_ops.DatasetV2.zip((dataset1, dataset2))
else:
dataset1 = dataset_ops.Dataset.range(10) dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2)) return dataset_ops.Dataset.zip((dataset1, dataset2))
dataset_or_input_fn = self._create_dataset_or_input_fn( dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn) input_type, dataset_fn)
@ -563,7 +550,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
input_workers = input_lib.InputWorkers(worker_device_pairs) input_workers = input_lib.InputWorkers(worker_device_pairs)
dataset = dataset_ops.DatasetV2.range(10) dataset = dataset_ops.Dataset.range(10)
dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers, dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
distribution) distribution)
@ -663,7 +650,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
worker_device_pairs, worker_device_pairs,
expected_values, expected_values,
distribution, distribution,
split_batch_by=distribution.num_replicas_in_sync, num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context()) input_context=distribution.extended._make_input_context())
@combinations.generate( @combinations.generate(
@ -724,7 +711,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
worker_device_pairs, worker_device_pairs,
expected_values, expected_values,
distribution, distribution,
split_batch_by=distribution.num_replicas_in_sync, num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context()) input_context=distribution.extended._make_input_context())
@combinations.generate( @combinations.generate(
@ -733,14 +720,14 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
input_type=["dataset"], input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"], api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"], iteration_type=["get_next", "for_loop"],
split_batch_by=[None, 2], num_replicas_in_sync=[None, 2],
distribution=[ distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu strategy_combinations.central_storage_strategy_with_gpu_and_cpu
], ],
enable_get_next_as_optional=[True, False])) enable_get_next_as_optional=[True, False]))
def testBatchSplitting(self, input_type, api_type, iteration_type, def testBatchSplitting(self, input_type, api_type, iteration_type,
split_batch_by, distribution, num_replicas_in_sync, distribution,
enable_get_next_as_optional): enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])] "/device:CPU:0"])]
@ -750,7 +737,8 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
input_type, dataset_fn) input_type, dataset_fn)
updated_batch_size = ( updated_batch_size = (
batch_size // split_batch_by if split_batch_by else batch_size) batch_size //
num_replicas_in_sync if num_replicas_in_sync else batch_size)
expected_values = [[range(i, i+updated_batch_size), expected_values = [[range(i, i+updated_batch_size),
range(i+updated_batch_size, i+2*updated_batch_size)] range(i+updated_batch_size, i+2*updated_batch_size)]
for i in range(0, 100, updated_batch_size*2)] for i in range(0, 100, updated_batch_size*2)]
@ -766,7 +754,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
expected_values, expected_values,
distribution, distribution,
sess=None, sess=None,
split_batch_by=split_batch_by) num_replicas_in_sync=num_replicas_in_sync)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
@ -774,13 +762,13 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
input_type=["dataset"], input_type=["dataset"],
api_type=["wrap_into_dataset"], api_type=["wrap_into_dataset"],
iteration_type=["get_next", "for_loop"], iteration_type=["get_next", "for_loop"],
split_batch_by=[None, 2], num_replicas_in_sync=[None, 2],
distribution=[ distribution=[
strategy_combinations.multi_worker_mirrored_2x2_gpu, strategy_combinations.multi_worker_mirrored_2x2_gpu,
], ],
enable_get_next_as_optional=[True, False])) enable_get_next_as_optional=[True, False]))
def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type, def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type,
split_batch_by, distribution, num_replicas_in_sync, distribution,
enable_get_next_as_optional): enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:GPU:1"])] "/device:GPU:1"])]
@ -796,7 +784,8 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
input_type, dataset_fn) input_type, dataset_fn)
updated_batch_size = ( updated_batch_size = (
batch_size // split_batch_by if split_batch_by else batch_size) batch_size //
num_replicas_in_sync if num_replicas_in_sync else batch_size)
expected_values = [ expected_values = [
[ # pylint: disable=g-complex-comprehension [ # pylint: disable=g-complex-comprehension
range(i, i + updated_batch_size), range(i, i + updated_batch_size),
@ -815,7 +804,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
expected_values, expected_values,
distribution, distribution,
sess=None, sess=None,
split_batch_by=split_batch_by) num_replicas_in_sync=num_replicas_in_sync)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
@ -1249,6 +1238,187 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
expected_for_sum = 310. expected_for_sum = 310.
self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3) self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3)
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
]))
def testMWMSPartialBatch(self, input_type, api_type, iteration_type,
distribution):
# Test case: 2 workers, 1 replica each.
# This test simulates the sharded behavior when we have two files each with
# 12 elements and a global batch size of 8. When we consider the dataset in
# aggregate (non-distributed), there are 24 elements divided into 3 batches
# of size 8. Hence, the correct distributed behavior is for each replica to
# see sub-batches of size 4, over three steps.
def dataset_fn(ctx):
del ctx
dataset = dataset_ops.Dataset.range(12).batch(8)
# Set the sharding behavior to OFF for simplicity of test setup; namely,
# `dataset` defines the per-worker dataset and will not be further
# sharded. Each worker will see a dataset that is
# tf.data.Dataset.range(12).batch(8).rebatch(...).
options = dataset_ops.Options()
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
dataset = dataset.with_options(options)
return dataset
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
# Actual devices don't matter in this test as long as there is 1 local
# replica.
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
# Each test runs individually on each worker, so we compare the
# values on each worker. Each worker should rebatch its dataset into
# smaller batches of size 4.
expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]]
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset,
worker_device_pairs,
expected_values,
distribution,
num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context())
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
]))
def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type,
iteration_type, distribution):
# Test case: 2 workers, 1 replica each.
# This test simulates the sharded behavior when we have two files each with
# 12 elements and a global batch size of 8. When we consider the dataset in
# aggregate (non-distributed), there are 24 elements divided into 3 batches
# of size 8. Hence, the correct distributed behavior is for each replica to
# see sub-batches of size 4, over three steps. However, when we create a
# DistributedDataset and cannot statically infer the intended global batch
# size (e.g. if the user does not use a batching dataset), each worker will
# rebatch based on the dynamic batch size of the data encountered, even when
# it encounters partial batches. The last per-worker partial batch (size 4)
# ends up being split into two replicas, resulting in 4 steps in total, of
# (global) batch sizes 8, 8, 4, 4.
def dataset_fn(ctx):
del ctx
# The following dataset is equivalent to
# tf.data.Dataset.range(12).batch(8), but does not use a batching dataset.
# This causes DistributedDataset to use LegacyRebatch instead.
batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4])
offsets = dataset_ops.Dataset.from_tensor_slices([0, 8])
dataset = dataset_ops.Dataset.zip((offsets, batch_sizes))
def map_fn(offset, batch_size):
return math_ops.range(offset, offset + batch_size)
dataset = dataset.map(map_fn)
# Set the sharding behavior to OFF for simplicity of test setup; namely,
# `dataset` defines the per-worker dataset and will not be further
# sharded. Each worker will see a dataset that is equivalent to
# tf.data.Dataset.range(12).batch(8).rebatch(...).
options = dataset_ops.Options()
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
dataset = dataset.with_options(options)
return dataset
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
# Actual devices don't matter in this test as long as the number of global
# replicas is 2.
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
# Each test runs individually on each worker, so we compare the
# values on each worker. Each worker should rebatch its dataset into
# smaller batches of size 4.
expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]]
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset,
worker_device_pairs,
expected_values,
distribution,
num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context())
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
],
auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA]))
def testMWMSWithDataSharding(self, input_type, api_type, iteration_type,
distribution, auto_shard_policy):
# Test case: 2 workers, 1 replica each.
# This test simulates the sharded behavior the dataset is sharded by data
# and the batch size is indivisible by the number of replicas. This checks
# that the elements are as expected and the batch size across all workers
# adds up to 3. This test will only pass if the autoshard rewrite rewrites
# RebatchDatasetV2 to legacy RebatchDataset when sharding by data.
def dataset_fn(ctx):
del ctx
dataset = dataset_ops.Dataset.range(8).batch(3)
# Set the sharding behavior to OFF for simplicity of test setup; namely,
# `dataset` defines the per-worker dataset and will not be further
# sharded. Each worker will see a dataset that is
# tf.data.Dataset.range(12).batch(8).rebatch(...).
options = dataset_ops.Options()
options.experimental_distribute.auto_shard_policy = auto_shard_policy
dataset = dataset.with_options(options)
return dataset
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
# Actual devices don't matter in this test as long as there is 1 local
# replica.
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
# Each test runs individually on each worker, so we compare the
# values on each worker. We expect each worker to see different shards of
# data.
cr = distribution.cluster_resolver
worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type,
cr.task_id)
if worker_id == 0:
expected_values = [[[0, 1]], [[3, 4]], [[6]]]
elif worker_id == 1:
expected_values = [[[2]], [[5]], [[7]]]
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset,
worker_device_pairs,
expected_values,
distribution,
num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context())
if __name__ == "__main__": if __name__ == "__main__":
combinations.main() combinations.main()

View File

@ -27,7 +27,7 @@ from tensorflow.python.framework import ops
# pylint: disable=protected-access # pylint: disable=protected-access
def auto_shard_dataset(dataset, num_shards, index): def auto_shard_dataset(dataset, num_shards, index, num_replicas_in_sync=None):
"""Shard the input pipeline by sharding the underlying list of files. """Shard the input pipeline by sharding the underlying list of files.
Args: Args:
@ -37,6 +37,8 @@ def auto_shard_dataset(dataset, num_shards, index):
shards operating in parallel. Same usage as in `tf.data.Dataset.shard`. shards operating in parallel. Same usage as in `tf.data.Dataset.shard`.
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
Same usage as in `tf.data.Dataset.shard`. Same usage as in `tf.data.Dataset.shard`.
num_replicas_in_sync: An integer representing the total number of replicas
across all workers. This is used in the rewrite when sharding by data.
Returns: Returns:
A modified `Dataset` obtained by updating the pipeline sharded by the A modified `Dataset` obtained by updating the pipeline sharded by the
@ -45,10 +47,14 @@ def auto_shard_dataset(dataset, num_shards, index):
""" """
if (dataset.options().experimental_distribute.auto_shard_policy != if (dataset.options().experimental_distribute.auto_shard_policy !=
AutoShardPolicy.OFF): AutoShardPolicy.OFF):
if num_replicas_in_sync is None:
num_replicas_in_sync = 1
if isinstance(dataset, dataset_ops.DatasetV1): if isinstance(dataset, dataset_ops.DatasetV1):
return distribute._AutoShardDatasetV1(dataset, num_shards, index) return distribute._AutoShardDatasetV1(dataset, num_shards, index,
num_replicas_in_sync)
else: else:
return distribute._AutoShardDataset(dataset, num_shards, index) return distribute._AutoShardDataset(dataset, num_shards, index,
num_replicas_in_sync)
else: else:
return dataset return dataset

View File

@ -479,7 +479,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
dataset, dataset,
self._input_workers, self._input_workers,
self._container_strategy(), self._container_strategy(),
split_batch_by=self._num_replicas_in_sync) num_replicas_in_sync=self._num_replicas_in_sync)
def _make_input_fn_iterator( def _make_input_fn_iterator(
self, self,
@ -501,7 +501,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
dataset, dataset,
self._input_workers_with_options(options), self._input_workers_with_options(options),
self._container_strategy(), self._container_strategy(),
split_batch_by=self._num_replicas_in_sync) num_replicas_in_sync=self._num_replicas_in_sync)
def _experimental_make_numpy_dataset(self, numpy_input, session): def _experimental_make_numpy_dataset(self, numpy_input, session):
return numpy_dataset.one_host_numpy_dataset( return numpy_dataset.one_host_numpy_dataset(

View File

@ -351,14 +351,14 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
dataset, dataset,
self._input_workers_with_options(options), self._input_workers_with_options(options),
self._container_strategy(), self._container_strategy(),
split_batch_by=self._num_replicas_in_sync) num_replicas_in_sync=self._num_replicas_in_sync)
def _make_dataset_iterator(self, dataset): def _make_dataset_iterator(self, dataset):
return input_lib.DatasetIterator( return input_lib.DatasetIterator(
dataset, dataset,
self._input_workers, self._input_workers,
self._container_strategy(), self._container_strategy(),
split_batch_by=self._num_replicas_in_sync) num_replicas_in_sync=self._num_replicas_in_sync)
def _make_input_fn_iterator( def _make_input_fn_iterator(
self, self,

View File

@ -752,7 +752,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
dataset, dataset,
input_workers, input_workers,
self._container_strategy(), self._container_strategy(),
split_batch_by=self._num_replicas_in_sync) num_replicas_in_sync=self._num_replicas_in_sync)
def _make_input_fn_iterator( def _make_input_fn_iterator(
self, self,
@ -809,7 +809,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
dataset, dataset,
self._get_input_workers(options), self._get_input_workers(options),
self._container_strategy(), self._container_strategy(),
split_batch_by=self._num_replicas_in_sync) num_replicas_in_sync=self._num_replicas_in_sync)
def _distribute_datasets_from_function(self, dataset_fn, options): def _distribute_datasets_from_function(self, dataset_fn, options):
input_workers = self._get_input_workers(options) input_workers = self._get_input_workers(options)