Remove the workaround that sets PerReplica spec to dynamic batch
It's no longer needed as we stopped reusing collective instance keys. Note that we still modifies element_spec to have a dynamic batch for multi worker strategies when partial batch is enabled, so that element_spec is compatible with the data produced. PiperOrigin-RevId: 354132185 Change-Id: I3857b4bb25c825befdd1f7c667437dc3bbf4ba50
This commit is contained in:
parent
f81438cb02
commit
3db793ee03
@ -1026,24 +1026,22 @@ distribute_py_test(
|
||||
"multi_and_single_gpu",
|
||||
],
|
||||
deps = [
|
||||
":collective_all_reduce_strategy",
|
||||
":combinations",
|
||||
":distribute_lib",
|
||||
":input_lib",
|
||||
":mirrored_strategy",
|
||||
":multi_worker_test_base",
|
||||
":reduce_util",
|
||||
":strategy_combinations",
|
||||
":test_util",
|
||||
":tpu_strategy",
|
||||
":values",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//third_party/py/numpy",
|
||||
|
@ -560,8 +560,7 @@ def _get_next_as_optional(iterator, strategy, return_per_replica=False):
|
||||
flattened_data = []
|
||||
for per_worker_data in replicas:
|
||||
flattened_data.extend(per_worker_data)
|
||||
replicas = _create_per_replica(
|
||||
flattened_data, strategy, get_next_as_optional=True)
|
||||
replicas = _create_per_replica(flattened_data, strategy)
|
||||
|
||||
# Run an all-reduce to see whether any worker has values.
|
||||
# TODO(b/131423105): we should be able to short-cut the all-reduce in some
|
||||
@ -661,8 +660,7 @@ class DistributedIteratorBase(DistributedIteratorInterface):
|
||||
# Make `replicas` a flat list of values across all replicas.
|
||||
replicas.extend(
|
||||
self._iterators[i].get_next_as_list_static_shapes(new_name))
|
||||
return _create_per_replica(
|
||||
replicas, self._strategy, get_next_as_optional=False)
|
||||
return _create_per_replica(replicas, self._strategy)
|
||||
|
||||
out_of_range_replicas = []
|
||||
def out_of_range_fn(worker_index, device):
|
||||
@ -696,8 +694,7 @@ class DistributedIteratorBase(DistributedIteratorInterface):
|
||||
results.append(result)
|
||||
replicas = results
|
||||
|
||||
return _create_per_replica(replicas, self._strategy,
|
||||
self._enable_get_next_as_optional)
|
||||
return _create_per_replica(replicas, self._strategy)
|
||||
|
||||
|
||||
class DistributedIteratorV1(DistributedIteratorBase):
|
||||
@ -908,9 +905,6 @@ class DistributedIterator(DistributedIteratorBase,
|
||||
# None, otherwise we just follow element_spec of the underlying dataset
|
||||
# (whose batch dimension may also be None). This is because with partial
|
||||
# batching handling we could always produce empty batches.
|
||||
#
|
||||
# TODO(b/163362689): avoid this once we have more elegant way to handle
|
||||
# retracing and collectives.
|
||||
if (self._enable_get_next_as_optional and
|
||||
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
|
||||
return nest.map_structure(
|
||||
@ -919,9 +913,9 @@ class DistributedIterator(DistributedIteratorBase,
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
# Note that we use actual element_spec to create DistributedIteratorSpec,
|
||||
# to be consistent with the underlying iterators' specs.
|
||||
# TODO(b/163362689): remove the comment after the bug if fixed.
|
||||
# Note that we use actual element_spec instead of the rebatched-as-dynamic
|
||||
# one to create DistributedIteratorSpec, to be consistent with the
|
||||
# underlying iterators' specs.
|
||||
return DistributedIteratorSpec(self._input_workers, self._element_spec,
|
||||
self._strategy,
|
||||
self._enable_get_next_as_optional,
|
||||
@ -1140,9 +1134,6 @@ class DistributedDataset(_IterableInput):
|
||||
# None, otherwise we just follow element_spec of the underlying dataset
|
||||
# (whose batch dimension may also be None). This is because with partial
|
||||
# batching handling we could always produce empty batches.
|
||||
#
|
||||
# TODO(b/163362689): avoid this once we have more elegant way to handle
|
||||
# retracing and collectives.
|
||||
if (self._enable_get_next_as_optional and
|
||||
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
|
||||
return nest.map_structure(
|
||||
@ -1321,9 +1312,6 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
||||
# None, otherwise we just follow element_spec of the underlying dataset
|
||||
# (whose batch dimension may also be None). This is because with partial
|
||||
# batching handling we could always produce empty batches.
|
||||
#
|
||||
# TODO(b/163362689): avoid this once we have more elegant way to handle
|
||||
# retracing and collectives.
|
||||
if (self._enable_get_next_as_optional and
|
||||
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
|
||||
return nest.map_structure(
|
||||
@ -2176,7 +2164,7 @@ def _enable_get_next_as_optional(strategy, dataset):
|
||||
dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _create_per_replica(value_list, strategy, get_next_as_optional):
|
||||
def _create_per_replica(value_list, strategy):
|
||||
"""Creates a PerReplica.
|
||||
|
||||
For strategies other than OneDeviceStrategy, it creates a PerReplica whose
|
||||
@ -2190,7 +2178,6 @@ def _create_per_replica(value_list, strategy, get_next_as_optional):
|
||||
Args:
|
||||
value_list: a list of values, one for each replica.
|
||||
strategy: the `tf.distribute.Strategy`.
|
||||
get_next_as_optional: whether last partial batch handling is enabled.
|
||||
|
||||
Returns:
|
||||
a structure of PerReplica.
|
||||
@ -2199,23 +2186,6 @@ def _create_per_replica(value_list, strategy, get_next_as_optional):
|
||||
# TODO(b/166464552): always wrap for all one device strategies as well.
|
||||
always_wrap = _always_wrap(strategy)
|
||||
per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap)
|
||||
|
||||
# When partial batch handling is enabled, always set the batch dimension to
|
||||
# None, otherwise we just follow element_spec of the underlying dataset
|
||||
# (whose batch dimension may also be None). This is because with partial
|
||||
# batching handling we could always produce empty batches.
|
||||
#
|
||||
# TODO(b/163362689): avoid this once we have more elegant way to handle
|
||||
# retracing and collectives.
|
||||
if (get_next_as_optional and strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
|
||||
# Use expand_composites=False since we don't want to expand PerReplica,
|
||||
# which is a CompositeTensor.
|
||||
flat_per_replicas = nest.flatten(per_replicas, expand_composites=False)
|
||||
flat_spec = [type_spec.type_spec_from_value(v) for v in flat_per_replicas]
|
||||
for per_replica, spec in zip(flat_per_replicas, flat_spec):
|
||||
per_replica._type_spec_override = _rebatch_as_dynamic(spec) # pylint: disable=protected-access
|
||||
per_replicas = nest.pack_sequence_as(per_replicas, flat_per_replicas)
|
||||
|
||||
return per_replicas
|
||||
|
||||
|
||||
|
@ -120,9 +120,6 @@ class DistributedIteratorTest(test.TestCase,
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
enable_get_next_as_optional=[True, False],
|
||||
drop_remainder=[True, False],
|
||||
@ -201,53 +198,6 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
|
||||
for x in ds:
|
||||
process_inputs(x)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
tf_api_version=2,
|
||||
enable_get_next_as_optional=[True, False],
|
||||
drop_remainder=[True, False],
|
||||
))
|
||||
def testFromFunctionInputSignatureForPerReplicaValues(
|
||||
self, distribution, enable_get_next_as_optional, drop_remainder):
|
||||
# Create files that produce partial/empty batches at different batch. Note
|
||||
# that some worker will get empty batches even when drop_remainder=True.
|
||||
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
|
||||
_create_text_file(fname1, 5)
|
||||
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
|
||||
_create_text_file(fname2, 9)
|
||||
|
||||
def dataset_fn(input_context):
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2])
|
||||
dataset = dataset.shard(input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
return readers.TextLineDatasetV2(dataset).map(
|
||||
string_ops.string_to_number).batch(
|
||||
input_context.get_per_replica_batch_size(4),
|
||||
drop_remainder=drop_remainder)
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||
enable_get_next_as_optional)
|
||||
ds = distribution.experimental_distribute_datasets_from_function(dataset_fn)
|
||||
_check_type_spec_structure(iter(ds))
|
||||
element_spec = ds.element_spec
|
||||
iter_element_spec = iter(ds).element_spec
|
||||
nest.assert_same_structure(element_spec, iter_element_spec)
|
||||
self.assertAllEqual(
|
||||
nest.flatten(element_spec), nest.flatten(iter_element_spec))
|
||||
|
||||
@def_function.function(input_signature=[element_spec])
|
||||
def process_inputs(inputs):
|
||||
distribution.run(lambda inputs: inputs, args=(inputs,))
|
||||
|
||||
for x in ds:
|
||||
process_inputs(x)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
@ -307,149 +257,6 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(spec1, spec1.most_specific_compatible_type(spec2))
|
||||
self.assertEqual(spec1, spec2.most_specific_compatible_type(spec1))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
tf_api_version=2,
|
||||
drop_remainder=[True, False],
|
||||
))
|
||||
def testFromDatasetDoesNotTriggerFunctionTracing(self, distribution,
|
||||
drop_remainder):
|
||||
self.trace_count = 0
|
||||
|
||||
@def_function.function
|
||||
def f(v):
|
||||
del v
|
||||
self.trace_count += 1
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = True
|
||||
# Total dataset size 5 allows us to have full batches, partial batches and
|
||||
# empty batches.
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3))).batch(
|
||||
4, drop_remainder=drop_remainder)
|
||||
dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
for v in iter(dataset):
|
||||
f(v)
|
||||
self.assertEqual(self.trace_count, 1)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
tf_api_version=2,
|
||||
drop_remainder=[True, False],
|
||||
))
|
||||
def testFromDatasetFileShardingDoesNotTriggerFunctionTracing(
|
||||
self, distribution, drop_remainder):
|
||||
# Create files that produce partial/empty batches at different batch.
|
||||
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
|
||||
_create_text_file(fname1, 5)
|
||||
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
|
||||
_create_text_file(fname2, 9)
|
||||
|
||||
self.trace_count = 0
|
||||
|
||||
@def_function.function
|
||||
def f(v):
|
||||
del v
|
||||
self.trace_count += 1
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = True
|
||||
dataset = readers.TextLineDatasetV2([fname1, fname2]).batch(
|
||||
4, drop_remainder=drop_remainder)
|
||||
dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
for v in iter(dataset):
|
||||
f(v)
|
||||
self.assertEqual(self.trace_count, 1)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
tf_api_version=2,
|
||||
drop_remainder=[True, False],
|
||||
))
|
||||
def testFromFunctionDoesNotTriggerFunctionTracing(self, distribution,
|
||||
drop_remainder):
|
||||
|
||||
def dataset_fn(input_context):
|
||||
# Total dataset size 5 allows us to have full batches, partial batches and
|
||||
# empty batches.
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3)))
|
||||
dataset = dataset.batch(
|
||||
input_context.get_per_replica_batch_size(4),
|
||||
drop_remainder=drop_remainder)
|
||||
return dataset.shard(input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
|
||||
self.trace_count = 0
|
||||
|
||||
@def_function.function
|
||||
def f(v):
|
||||
del v
|
||||
self.trace_count += 1
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = True
|
||||
dataset = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
for v in iter(dataset):
|
||||
f(v)
|
||||
self.assertEqual(self.trace_count, 1)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
tf_api_version=2,
|
||||
drop_remainder=[True, False],
|
||||
))
|
||||
def testFromFunctionFileShardingDoesNotTriggerFunctionTracing(
|
||||
self, distribution, drop_remainder):
|
||||
# Create files that produce partial/empty batches at different batch.
|
||||
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
|
||||
_create_text_file(fname1, 5)
|
||||
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
|
||||
_create_text_file(fname2, 9)
|
||||
|
||||
def dataset_fn(input_context):
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2])
|
||||
dataset = dataset.shard(input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
return readers.TextLineDatasetV2(dataset).batch(
|
||||
input_context.get_per_replica_batch_size(4),
|
||||
drop_remainder=drop_remainder)
|
||||
|
||||
self.trace_count = 0
|
||||
|
||||
@def_function.function
|
||||
def f(v):
|
||||
del v
|
||||
self.trace_count += 1
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = True
|
||||
dataset = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
for v in iter(dataset):
|
||||
f(v)
|
||||
self.assertEqual(self.trace_count, 1)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
@ -643,9 +450,6 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testDoesNotTriggerFunctionTracing(self, distribution,
|
||||
|
@ -362,23 +362,8 @@ class DistributedDelegate(DistributedValues):
|
||||
class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
|
||||
"""Holds a map from replica to unsynchronized values."""
|
||||
|
||||
def __init__(self, values, type_spec_override=None):
|
||||
super(PerReplica, self).__init__(values)
|
||||
# Allow setting a type spec that can be different from the underlying
|
||||
# values. This allows us avoid retracing for PerReplica from full, partial
|
||||
# and empty batches. In a multi client setup, we need to avoid such
|
||||
# retracing otherwise the collectives may mismatch since we assign new
|
||||
# collective keys when retracing the function.
|
||||
#
|
||||
# TODO(b/166169298): remove after CrossDeviceOps is tracing safe.
|
||||
self._type_spec_override = type_spec_override
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
if self._type_spec_override is not None:
|
||||
# Return a deep copy in case the caller changes it, since _type_spec()
|
||||
# normally returns a temporary object.
|
||||
return copy.deepcopy(self._type_spec_override)
|
||||
return PerReplicaSpec(
|
||||
*(type_spec.type_spec_from_value(v) for v in self._values))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user