Remove the workaround that sets PerReplica spec to dynamic batch

It's no longer needed as we stopped reusing collective instance keys. Note that we still modifies element_spec to have a dynamic batch for multi worker strategies when partial batch is enabled, so that element_spec is compatible with the data produced.

PiperOrigin-RevId: 354132185
Change-Id: I3857b4bb25c825befdd1f7c667437dc3bbf4ba50
This commit is contained in:
Ran Chen 2021-01-27 11:25:19 -08:00 committed by TensorFlower Gardener
parent f81438cb02
commit 3db793ee03
4 changed files with 17 additions and 260 deletions

View File

@ -1026,24 +1026,22 @@ distribute_py_test(
"multi_and_single_gpu",
],
deps = [
":collective_all_reduce_strategy",
":combinations",
":distribute_lib",
":input_lib",
":mirrored_strategy",
":multi_worker_test_base",
":reduce_util",
":strategy_combinations",
":test_util",
":tpu_strategy",
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:composite_tensor",
"//tensorflow/python:dtypes",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tf2",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//third_party/py/numpy",

View File

@ -560,8 +560,7 @@ def _get_next_as_optional(iterator, strategy, return_per_replica=False):
flattened_data = []
for per_worker_data in replicas:
flattened_data.extend(per_worker_data)
replicas = _create_per_replica(
flattened_data, strategy, get_next_as_optional=True)
replicas = _create_per_replica(flattened_data, strategy)
# Run an all-reduce to see whether any worker has values.
# TODO(b/131423105): we should be able to short-cut the all-reduce in some
@ -661,8 +660,7 @@ class DistributedIteratorBase(DistributedIteratorInterface):
# Make `replicas` a flat list of values across all replicas.
replicas.extend(
self._iterators[i].get_next_as_list_static_shapes(new_name))
return _create_per_replica(
replicas, self._strategy, get_next_as_optional=False)
return _create_per_replica(replicas, self._strategy)
out_of_range_replicas = []
def out_of_range_fn(worker_index, device):
@ -696,8 +694,7 @@ class DistributedIteratorBase(DistributedIteratorInterface):
results.append(result)
replicas = results
return _create_per_replica(replicas, self._strategy,
self._enable_get_next_as_optional)
return _create_per_replica(replicas, self._strategy)
class DistributedIteratorV1(DistributedIteratorBase):
@ -908,9 +905,6 @@ class DistributedIterator(DistributedIteratorBase,
# None, otherwise we just follow element_spec of the underlying dataset
# (whose batch dimension may also be None). This is because with partial
# batching handling we could always produce empty batches.
#
# TODO(b/163362689): avoid this once we have more elegant way to handle
# retracing and collectives.
if (self._enable_get_next_as_optional and
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
return nest.map_structure(
@ -919,9 +913,9 @@ class DistributedIterator(DistributedIteratorBase,
@property
def _type_spec(self):
# Note that we use actual element_spec to create DistributedIteratorSpec,
# to be consistent with the underlying iterators' specs.
# TODO(b/163362689): remove the comment after the bug if fixed.
# Note that we use actual element_spec instead of the rebatched-as-dynamic
# one to create DistributedIteratorSpec, to be consistent with the
# underlying iterators' specs.
return DistributedIteratorSpec(self._input_workers, self._element_spec,
self._strategy,
self._enable_get_next_as_optional,
@ -1140,9 +1134,6 @@ class DistributedDataset(_IterableInput):
# None, otherwise we just follow element_spec of the underlying dataset
# (whose batch dimension may also be None). This is because with partial
# batching handling we could always produce empty batches.
#
# TODO(b/163362689): avoid this once we have more elegant way to handle
# retracing and collectives.
if (self._enable_get_next_as_optional and
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
return nest.map_structure(
@ -1321,9 +1312,6 @@ class DistributedDatasetsFromFunction(_IterableInput):
# None, otherwise we just follow element_spec of the underlying dataset
# (whose batch dimension may also be None). This is because with partial
# batching handling we could always produce empty batches.
#
# TODO(b/163362689): avoid this once we have more elegant way to handle
# retracing and collectives.
if (self._enable_get_next_as_optional and
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
return nest.map_structure(
@ -2176,7 +2164,7 @@ def _enable_get_next_as_optional(strategy, dataset):
dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
def _create_per_replica(value_list, strategy, get_next_as_optional):
def _create_per_replica(value_list, strategy):
"""Creates a PerReplica.
For strategies other than OneDeviceStrategy, it creates a PerReplica whose
@ -2190,7 +2178,6 @@ def _create_per_replica(value_list, strategy, get_next_as_optional):
Args:
value_list: a list of values, one for each replica.
strategy: the `tf.distribute.Strategy`.
get_next_as_optional: whether last partial batch handling is enabled.
Returns:
a structure of PerReplica.
@ -2199,23 +2186,6 @@ def _create_per_replica(value_list, strategy, get_next_as_optional):
# TODO(b/166464552): always wrap for all one device strategies as well.
always_wrap = _always_wrap(strategy)
per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap)
# When partial batch handling is enabled, always set the batch dimension to
# None, otherwise we just follow element_spec of the underlying dataset
# (whose batch dimension may also be None). This is because with partial
# batching handling we could always produce empty batches.
#
# TODO(b/163362689): avoid this once we have more elegant way to handle
# retracing and collectives.
if (get_next_as_optional and strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
# Use expand_composites=False since we don't want to expand PerReplica,
# which is a CompositeTensor.
flat_per_replicas = nest.flatten(per_replicas, expand_composites=False)
flat_spec = [type_spec.type_spec_from_value(v) for v in flat_per_replicas]
for per_replica, spec in zip(flat_per_replicas, flat_spec):
per_replica._type_spec_override = _rebatch_as_dynamic(spec) # pylint: disable=protected-access
per_replicas = nest.pack_sequence_as(per_replicas, flat_per_replicas)
return per_replicas

View File

@ -120,9 +120,6 @@ class DistributedIteratorTest(test.TestCase,
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
enable_get_next_as_optional=[True, False],
drop_remainder=[True, False],
@ -201,53 +198,6 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
for x in ds:
process_inputs(x)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
tf_api_version=2,
enable_get_next_as_optional=[True, False],
drop_remainder=[True, False],
))
def testFromFunctionInputSignatureForPerReplicaValues(
self, distribution, enable_get_next_as_optional, drop_remainder):
# Create files that produce partial/empty batches at different batch. Note
# that some worker will get empty batches even when drop_remainder=True.
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
_create_text_file(fname1, 5)
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
_create_text_file(fname2, 9)
def dataset_fn(input_context):
dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2])
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
return readers.TextLineDatasetV2(dataset).map(
string_ops.string_to_number).batch(
input_context.get_per_replica_batch_size(4),
drop_remainder=drop_remainder)
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
ds = distribution.experimental_distribute_datasets_from_function(dataset_fn)
_check_type_spec_structure(iter(ds))
element_spec = ds.element_spec
iter_element_spec = iter(ds).element_spec
nest.assert_same_structure(element_spec, iter_element_spec)
self.assertAllEqual(
nest.flatten(element_spec), nest.flatten(iter_element_spec))
@def_function.function(input_signature=[element_spec])
def process_inputs(inputs):
distribution.run(lambda inputs: inputs, args=(inputs,))
for x in ds:
process_inputs(x)
@combinations.generate(
combinations.combine(
mode=["eager"],
@ -307,149 +257,6 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
self.assertEqual(spec1, spec1.most_specific_compatible_type(spec2))
self.assertEqual(spec1, spec2.most_specific_compatible_type(spec1))
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
tf_api_version=2,
drop_remainder=[True, False],
))
def testFromDatasetDoesNotTriggerFunctionTracing(self, distribution,
drop_remainder):
self.trace_count = 0
@def_function.function
def f(v):
del v
self.trace_count += 1
distribution.extended.experimental_enable_get_next_as_optional = True
# Total dataset size 5 allows us to have full batches, partial batches and
# empty batches.
dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3))).batch(
4, drop_remainder=drop_remainder)
dataset = distribution.experimental_distribute_dataset(dataset)
for v in iter(dataset):
f(v)
self.assertEqual(self.trace_count, 1)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
tf_api_version=2,
drop_remainder=[True, False],
))
def testFromDatasetFileShardingDoesNotTriggerFunctionTracing(
self, distribution, drop_remainder):
# Create files that produce partial/empty batches at different batch.
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
_create_text_file(fname1, 5)
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
_create_text_file(fname2, 9)
self.trace_count = 0
@def_function.function
def f(v):
del v
self.trace_count += 1
distribution.extended.experimental_enable_get_next_as_optional = True
dataset = readers.TextLineDatasetV2([fname1, fname2]).batch(
4, drop_remainder=drop_remainder)
dataset = distribution.experimental_distribute_dataset(dataset)
for v in iter(dataset):
f(v)
self.assertEqual(self.trace_count, 1)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
tf_api_version=2,
drop_remainder=[True, False],
))
def testFromFunctionDoesNotTriggerFunctionTracing(self, distribution,
drop_remainder):
def dataset_fn(input_context):
# Total dataset size 5 allows us to have full batches, partial batches and
# empty batches.
dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3)))
dataset = dataset.batch(
input_context.get_per_replica_batch_size(4),
drop_remainder=drop_remainder)
return dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
self.trace_count = 0
@def_function.function
def f(v):
del v
self.trace_count += 1
distribution.extended.experimental_enable_get_next_as_optional = True
dataset = distribution.experimental_distribute_datasets_from_function(
dataset_fn)
for v in iter(dataset):
f(v)
self.assertEqual(self.trace_count, 1)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
tf_api_version=2,
drop_remainder=[True, False],
))
def testFromFunctionFileShardingDoesNotTriggerFunctionTracing(
self, distribution, drop_remainder):
# Create files that produce partial/empty batches at different batch.
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
_create_text_file(fname1, 5)
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
_create_text_file(fname2, 9)
def dataset_fn(input_context):
dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2])
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
return readers.TextLineDatasetV2(dataset).batch(
input_context.get_per_replica_batch_size(4),
drop_remainder=drop_remainder)
self.trace_count = 0
@def_function.function
def f(v):
del v
self.trace_count += 1
distribution.extended.experimental_enable_get_next_as_optional = True
dataset = distribution.experimental_distribute_datasets_from_function(
dataset_fn)
for v in iter(dataset):
f(v)
self.assertEqual(self.trace_count, 1)
@combinations.generate(
combinations.combine(
mode=["eager"],
@ -643,9 +450,6 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
enable_get_next_as_optional=[True, False]))
def testDoesNotTriggerFunctionTracing(self, distribution,

View File

@ -362,23 +362,8 @@ class DistributedDelegate(DistributedValues):
class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
"""Holds a map from replica to unsynchronized values."""
def __init__(self, values, type_spec_override=None):
super(PerReplica, self).__init__(values)
# Allow setting a type spec that can be different from the underlying
# values. This allows us avoid retracing for PerReplica from full, partial
# and empty batches. In a multi client setup, we need to avoid such
# retracing otherwise the collectives may mismatch since we assign new
# collective keys when retracing the function.
#
# TODO(b/166169298): remove after CrossDeviceOps is tracing safe.
self._type_spec_override = type_spec_override
@property
def _type_spec(self):
if self._type_spec_override is not None:
# Return a deep copy in case the caller changes it, since _type_spec()
# normally returns a temporary object.
return copy.deepcopy(self._type_spec_override)
return PerReplicaSpec(
*(type_spec.type_spec_from_value(v) for v in self._values))