[tf.data + tf.distribute] Use RebatchDataset instead of LegacyRebatchDataset in distribution strategies when global batch size can be statically determined.

PiperOrigin-RevId: 334646795
Change-Id: I6ed7a09e38577c1b1d3487473e12a48c92c2a2c7
This commit is contained in:
Rachel Lim 2020-09-30 11:50:44 -07:00 committed by TensorFlower Gardener
parent 8e14fdb6a2
commit a7f8535480
8 changed files with 332 additions and 96 deletions

View File

@ -85,9 +85,9 @@ class _AutoShardDataset(dataset_ops.UnaryDataset):
return self._element_spec
def _AutoShardDatasetV1(input_dataset, num_workers, index): # pylint: disable=invalid-name
def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None): # pylint: disable=invalid-name
return dataset_ops.DatasetV1Adapter(
_AutoShardDataset(input_dataset, num_workers, index))
_AutoShardDataset(input_dataset, num_workers, index, num_replicas))
class _RebatchDataset(dataset_ops.UnaryDataset):

View File

@ -476,7 +476,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
dataset,
self._input_workers_with_options(options),
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync,
num_replicas_in_sync=self._num_replicas_in_sync,
input_context=input_context)
def _distribute_datasets_from_function(self, dataset_fn, options):
@ -505,7 +505,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
dataset,
self._input_workers,
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync,
num_replicas_in_sync=self._num_replicas_in_sync,
input_context=input_context)
def _make_input_fn_iterator(

View File

@ -61,7 +61,7 @@ from tensorflow.tools.docs import doc_controls
def get_distributed_dataset(dataset,
input_workers,
strategy,
split_batch_by=None,
num_replicas_in_sync=None,
input_context=None):
"""Returns a distributed dataset from the given tf.data.Dataset instance.
@ -77,8 +77,10 @@ def get_distributed_dataset(dataset,
iterators should be created.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
split_batch_by: Optional integer. If present, we "split" each batch of the
dataset by `split_batch_by` value.
num_replicas_in_sync: Optional integer. If this is not None, the value is
used to decide how to rebatch datasets into smaller batches so that
the total batch size for each step (across all workers and replicas)
adds up to `dataset`'s batch size.
input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and
@ -92,14 +94,14 @@ def get_distributed_dataset(dataset,
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context)
else:
return DistributedDatasetV1(
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context)
@ -917,20 +919,24 @@ class DistributedDataset(_IterableInput):
dataset,
input_workers,
strategy,
split_batch_by=None,
num_replicas_in_sync=None,
input_context=None):
"""Distribute the dataset on all workers.
If `split_batch_by` is not None, we "split" each batch of the dataset by
`split_batch_by` value.
If `num_replicas_in_sync` is not None, we split each batch of the dataset
into `num_replicas_in_sync` smaller batches, to be distributed among that
worker's replicas, so that the batch size for a global step (across all
workers and replicas) is as expected.
Args:
dataset: `tf.data.Dataset` that will be used as the input source.
input_workers: an `InputWorkers` object.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
split_batch_by: Optional integer. If present, we "split" each batch of the
dataset by `split_batch_by` value.
num_replicas_in_sync: Optional integer. If this is not None, the value
is used to decide how to rebatch datasets into smaller batches so that
the total batch size for each step (across all workers and replicas)
adds up to `dataset`'s batch size.
input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and
@ -942,36 +948,29 @@ class DistributedDataset(_IterableInput):
# different subset of files. If that is not possible, will attempt to shard
# the final input such that each worker will run the entire preprocessing
# pipeline and only receive its own shard of the dataset.
if split_batch_by:
try:
# pylint: disable=protected-access
with ops.colocate_with(dataset._variant_tensor):
dataset = distribute._LegacyRebatchDataset(dataset, split_batch_by)
# Add a prefetch to pipeline rebatching for performance.
# TODO(rachelim): Instead of inserting an extra prefetch stage here,
# leverage static graph rewrites to insert _RebatchDataset before
# the final `prefetch` if it exists.
dataset = dataset.prefetch(split_batch_by)
except errors.InvalidArgumentError as e:
if "without encountering a batch" in str(e):
six.reraise(
ValueError,
ValueError(
"Call the `batch` method on the input Dataset in order to be "
"able to split your input across {} replicas.\n Please "
"the tf.distribute.Strategy guide. {}".format(
split_batch_by, e)),
sys.exc_info()[2])
else:
raise
# Additionally, we rebatch the dataset on each worker into
# `num_replicas_in_sync` smaller batches to be distributed among that
# worker's replicas, so that the batch size for a global step (across all
# workers and replicas) adds up to the original dataset's batch size.
if num_replicas_in_sync is not None:
num_workers = input_context.num_input_pipelines if input_context else len(
input_workers.worker_devices)
rebatch_fn = self._make_rebatch_fn(dataset, num_workers,
num_replicas_in_sync)
else:
rebatch_fn = None
self._cloned_datasets = []
if input_context:
# Between-graph where we rely on the input_context for sharding
assert input_workers.num_workers == 1
if rebatch_fn is not None:
dataset = rebatch_fn(dataset, input_context.input_pipeline_id)
dataset = input_ops.auto_shard_dataset(dataset,
input_context.num_input_pipelines,
input_context.input_pipeline_id)
input_context.input_pipeline_id,
num_replicas_in_sync)
self._cloned_datasets.append(dataset)
else:
replicated_ds = distribute.replicate(dataset,
@ -980,16 +979,73 @@ class DistributedDataset(_IterableInput):
with ops.device(worker):
cloned_dataset = replicated_ds[worker]
cloned_dataset = cloned_dataset.with_options(dataset.options())
if rebatch_fn is not None:
cloned_dataset = rebatch_fn(cloned_dataset, i)
cloned_dataset = input_ops.auto_shard_dataset(
cloned_dataset, len(input_workers.worker_devices), i)
cloned_dataset, len(input_workers.worker_devices), i,
num_replicas_in_sync)
self._cloned_datasets.append(cloned_dataset)
self._input_workers = input_workers
self._strategy = strategy
self._enable_get_next_as_optional = _enable_get_next_as_optional(
self._strategy, dataset.element_spec)
self._element_spec = _create_distributed_tensor_spec(self._strategy,
dataset.element_spec) # pylint: disable=protected-access
self._element_spec = _create_distributed_tensor_spec(
self._strategy, self._cloned_datasets[0].element_spec)
def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync):
"""Returns a callable that rebatches the input dataset.
Args:
dataset: A `tf.data.Dataset` representing the dataset to be distributed.
num_workers: An integer representing the number of workers to distribute
`dataset` among.
num_replicas_in_sync: An integer representing the number of replicas in
sync across all workers.
"""
if num_replicas_in_sync % num_workers:
raise ValueError(
"tf.distribute expects every worker to have the same number of "
"replicas. However, encountered `num_replicas_in_sync` ({}) that "
"cannot be divided by `num_workers` ({})".format(
num_replicas_in_sync, num_workers))
num_replicas_per_worker = num_replicas_in_sync // num_workers
with ops.colocate_with(dataset._variant_tensor): # pylint: disable=protected-access
batch_size = distribute.compute_batch_size(dataset)
def rebatch_fn(dataset, worker_index):
try:
# pylint: disable=protected-access
def apply_rebatch():
batch_sizes = distribute.batch_sizes_for_worker(
batch_size, num_workers, num_replicas_per_worker, worker_index)
return distribute._RebatchDataset(
dataset, batch_sizes).prefetch(num_replicas_per_worker)
def apply_legacy_rebatch():
return distribute._LegacyRebatchDataset(
dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker)
with ops.colocate_with(dataset._variant_tensor):
return control_flow_ops.cond(
math_ops.not_equal(batch_size, -1),
true_fn=apply_rebatch,
false_fn=apply_legacy_rebatch)
except errors.InvalidArgumentError as e:
if "without encountering a batch" in str(e):
six.reraise(
ValueError,
ValueError(
"Call the `batch` method on the input Dataset in order to be "
"able to split your input across {} replicas.\n Please see "
"the tf.distribute.Strategy guide. {}".format(
num_replicas_in_sync, e)),
sys.exc_info()[2])
else:
raise
return rebatch_fn
def __iter__(self):
if not (context.executing_eagerly() or
@ -1040,14 +1096,14 @@ class DistributedDatasetV1(DistributedDataset):
dataset,
input_workers,
strategy,
split_batch_by=None,
num_replicas_in_sync=None,
input_context=None):
self._input_workers = input_workers
super(DistributedDatasetV1, self).__init__(
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context)
def make_one_shot_iterator(self):
@ -1305,20 +1361,24 @@ class DatasetIterator(DistributedIteratorV1):
dataset,
input_workers,
strategy,
split_batch_by=None,
num_replicas_in_sync=None,
input_context=None):
"""Make an iterator for the dataset on given devices.
If `split_batch_by` is not None, we "split" each batch of the
dataset by `split_batch_by` value.
If `num_replicas_in_sync` is not None, we split each batch of the dataset
into `num_replicas_in_sync` smaller batches, to be distributed among that
worker's replicas, so that the batch size for a global step (across all
workers and replicas) is as expected.
Args:
dataset: `tf.data.Dataset` that will be used as the input source.
input_workers: an `InputWorkers` object.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
split_batch_by: Optional integer. If present, we "split" each batch of the
dataset by `split_batch_by` value.
num_replicas_in_sync: Optional integer. If this is not None, the value is
used to decide how to rebatch datasets into smaller batches so that the
total batch size for each step (across all workers and replicas) adds up
to `dataset`'s batch size.
input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and
@ -1328,7 +1388,7 @@ class DatasetIterator(DistributedIteratorV1):
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context)
worker_iterators = _create_iterators_per_worker(
dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access

View File

@ -61,7 +61,7 @@ class DistributedIteratorTestBase(test.TestCase):
dataset_or_input_fn,
input_workers,
devices,
split_batch_by,
num_replicas_in_sync,
strategy,
input_context=None):
# The `input_context` passed in is to shard dataset for
@ -93,7 +93,7 @@ class DistributedIteratorTestBase(test.TestCase):
dataset_or_input_fn,
input_workers,
strategy,
split_batch_by=split_batch_by,
num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context)
return iterator
@ -101,7 +101,7 @@ class DistributedIteratorTestBase(test.TestCase):
input_type,
dataset,
input_workers,
split_batch_by,
num_replicas_in_sync,
strategy,
input_context=None):
if input_type == "dataset":
@ -110,14 +110,14 @@ class DistributedIteratorTestBase(test.TestCase):
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context)
else:
return input_lib.DistributedDatasetV1(
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
num_replicas_in_sync=num_replicas_in_sync,
input_context=input_context)
else:
return strategy.distribute_datasets_from_function(dataset)
@ -163,7 +163,7 @@ class DistributedIteratorTestBase(test.TestCase):
expected_values,
strategy,
sess=None,
split_batch_by=None,
num_replicas_in_sync=None,
input_context=None):
if iteration_type == "for_loop" and not context.executing_eagerly():
self.skipTest("unsupported test combination.")
@ -183,7 +183,7 @@ class DistributedIteratorTestBase(test.TestCase):
dataset_or_input_fn,
input_workers,
devices,
split_batch_by,
num_replicas_in_sync,
strategy,
input_context=input_context)
else:
@ -192,7 +192,7 @@ class DistributedIteratorTestBase(test.TestCase):
input_type,
dataset_or_input_fn,
input_workers,
split_batch_by,
num_replicas_in_sync,
strategy,
input_context=input_context)
@ -361,10 +361,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution,
enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
else:
dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
@ -419,10 +416,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
distribution, enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])]
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
else:
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
@ -455,10 +449,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
worker_device_pairs.setdefault(host_device, [])
worker_device_pairs[host_device].append(tpu_device)
worker_device_pairs = worker_device_pairs.items()
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
else:
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
@ -493,14 +484,10 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
def dataset_fn(ctx):
del ctx
if tf2.enabled():
dataset1 = dataset_ops.DatasetV2.range(10)
dataset2 = dataset_ops.DatasetV2.range(10).map(lambda x: x**2)
return dataset_ops.DatasetV2.zip((dataset1, dataset2))
else:
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
return dataset_ops.Dataset.zip((dataset1, dataset2))
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
@ -563,7 +550,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
input_workers = input_lib.InputWorkers(worker_device_pairs)
dataset = dataset_ops.DatasetV2.range(10)
dataset = dataset_ops.Dataset.range(10)
dist_dataset = input_lib.get_distributed_dataset(dataset, input_workers,
distribution)
@ -663,7 +650,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
worker_device_pairs,
expected_values,
distribution,
split_batch_by=distribution.num_replicas_in_sync,
num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context())
@combinations.generate(
@ -724,7 +711,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
worker_device_pairs,
expected_values,
distribution,
split_batch_by=distribution.num_replicas_in_sync,
num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context())
@combinations.generate(
@ -733,14 +720,14 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
split_batch_by=[None, 2],
num_replicas_in_sync=[None, 2],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu
],
enable_get_next_as_optional=[True, False]))
def testBatchSplitting(self, input_type, api_type, iteration_type,
split_batch_by, distribution,
num_replicas_in_sync, distribution,
enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])]
@ -750,7 +737,8 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
input_type, dataset_fn)
updated_batch_size = (
batch_size // split_batch_by if split_batch_by else batch_size)
batch_size //
num_replicas_in_sync if num_replicas_in_sync else batch_size)
expected_values = [[range(i, i+updated_batch_size),
range(i+updated_batch_size, i+2*updated_batch_size)]
for i in range(0, 100, updated_batch_size*2)]
@ -766,7 +754,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
expected_values,
distribution,
sess=None,
split_batch_by=split_batch_by)
num_replicas_in_sync=num_replicas_in_sync)
@combinations.generate(
combinations.combine(
@ -774,13 +762,13 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
input_type=["dataset"],
api_type=["wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
split_batch_by=[None, 2],
num_replicas_in_sync=[None, 2],
distribution=[
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
enable_get_next_as_optional=[True, False]))
def testBatchSplittingMultiWorker(self, input_type, api_type, iteration_type,
split_batch_by, distribution,
num_replicas_in_sync, distribution,
enable_get_next_as_optional):
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:GPU:1"])]
@ -796,7 +784,8 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
input_type, dataset_fn)
updated_batch_size = (
batch_size // split_batch_by if split_batch_by else batch_size)
batch_size //
num_replicas_in_sync if num_replicas_in_sync else batch_size)
expected_values = [
[ # pylint: disable=g-complex-comprehension
range(i, i + updated_batch_size),
@ -815,7 +804,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
expected_values,
distribution,
sess=None,
split_batch_by=split_batch_by)
num_replicas_in_sync=num_replicas_in_sync)
@combinations.generate(
combinations.combine(
@ -1249,6 +1238,187 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
expected_for_sum = 310.
self.assertAllEqual(nest.flatten(sums), [expected_for_sum] * 3)
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
]))
def testMWMSPartialBatch(self, input_type, api_type, iteration_type,
distribution):
# Test case: 2 workers, 1 replica each.
# This test simulates the sharded behavior when we have two files each with
# 12 elements and a global batch size of 8. When we consider the dataset in
# aggregate (non-distributed), there are 24 elements divided into 3 batches
# of size 8. Hence, the correct distributed behavior is for each replica to
# see sub-batches of size 4, over three steps.
def dataset_fn(ctx):
del ctx
dataset = dataset_ops.Dataset.range(12).batch(8)
# Set the sharding behavior to OFF for simplicity of test setup; namely,
# `dataset` defines the per-worker dataset and will not be further
# sharded. Each worker will see a dataset that is
# tf.data.Dataset.range(12).batch(8).rebatch(...).
options = dataset_ops.Options()
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
dataset = dataset.with_options(options)
return dataset
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
# Actual devices don't matter in this test as long as there is 1 local
# replica.
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
# Each test runs individually on each worker, so we compare the
# values on each worker. Each worker should rebatch its dataset into
# smaller batches of size 4.
expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9, 10, 11]]]
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset,
worker_device_pairs,
expected_values,
distribution,
num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context())
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
]))
def testMWMSPartialBatchWithLegacyRebatch(self, input_type, api_type,
iteration_type, distribution):
# Test case: 2 workers, 1 replica each.
# This test simulates the sharded behavior when we have two files each with
# 12 elements and a global batch size of 8. When we consider the dataset in
# aggregate (non-distributed), there are 24 elements divided into 3 batches
# of size 8. Hence, the correct distributed behavior is for each replica to
# see sub-batches of size 4, over three steps. However, when we create a
# DistributedDataset and cannot statically infer the intended global batch
# size (e.g. if the user does not use a batching dataset), each worker will
# rebatch based on the dynamic batch size of the data encountered, even when
# it encounters partial batches. The last per-worker partial batch (size 4)
# ends up being split into two replicas, resulting in 4 steps in total, of
# (global) batch sizes 8, 8, 4, 4.
def dataset_fn(ctx):
del ctx
# The following dataset is equivalent to
# tf.data.Dataset.range(12).batch(8), but does not use a batching dataset.
# This causes DistributedDataset to use LegacyRebatch instead.
batch_sizes = dataset_ops.Dataset.from_tensor_slices([8, 4])
offsets = dataset_ops.Dataset.from_tensor_slices([0, 8])
dataset = dataset_ops.Dataset.zip((offsets, batch_sizes))
def map_fn(offset, batch_size):
return math_ops.range(offset, offset + batch_size)
dataset = dataset.map(map_fn)
# Set the sharding behavior to OFF for simplicity of test setup; namely,
# `dataset` defines the per-worker dataset and will not be further
# sharded. Each worker will see a dataset that is equivalent to
# tf.data.Dataset.range(12).batch(8).rebatch(...).
options = dataset_ops.Options()
options.experimental_distribute.auto_shard_policy = AutoShardPolicy.OFF
dataset = dataset.with_options(options)
return dataset
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
# Actual devices don't matter in this test as long as the number of global
# replicas is 2.
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
# Each test runs individually on each worker, so we compare the
# values on each worker. Each worker should rebatch its dataset into
# smaller batches of size 4.
expected_values = [[[0, 1, 2, 3]], [[4, 5, 6, 7]], [[8, 9]], [[10, 11]]]
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset,
worker_device_pairs,
expected_values,
distribution,
num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context())
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["dataset"],
api_type=["wrap_into_iterator", "wrap_into_dataset"],
iteration_type=["get_next", "for_loop"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
],
auto_shard_policy=[AutoShardPolicy.AUTO, AutoShardPolicy.DATA]))
def testMWMSWithDataSharding(self, input_type, api_type, iteration_type,
distribution, auto_shard_policy):
# Test case: 2 workers, 1 replica each.
# This test simulates the sharded behavior the dataset is sharded by data
# and the batch size is indivisible by the number of replicas. This checks
# that the elements are as expected and the batch size across all workers
# adds up to 3. This test will only pass if the autoshard rewrite rewrites
# RebatchDatasetV2 to legacy RebatchDataset when sharding by data.
def dataset_fn(ctx):
del ctx
dataset = dataset_ops.Dataset.range(8).batch(3)
# Set the sharding behavior to OFF for simplicity of test setup; namely,
# `dataset` defines the per-worker dataset and will not be further
# sharded. Each worker will see a dataset that is
# tf.data.Dataset.range(12).batch(8).rebatch(...).
options = dataset_ops.Options()
options.experimental_distribute.auto_shard_policy = auto_shard_policy
dataset = dataset.with_options(options)
return dataset
dataset = self._create_dataset_or_input_fn(input_type, dataset_fn)
# Actual devices don't matter in this test as long as there is 1 local
# replica.
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
# Each test runs individually on each worker, so we compare the
# values on each worker. We expect each worker to see different shards of
# data.
cr = distribution.cluster_resolver
worker_id = multi_worker_util.id_in_cluster(cr.cluster_spec(), cr.task_type,
cr.task_id)
if worker_id == 0:
expected_values = [[[0, 1]], [[3, 4]], [[6]]]
elif worker_id == 1:
expected_values = [[[2]], [[5]], [[7]]]
self._test_input_iteration(
input_type,
api_type,
iteration_type,
dataset,
worker_device_pairs,
expected_values,
distribution,
num_replicas_in_sync=distribution.num_replicas_in_sync,
input_context=distribution.extended._make_input_context())
if __name__ == "__main__":
combinations.main()

View File

@ -27,7 +27,7 @@ from tensorflow.python.framework import ops
# pylint: disable=protected-access
def auto_shard_dataset(dataset, num_shards, index):
def auto_shard_dataset(dataset, num_shards, index, num_replicas_in_sync=None):
"""Shard the input pipeline by sharding the underlying list of files.
Args:
@ -37,6 +37,8 @@ def auto_shard_dataset(dataset, num_shards, index):
shards operating in parallel. Same usage as in `tf.data.Dataset.shard`.
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
Same usage as in `tf.data.Dataset.shard`.
num_replicas_in_sync: An integer representing the total number of replicas
across all workers. This is used in the rewrite when sharding by data.
Returns:
A modified `Dataset` obtained by updating the pipeline sharded by the
@ -45,10 +47,14 @@ def auto_shard_dataset(dataset, num_shards, index):
"""
if (dataset.options().experimental_distribute.auto_shard_policy !=
AutoShardPolicy.OFF):
if num_replicas_in_sync is None:
num_replicas_in_sync = 1
if isinstance(dataset, dataset_ops.DatasetV1):
return distribute._AutoShardDatasetV1(dataset, num_shards, index)
return distribute._AutoShardDatasetV1(dataset, num_shards, index,
num_replicas_in_sync)
else:
return distribute._AutoShardDataset(dataset, num_shards, index)
return distribute._AutoShardDataset(dataset, num_shards, index,
num_replicas_in_sync)
else:
return dataset

View File

@ -479,7 +479,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
dataset,
self._input_workers,
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
num_replicas_in_sync=self._num_replicas_in_sync)
def _make_input_fn_iterator(
self,
@ -501,7 +501,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
dataset,
self._input_workers_with_options(options),
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
num_replicas_in_sync=self._num_replicas_in_sync)
def _experimental_make_numpy_dataset(self, numpy_input, session):
return numpy_dataset.one_host_numpy_dataset(

View File

@ -351,14 +351,14 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
dataset,
self._input_workers_with_options(options),
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
num_replicas_in_sync=self._num_replicas_in_sync)
def _make_dataset_iterator(self, dataset):
return input_lib.DatasetIterator(
dataset,
self._input_workers,
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
num_replicas_in_sync=self._num_replicas_in_sync)
def _make_input_fn_iterator(
self,

View File

@ -752,7 +752,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
dataset,
input_workers,
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
num_replicas_in_sync=self._num_replicas_in_sync)
def _make_input_fn_iterator(
self,
@ -809,7 +809,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
dataset,
self._get_input_workers(options),
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
num_replicas_in_sync=self._num_replicas_in_sync)
def _distribute_datasets_from_function(self, dataset_fn, options):
input_workers = self._get_input_workers(options)