From a7f8535480bb204939954ad31bf23b190920e1e6 Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Wed, 30 Sep 2020 11:50:44 -0700 Subject: [PATCH] [tf.data + tf.distribute] Use RebatchDataset instead of LegacyRebatchDataset in distribution strategies when global batch size can be statically determined. PiperOrigin-RevId: 334646795 Change-Id: I6ed7a09e38577c1b1d3487473e12a48c92c2a2c7 --- .../data/experimental/ops/distribute.py | 4 +- .../collective_all_reduce_strategy.py | 4 +- tensorflow/python/distribute/input_lib.py | 148 +++++++---- .../python/distribute/input_lib_test.py | 248 +++++++++++++++--- tensorflow/python/distribute/input_ops.py | 12 +- .../python/distribute/mirrored_strategy.py | 4 +- .../distribute/parameter_server_strategy.py | 4 +- tensorflow/python/distribute/tpu_strategy.py | 4 +- 8 files changed, 332 insertions(+), 96 deletions(-) diff --git a/tensorflow/python/data/experimental/ops/distribute.py b/tensorflow/python/data/experimental/ops/distribute.py index 5105f30fd07..568c01646de 100644 --- a/tensorflow/python/data/experimental/ops/distribute.py +++ b/tensorflow/python/data/experimental/ops/distribute.py @@ -85,9 +85,9 @@ class _AutoShardDataset(dataset_ops.UnaryDataset): return self._element_spec -def _AutoShardDatasetV1(input_dataset, num_workers, index): # pylint: disable=invalid-name +def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None): # pylint: disable=invalid-name return dataset_ops.DatasetV1Adapter( - _AutoShardDataset(input_dataset, num_workers, index)) + _AutoShardDataset(input_dataset, num_workers, index, num_replicas)) class _RebatchDataset(dataset_ops.UnaryDataset): diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index a3e63e8a6f1..363b84acc51 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -476,7 +476,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): dataset, self._input_workers_with_options(options), self._container_strategy(), - split_batch_by=self._num_replicas_in_sync, + num_replicas_in_sync=self._num_replicas_in_sync, input_context=input_context) def _distribute_datasets_from_function(self, dataset_fn, options): @@ -505,7 +505,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): dataset, self._input_workers, self._container_strategy(), - split_batch_by=self._num_replicas_in_sync, + num_replicas_in_sync=self._num_replicas_in_sync, input_context=input_context) def _make_input_fn_iterator( diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index df81fee3e37..d01cedcead0 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -61,7 +61,7 @@ from tensorflow.tools.docs import doc_controls def get_distributed_dataset(dataset, input_workers, strategy, - split_batch_by=None, + num_replicas_in_sync=None, input_context=None): """Returns a distributed dataset from the given tf.data.Dataset instance. @@ -77,8 +77,10 @@ def get_distributed_dataset(dataset, iterators should be created. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. - split_batch_by: Optional integer. If present, we "split" each batch of the - dataset by `split_batch_by` value. + num_replicas_in_sync: Optional integer. If this is not None, the value is + used to decide how to rebatch datasets into smaller batches so that + the total batch size for each step (across all workers and replicas) + adds up to `dataset`'s batch size. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and @@ -92,14 +94,14 @@ def get_distributed_dataset(dataset, dataset, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) else: return DistributedDatasetV1( dataset, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) @@ -917,20 +919,24 @@ class DistributedDataset(_IterableInput): dataset, input_workers, strategy, - split_batch_by=None, + num_replicas_in_sync=None, input_context=None): """Distribute the dataset on all workers. - If `split_batch_by` is not None, we "split" each batch of the dataset by - `split_batch_by` value. + If `num_replicas_in_sync` is not None, we split each batch of the dataset + into `num_replicas_in_sync` smaller batches, to be distributed among that + worker's replicas, so that the batch size for a global step (across all + workers and replicas) is as expected. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. - split_batch_by: Optional integer. If present, we "split" each batch of the - dataset by `split_batch_by` value. + num_replicas_in_sync: Optional integer. If this is not None, the value + is used to decide how to rebatch datasets into smaller batches so that + the total batch size for each step (across all workers and replicas) + adds up to `dataset`'s batch size. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and @@ -942,36 +948,29 @@ class DistributedDataset(_IterableInput): # different subset of files. If that is not possible, will attempt to shard # the final input such that each worker will run the entire preprocessing # pipeline and only receive its own shard of the dataset. - if split_batch_by: - try: - # pylint: disable=protected-access - with ops.colocate_with(dataset._variant_tensor): - dataset = distribute._LegacyRebatchDataset(dataset, split_batch_by) - # Add a prefetch to pipeline rebatching for performance. - # TODO(rachelim): Instead of inserting an extra prefetch stage here, - # leverage static graph rewrites to insert _RebatchDataset before - # the final `prefetch` if it exists. - dataset = dataset.prefetch(split_batch_by) - except errors.InvalidArgumentError as e: - if "without encountering a batch" in str(e): - six.reraise( - ValueError, - ValueError( - "Call the `batch` method on the input Dataset in order to be " - "able to split your input across {} replicas.\n Please " - "the tf.distribute.Strategy guide. {}".format( - split_batch_by, e)), - sys.exc_info()[2]) - else: - raise + + # Additionally, we rebatch the dataset on each worker into + # `num_replicas_in_sync` smaller batches to be distributed among that + # worker's replicas, so that the batch size for a global step (across all + # workers and replicas) adds up to the original dataset's batch size. + if num_replicas_in_sync is not None: + num_workers = input_context.num_input_pipelines if input_context else len( + input_workers.worker_devices) + rebatch_fn = self._make_rebatch_fn(dataset, num_workers, + num_replicas_in_sync) + else: + rebatch_fn = None self._cloned_datasets = [] if input_context: # Between-graph where we rely on the input_context for sharding assert input_workers.num_workers == 1 + if rebatch_fn is not None: + dataset = rebatch_fn(dataset, input_context.input_pipeline_id) dataset = input_ops.auto_shard_dataset(dataset, input_context.num_input_pipelines, - input_context.input_pipeline_id) + input_context.input_pipeline_id, + num_replicas_in_sync) self._cloned_datasets.append(dataset) else: replicated_ds = distribute.replicate(dataset, @@ -980,16 +979,73 @@ class DistributedDataset(_IterableInput): with ops.device(worker): cloned_dataset = replicated_ds[worker] cloned_dataset = cloned_dataset.with_options(dataset.options()) + if rebatch_fn is not None: + cloned_dataset = rebatch_fn(cloned_dataset, i) cloned_dataset = input_ops.auto_shard_dataset( - cloned_dataset, len(input_workers.worker_devices), i) + cloned_dataset, len(input_workers.worker_devices), i, + num_replicas_in_sync) self._cloned_datasets.append(cloned_dataset) self._input_workers = input_workers self._strategy = strategy self._enable_get_next_as_optional = _enable_get_next_as_optional( self._strategy, dataset.element_spec) - self._element_spec = _create_distributed_tensor_spec(self._strategy, - dataset.element_spec) # pylint: disable=protected-access + self._element_spec = _create_distributed_tensor_spec( + self._strategy, self._cloned_datasets[0].element_spec) + + def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync): + """Returns a callable that rebatches the input dataset. + + Args: + dataset: A `tf.data.Dataset` representing the dataset to be distributed. + num_workers: An integer representing the number of workers to distribute + `dataset` among. + num_replicas_in_sync: An integer representing the number of replicas in + sync across all workers. + """ + if num_replicas_in_sync % num_workers: + raise ValueError( + "tf.distribute expects every worker to have the same number of " + "replicas. However, encountered `num_replicas_in_sync` ({}) that " + "cannot be divided by `num_workers` ({})".format( + num_replicas_in_sync, num_workers)) + + num_replicas_per_worker = num_replicas_in_sync // num_workers + with ops.colocate_with(dataset._variant_tensor): # pylint: disable=protected-access + batch_size = distribute.compute_batch_size(dataset) + + def rebatch_fn(dataset, worker_index): + try: + # pylint: disable=protected-access + def apply_rebatch(): + batch_sizes = distribute.batch_sizes_for_worker( + batch_size, num_workers, num_replicas_per_worker, worker_index) + return distribute._RebatchDataset( + dataset, batch_sizes).prefetch(num_replicas_per_worker) + + def apply_legacy_rebatch(): + return distribute._LegacyRebatchDataset( + dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker) + + with ops.colocate_with(dataset._variant_tensor): + return control_flow_ops.cond( + math_ops.not_equal(batch_size, -1), + true_fn=apply_rebatch, + false_fn=apply_legacy_rebatch) + except errors.InvalidArgumentError as e: + if "without encountering a batch" in str(e): + six.reraise( + ValueError, + ValueError( + "Call the `batch` method on the input Dataset in order to be " + "able to split your input across {} replicas.\n Please see " + "the tf.distribute.Strategy guide. {}".format( + num_replicas_in_sync, e)), + sys.exc_info()[2]) + else: + raise + + return rebatch_fn def __iter__(self): if not (context.executing_eagerly() or @@ -1040,14 +1096,14 @@ class DistributedDatasetV1(DistributedDataset): dataset, input_workers, strategy, - split_batch_by=None, + num_replicas_in_sync=None, input_context=None): self._input_workers = input_workers super(DistributedDatasetV1, self).__init__( dataset, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) def make_one_shot_iterator(self): @@ -1305,20 +1361,24 @@ class DatasetIterator(DistributedIteratorV1): dataset, input_workers, strategy, - split_batch_by=None, + num_replicas_in_sync=None, input_context=None): """Make an iterator for the dataset on given devices. - If `split_batch_by` is not None, we "split" each batch of the - dataset by `split_batch_by` value. + If `num_replicas_in_sync` is not None, we split each batch of the dataset + into `num_replicas_in_sync` smaller batches, to be distributed among that + worker's replicas, so that the batch size for a global step (across all + workers and replicas) is as expected. Args: dataset: `tf.data.Dataset` that will be used as the input source. input_workers: an `InputWorkers` object. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to handle last partial batch. - split_batch_by: Optional integer. If present, we "split" each batch of the - dataset by `split_batch_by` value. + num_replicas_in_sync: Optional integer. If this is not None, the value is + used to decide how to rebatch datasets into smaller batches so that the + total batch size for each step (across all workers and replicas) adds up + to `dataset`'s batch size. input_context: `InputContext` for sharding. Only pass this in for between graph multi-worker cases where there is only one `input_worker`. In these cases, we will shard based on the `input_pipeline_id` and @@ -1328,7 +1388,7 @@ class DatasetIterator(DistributedIteratorV1): dataset, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) worker_iterators = _create_iterators_per_worker( dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index ec0b591d710..b266dd25bc0 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -61,7 +61,7 @@ class DistributedIteratorTestBase(test.TestCase): dataset_or_input_fn, input_workers, devices, - split_batch_by, + num_replicas_in_sync, strategy, input_context=None): # The `input_context` passed in is to shard dataset for @@ -93,7 +93,7 @@ class DistributedIteratorTestBase(test.TestCase): dataset_or_input_fn, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) return iterator @@ -101,7 +101,7 @@ class DistributedIteratorTestBase(test.TestCase): input_type, dataset, input_workers, - split_batch_by, + num_replicas_in_sync, strategy, input_context=None): if input_type == "dataset": @@ -110,14 +110,14 @@ class DistributedIteratorTestBase(test.TestCase): dataset, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) else: return input_lib.DistributedDatasetV1( dataset, input_workers, strategy, - split_batch_by=split_batch_by, + num_replicas_in_sync=num_replicas_in_sync, input_context=input_context) else: return strategy.distribute_datasets_from_function(dataset) @@ -163,7 +163,7 @@ class DistributedIteratorTestBase(test.TestCase): expected_values, strategy, sess=None, - split_batch_by=None, + num_replicas_in_sync=None, input_context=None): if iteration_type == "for_loop" and not context.executing_eagerly(): self.skipTest("unsupported test combination.") @@ -183,7 +183,7 @@ class DistributedIteratorTestBase(test.TestCase): dataset_or_input_fn, input_workers, devices, - split_batch_by, + num_replicas_in_sync, strategy, input_context=input_context) else: @@ -192,7 +192,7 @@ class DistributedIteratorTestBase(test.TestCase): input_type, dataset_or_input_fn, input_workers, - split_batch_by, + num_replicas_in_sync, strategy, input_context=input_context) @@ -361,10 +361,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution, enable_get_next_as_optional): worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] - if tf2.enabled(): - dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) - else: - dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) + dataset_fn = lambda _: dataset_ops.Dataset.range(10) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -419,10 +416,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, distribution, enable_get_next_as_optional): worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", "/device:CPU:0"])] - if tf2.enabled(): - dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) - else: - dataset_fn = lambda _: dataset_ops.Dataset.range(10) + dataset_fn = lambda _: dataset_ops.Dataset.range(10) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -455,10 +449,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, worker_device_pairs.setdefault(host_device, []) worker_device_pairs[host_device].append(tpu_device) worker_device_pairs = worker_device_pairs.items() - if tf2.enabled(): - dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) - else: - dataset_fn = lambda _: dataset_ops.Dataset.range(10) + dataset_fn = lambda _: dataset_ops.Dataset.range(10) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -493,14 +484,10 @@ class DistributedIteratorTest(DistributedIteratorTestBase, def dataset_fn(ctx): del ctx - if tf2.enabled(): - dataset1 = dataset_ops.DatasetV2.range(10) - dataset2 = dataset_ops.DatasetV2.range(10).map(lambda x: x**2) - return dataset_ops.DatasetV2.zip((dataset1, dataset2)) - else: - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -563,7 +550,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] input_workers = input_lib.InputWorkers(worker_device_pairs) - dataset = dataset_ops.DatasetV2.range(10) + dataset = dataset_ops.Dataset.range(10) dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers, distribution) @@ -663,7 +650,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, worker_device_pairs, expected_values, distribution, - split_batch_by=distribution.num_replicas_in_sync, + num_replicas_in_sync=distribution.num_replicas_in_sync, input_context=distribution.extended._make_input_context()) @combinations.generate( @@ -724,7 +711,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, worker_device_pairs, expected_values, distribution, - split_batch_by=distribution.num_replicas_in_sync, + num_replicas_in_sync=distribution.num_replicas_in_sync, input_context=distribution.extended._make_input_context()) @combinations.generate( @@ -733,14 +720,14 @@ class DistributedIteratorTest(DistributedIteratorTestBase, input_type=["dataset"], api_type=["wrap_into_iterator", "wrap_into_dataset"], iteration_type=["get_next", "for_loop"], - split_batch_by=[None, 2], + num_replicas_in_sync=[None, 2], distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.central_storage_strategy_with_gpu_and_cpu ], enable_get_next_as_optional=[True, False])) def testBatchSplitting(self, input_type, api_type, iteration_type, - split_batch_by, distribution, + num_replicas_in_sync, distribution, enable_get_next_as_optional): worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", "/device:CPU:0"])] @@ -750,7 +737,8 @@ class DistributedIteratorTest(DistributedIteratorTestBase, input_type, dataset_fn) updated_batch_size = ( - batch_size // split_batch_by if split_batch_by else batch_size) + batch_size // + num_replicas_in_sync if num_replicas_in_sync else batch_size) expected_values = [[range(i, i+updated_batch_size), range(i+updated_batch_size, i+2*updated_batch_size)] for i in range(0, 100, updated_batch_size*2)] @@ -766,7 +754,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, expected_values, distribution, sess=None, - split_batch_by=split_batch_by) + num_replicas_in_sync=num_replicas_in_sync) @combinations.generate( combinations.combine( @@ -774,13 +762,13 @@ class DistributedIteratorTest(DistributedIteratorTestBase, input_type=["dataset"], api_type=["wrap_into_dataset"], iteration_type=["get_next", "for_loop"], - split_batch_by=[None, 2], + num_replicas_in_sync=[None, 2], distribution=[ strategy_combinations.multi_worker_mirrored_2x2_gpu, ], enable_get_next_as_optional=[True, False])) def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type, - split_batch_by, distribution, + num_replicas_in_sync, distribution, enable_get_next_as_optional): worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", "/device:GPU:1"])] @@ -796,7 +784,8 @@ class DistributedIteratorTest(DistributedIteratorTestBase, input_type, dataset_fn) updated_batch_size = ( - batch_size // split_batch_by if split_batch_by else batch_size) + batch_size // + num_replicas_in_sync if num_replicas_in_sync else batch_size) expected_values = [ [ # pylint: disable=g-complex-comprehension range(i, i + updated_batch_size), @@ -815,7 +804,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, expected_values, distribution, sess=None, - split_batch_by=split_batch_by) + num_replicas_in_sync=num_replicas_in_sync) @combinations.generate( combinations.combine( @@ -1249,6 +1238,187 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, expected_for_sum = 310. self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3) + @combinations.generate( + combinations.combine( + mode=["eager"], + input_type=["dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + ])) + def testMWMSPartialBatch(self, input_type, api_type, iteration_type, + distribution): + # Test case: 2 workers, 1 replica each. + # This test simulates the sharded behavior when we have two files each with + # 12 elements and a global batch size of 8. When we consider the dataset in + # aggregate (non-distributed), there are 24 elements divided into 3 batches + # of size 8. Hence, the correct distributed behavior is for each replica to + # see sub-batches of size 4, over three steps. + def dataset_fn(ctx): + del ctx + dataset = dataset_ops.Dataset.range(12).batch(8) + + # Set the sharding behavior to OFF for simplicity of test setup; namely, + # `dataset` defines the per-worker dataset and will not be further + # sharded. Each worker will see a dataset that is + # tf.data.Dataset.range(12).batch(8).rebatch(...). + options = dataset_ops.Options() + options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF + dataset = dataset.with_options(options) + return dataset + + dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) + + # Actual devices don't matter in this test as long as there is 1 local + # replica. + worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] + + # Each test runs individually on each worker, so we compare the + # values on each worker. Each worker should rebatch its dataset into + # smaller batches of size 4. + expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]] + self._test_input_iteration( + input_type, + api_type, + iteration_type, + dataset, + worker_device_pairs, + expected_values, + distribution, + num_replicas_in_sync=distribution.num_replicas_in_sync, + input_context=distribution.extended._make_input_context()) + + @combinations.generate( + combinations.combine( + mode=["eager"], + input_type=["dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + ])) + def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type, + iteration_type, distribution): + # Test case: 2 workers, 1 replica each. + # This test simulates the sharded behavior when we have two files each with + # 12 elements and a global batch size of 8. When we consider the dataset in + # aggregate (non-distributed), there are 24 elements divided into 3 batches + # of size 8. Hence, the correct distributed behavior is for each replica to + # see sub-batches of size 4, over three steps. However, when we create a + # DistributedDataset and cannot statically infer the intended global batch + # size (e.g. if the user does not use a batching dataset), each worker will + # rebatch based on the dynamic batch size of the data encountered, even when + # it encounters partial batches. The last per-worker partial batch (size 4) + # ends up being split into two replicas, resulting in 4 steps in total, of + # (global) batch sizes 8, 8, 4, 4. + def dataset_fn(ctx): + del ctx + # The following dataset is equivalent to + # tf.data.Dataset.range(12).batch(8), but does not use a batching dataset. + # This causes DistributedDataset to use LegacyRebatch instead. + batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4]) + offsets = dataset_ops.Dataset.from_tensor_slices([0, 8]) + dataset = dataset_ops.Dataset.zip((offsets, batch_sizes)) + + def map_fn(offset, batch_size): + return math_ops.range(offset, offset + batch_size) + + dataset = dataset.map(map_fn) + + # Set the sharding behavior to OFF for simplicity of test setup; namely, + # `dataset` defines the per-worker dataset and will not be further + # sharded. Each worker will see a dataset that is equivalent to + # tf.data.Dataset.range(12).batch(8).rebatch(...). + options = dataset_ops.Options() + options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF + dataset = dataset.with_options(options) + return dataset + + dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) + + # Actual devices don't matter in this test as long as the number of global + # replicas is 2. + worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] + + # Each test runs individually on each worker, so we compare the + # values on each worker. Each worker should rebatch its dataset into + # smaller batches of size 4. + expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]] + self._test_input_iteration( + input_type, + api_type, + iteration_type, + dataset, + worker_device_pairs, + expected_values, + distribution, + num_replicas_in_sync=distribution.num_replicas_in_sync, + input_context=distribution.extended._make_input_context()) + + @combinations.generate( + combinations.combine( + mode=["eager"], + input_type=["dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + ], + auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA])) + def testMWMSWithDataSharding(self, input_type, api_type, iteration_type, + distribution, auto_shard_policy): + # Test case: 2 workers, 1 replica each. + # This test simulates the sharded behavior the dataset is sharded by data + # and the batch size is indivisible by the number of replicas. This checks + # that the elements are as expected and the batch size across all workers + # adds up to 3. This test will only pass if the autoshard rewrite rewrites + # RebatchDatasetV2 to legacy RebatchDataset when sharding by data. + def dataset_fn(ctx): + del ctx + dataset = dataset_ops.Dataset.range(8).batch(3) + + # Set the sharding behavior to OFF for simplicity of test setup; namely, + # `dataset` defines the per-worker dataset and will not be further + # sharded. Each worker will see a dataset that is + # tf.data.Dataset.range(12).batch(8).rebatch(...). + options = dataset_ops.Options() + options.experimental_distribute.auto_shard_policy = auto_shard_policy + dataset = dataset.with_options(options) + return dataset + + dataset = self._create_dataset_or_input_fn(input_type, dataset_fn) + + # Actual devices don't matter in this test as long as there is 1 local + # replica. + worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] + + # Each test runs individually on each worker, so we compare the + # values on each worker. We expect each worker to see different shards of + # data. + cr = distribution.cluster_resolver + worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type, + cr.task_id) + + if worker_id == 0: + expected_values = [[[0, 1]], [[3, 4]], [[6]]] + elif worker_id == 1: + expected_values = [[[2]], [[5]], [[7]]] + + self._test_input_iteration( + input_type, + api_type, + iteration_type, + dataset, + worker_device_pairs, + expected_values, + distribution, + num_replicas_in_sync=distribution.num_replicas_in_sync, + input_context=distribution.extended._make_input_context()) + if __name__ == "__main__": combinations.main() diff --git a/tensorflow/python/distribute/input_ops.py b/tensorflow/python/distribute/input_ops.py index 37a7ed642d0..de828f4bcd9 100644 --- a/tensorflow/python/distribute/input_ops.py +++ b/tensorflow/python/distribute/input_ops.py @@ -27,7 +27,7 @@ from tensorflow.python.framework import ops # pylint: disable=protected-access -def auto_shard_dataset(dataset, num_shards, index): +def auto_shard_dataset(dataset, num_shards, index, num_replicas_in_sync=None): """Shard the input pipeline by sharding the underlying list of files. Args: @@ -37,6 +37,8 @@ def auto_shard_dataset(dataset, num_shards, index): shards operating in parallel. Same usage as in `tf.data.Dataset.shard`. index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. Same usage as in `tf.data.Dataset.shard`. + num_replicas_in_sync: An integer representing the total number of replicas + across all workers. This is used in the rewrite when sharding by data. Returns: A modified `Dataset` obtained by updating the pipeline sharded by the @@ -45,10 +47,14 @@ def auto_shard_dataset(dataset, num_shards, index): """ if (dataset.options().experimental_distribute.auto_shard_policy != AutoShardPolicy.OFF): + if num_replicas_in_sync is None: + num_replicas_in_sync = 1 if isinstance(dataset, dataset_ops.DatasetV1): - return distribute._AutoShardDatasetV1(dataset, num_shards, index) + return distribute._AutoShardDatasetV1(dataset, num_shards, index, + num_replicas_in_sync) else: - return distribute._AutoShardDataset(dataset, num_shards, index) + return distribute._AutoShardDataset(dataset, num_shards, index, + num_replicas_in_sync) else: return dataset diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 523c71c4fb5..255c32ff50e 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -479,7 +479,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): dataset, self._input_workers, self._container_strategy(), - split_batch_by=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync) def _make_input_fn_iterator( self, @@ -501,7 +501,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): dataset, self._input_workers_with_options(options), self._container_strategy(), - split_batch_by=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync) def _experimental_make_numpy_dataset(self, numpy_input, session): return numpy_dataset.one_host_numpy_dataset( diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index b60ea74dd04..b680e062ab1 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -351,14 +351,14 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): dataset, self._input_workers_with_options(options), self._container_strategy(), - split_batch_by=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync) def _make_dataset_iterator(self, dataset): return input_lib.DatasetIterator( dataset, self._input_workers, self._container_strategy(), - split_batch_by=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync) def _make_input_fn_iterator( self, diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 1a3d49a2032..525d2ff5de1 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -752,7 +752,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): dataset, input_workers, self._container_strategy(), - split_batch_by=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync) def _make_input_fn_iterator( self, @@ -809,7 +809,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): dataset, self._get_input_workers(options), self._container_strategy(), - split_batch_by=self._num_replicas_in_sync) + num_replicas_in_sync=self._num_replicas_in_sync) def _distribute_datasets_from_function(self, dataset_fn, options): input_workers = self._get_input_workers(options)