diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 8851420c5eb..bdbb18f8c33 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1026,24 +1026,22 @@ distribute_py_test( "multi_and_single_gpu", ], deps = [ + ":collective_all_reduce_strategy", ":combinations", - ":distribute_lib", + ":input_lib", + ":mirrored_strategy", + ":multi_worker_test_base", + ":reduce_util", ":strategy_combinations", - ":test_util", ":tpu_strategy", ":values", - "//tensorflow/python:array_ops", - "//tensorflow/python:composite_tensor", - "//tensorflow/python:dtypes", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:tensor_spec", - "//tensorflow/python:tf2", - "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:readers", - "//tensorflow/python/eager:def_function", + "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python/ops/ragged:ragged_tensor", "//third_party/py/numpy", diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 8a44f693c5d..553d1151df8 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -560,8 +560,7 @@ def _get_next_as_optional(iterator, strategy, return_per_replica=False): flattened_data = [] for per_worker_data in replicas: flattened_data.extend(per_worker_data) - replicas = _create_per_replica( - flattened_data, strategy, get_next_as_optional=True) + replicas = _create_per_replica(flattened_data, strategy) # Run an all-reduce to see whether any worker has values. # TODO(b/131423105): we should be able to short-cut the all-reduce in some @@ -661,8 +660,7 @@ class DistributedIteratorBase(DistributedIteratorInterface): # Make `replicas` a flat list of values across all replicas. replicas.extend( self._iterators[i].get_next_as_list_static_shapes(new_name)) - return _create_per_replica( - replicas, self._strategy, get_next_as_optional=False) + return _create_per_replica(replicas, self._strategy) out_of_range_replicas = [] def out_of_range_fn(worker_index, device): @@ -696,8 +694,7 @@ class DistributedIteratorBase(DistributedIteratorInterface): results.append(result) replicas = results - return _create_per_replica(replicas, self._strategy, - self._enable_get_next_as_optional) + return _create_per_replica(replicas, self._strategy) class DistributedIteratorV1(DistributedIteratorBase): @@ -908,9 +905,6 @@ class DistributedIterator(DistributedIteratorBase, # None, otherwise we just follow element_spec of the underlying dataset # (whose batch dimension may also be None). This is because with partial # batching handling we could always produce empty batches. - # - # TODO(b/163362689): avoid this once we have more elegant way to handle - # retracing and collectives. if (self._enable_get_next_as_optional and self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access return nest.map_structure( @@ -919,9 +913,9 @@ class DistributedIterator(DistributedIteratorBase, @property def _type_spec(self): - # Note that we use actual element_spec to create DistributedIteratorSpec, - # to be consistent with the underlying iterators' specs. - # TODO(b/163362689): remove the comment after the bug if fixed. + # Note that we use actual element_spec instead of the rebatched-as-dynamic + # one to create DistributedIteratorSpec, to be consistent with the + # underlying iterators' specs. return DistributedIteratorSpec(self._input_workers, self._element_spec, self._strategy, self._enable_get_next_as_optional, @@ -1140,9 +1134,6 @@ class DistributedDataset(_IterableInput): # None, otherwise we just follow element_spec of the underlying dataset # (whose batch dimension may also be None). This is because with partial # batching handling we could always produce empty batches. - # - # TODO(b/163362689): avoid this once we have more elegant way to handle - # retracing and collectives. if (self._enable_get_next_as_optional and self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access return nest.map_structure( @@ -1321,9 +1312,6 @@ class DistributedDatasetsFromFunction(_IterableInput): # None, otherwise we just follow element_spec of the underlying dataset # (whose batch dimension may also be None). This is because with partial # batching handling we could always produce empty batches. - # - # TODO(b/163362689): avoid this once we have more elegant way to handle - # retracing and collectives. if (self._enable_get_next_as_optional and self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access return nest.map_structure( @@ -2176,7 +2164,7 @@ def _enable_get_next_as_optional(strategy, dataset): dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access -def _create_per_replica(value_list, strategy, get_next_as_optional): +def _create_per_replica(value_list, strategy): """Creates a PerReplica. For strategies other than OneDeviceStrategy, it creates a PerReplica whose @@ -2190,7 +2178,6 @@ def _create_per_replica(value_list, strategy, get_next_as_optional): Args: value_list: a list of values, one for each replica. strategy: the `tf.distribute.Strategy`. - get_next_as_optional: whether last partial batch handling is enabled. Returns: a structure of PerReplica. @@ -2199,23 +2186,6 @@ def _create_per_replica(value_list, strategy, get_next_as_optional): # TODO(b/166464552): always wrap for all one device strategies as well. always_wrap = _always_wrap(strategy) per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap) - - # When partial batch handling is enabled, always set the batch dimension to - # None, otherwise we just follow element_spec of the underlying dataset - # (whose batch dimension may also be None). This is because with partial - # batching handling we could always produce empty batches. - # - # TODO(b/163362689): avoid this once we have more elegant way to handle - # retracing and collectives. - if (get_next_as_optional and strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access - # Use expand_composites=False since we don't want to expand PerReplica, - # which is a CompositeTensor. - flat_per_replicas = nest.flatten(per_replicas, expand_composites=False) - flat_spec = [type_spec.type_spec_from_value(v) for v in flat_per_replicas] - for per_replica, spec in zip(flat_per_replicas, flat_spec): - per_replica._type_spec_override = _rebatch_as_dynamic(spec) # pylint: disable=protected-access - per_replicas = nest.pack_sequence_as(per_replicas, flat_per_replicas) - return per_replicas diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py index 6aa662fa10a..d3830c3df62 100644 --- a/tensorflow/python/distribute/input_lib_type_spec_test.py +++ b/tensorflow/python/distribute/input_lib_type_spec_test.py @@ -120,9 +120,6 @@ class DistributedIteratorTest(test.TestCase, distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, - strategy_combinations.multi_worker_mirrored_2x1_cpu, - strategy_combinations.multi_worker_mirrored_2x1_gpu, - strategy_combinations.multi_worker_mirrored_2x2_gpu, ], enable_get_next_as_optional=[True, False], drop_remainder=[True, False], @@ -201,53 +198,6 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): for x in ds: process_inputs(x) - @combinations.generate( - combinations.combine( - mode=["eager"], - distribution=[ - strategy_combinations.multi_worker_mirrored_2x1_cpu, - strategy_combinations.multi_worker_mirrored_2x1_gpu, - strategy_combinations.multi_worker_mirrored_2x2_gpu, - ], - tf_api_version=2, - enable_get_next_as_optional=[True, False], - drop_remainder=[True, False], - )) - def testFromFunctionInputSignatureForPerReplicaValues( - self, distribution, enable_get_next_as_optional, drop_remainder): - # Create files that produce partial/empty batches at different batch. Note - # that some worker will get empty batches even when drop_remainder=True. - fname1 = os.path.join(self.get_temp_dir(), "1.txt") - _create_text_file(fname1, 5) - fname2 = os.path.join(self.get_temp_dir(), "2.txt") - _create_text_file(fname2, 9) - - def dataset_fn(input_context): - dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2]) - dataset = dataset.shard(input_context.num_input_pipelines, - input_context.input_pipeline_id) - return readers.TextLineDatasetV2(dataset).map( - string_ops.string_to_number).batch( - input_context.get_per_replica_batch_size(4), - drop_remainder=drop_remainder) - - distribution.extended.experimental_enable_get_next_as_optional = ( - enable_get_next_as_optional) - ds = distribution.experimental_distribute_datasets_from_function(dataset_fn) - _check_type_spec_structure(iter(ds)) - element_spec = ds.element_spec - iter_element_spec = iter(ds).element_spec - nest.assert_same_structure(element_spec, iter_element_spec) - self.assertAllEqual( - nest.flatten(element_spec), nest.flatten(iter_element_spec)) - - @def_function.function(input_signature=[element_spec]) - def process_inputs(inputs): - distribution.run(lambda inputs: inputs, args=(inputs,)) - - for x in ds: - process_inputs(x) - @combinations.generate( combinations.combine( mode=["eager"], @@ -307,149 +257,6 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): self.assertEqual(spec1, spec1.most_specific_compatible_type(spec2)) self.assertEqual(spec1, spec2.most_specific_compatible_type(spec1)) - @combinations.generate( - combinations.combine( - mode=["eager"], - distribution=[ - strategy_combinations.multi_worker_mirrored_2x1_cpu, - strategy_combinations.multi_worker_mirrored_2x1_gpu, - strategy_combinations.multi_worker_mirrored_2x2_gpu, - ], - tf_api_version=2, - drop_remainder=[True, False], - )) - def testFromDatasetDoesNotTriggerFunctionTracing(self, distribution, - drop_remainder): - self.trace_count = 0 - - @def_function.function - def f(v): - del v - self.trace_count += 1 - - distribution.extended.experimental_enable_get_next_as_optional = True - # Total dataset size 5 allows us to have full batches, partial batches and - # empty batches. - dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3))).batch( - 4, drop_remainder=drop_remainder) - dataset = distribution.experimental_distribute_dataset(dataset) - for v in iter(dataset): - f(v) - self.assertEqual(self.trace_count, 1) - - @combinations.generate( - combinations.combine( - mode=["eager"], - distribution=[ - strategy_combinations.multi_worker_mirrored_2x1_cpu, - strategy_combinations.multi_worker_mirrored_2x1_gpu, - strategy_combinations.multi_worker_mirrored_2x2_gpu, - ], - tf_api_version=2, - drop_remainder=[True, False], - )) - def testFromDatasetFileShardingDoesNotTriggerFunctionTracing( - self, distribution, drop_remainder): - # Create files that produce partial/empty batches at different batch. - fname1 = os.path.join(self.get_temp_dir(), "1.txt") - _create_text_file(fname1, 5) - fname2 = os.path.join(self.get_temp_dir(), "2.txt") - _create_text_file(fname2, 9) - - self.trace_count = 0 - - @def_function.function - def f(v): - del v - self.trace_count += 1 - - distribution.extended.experimental_enable_get_next_as_optional = True - dataset = readers.TextLineDatasetV2([fname1, fname2]).batch( - 4, drop_remainder=drop_remainder) - dataset = distribution.experimental_distribute_dataset(dataset) - for v in iter(dataset): - f(v) - self.assertEqual(self.trace_count, 1) - - @combinations.generate( - combinations.combine( - mode=["eager"], - distribution=[ - strategy_combinations.multi_worker_mirrored_2x1_cpu, - strategy_combinations.multi_worker_mirrored_2x1_gpu, - strategy_combinations.multi_worker_mirrored_2x2_gpu, - ], - tf_api_version=2, - drop_remainder=[True, False], - )) - def testFromFunctionDoesNotTriggerFunctionTracing(self, distribution, - drop_remainder): - - def dataset_fn(input_context): - # Total dataset size 5 allows us to have full batches, partial batches and - # empty batches. - dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3))) - dataset = dataset.batch( - input_context.get_per_replica_batch_size(4), - drop_remainder=drop_remainder) - return dataset.shard(input_context.num_input_pipelines, - input_context.input_pipeline_id) - - self.trace_count = 0 - - @def_function.function - def f(v): - del v - self.trace_count += 1 - - distribution.extended.experimental_enable_get_next_as_optional = True - dataset = distribution.experimental_distribute_datasets_from_function( - dataset_fn) - for v in iter(dataset): - f(v) - self.assertEqual(self.trace_count, 1) - - @combinations.generate( - combinations.combine( - mode=["eager"], - distribution=[ - strategy_combinations.multi_worker_mirrored_2x1_cpu, - strategy_combinations.multi_worker_mirrored_2x1_gpu, - strategy_combinations.multi_worker_mirrored_2x2_gpu, - ], - tf_api_version=2, - drop_remainder=[True, False], - )) - def testFromFunctionFileShardingDoesNotTriggerFunctionTracing( - self, distribution, drop_remainder): - # Create files that produce partial/empty batches at different batch. - fname1 = os.path.join(self.get_temp_dir(), "1.txt") - _create_text_file(fname1, 5) - fname2 = os.path.join(self.get_temp_dir(), "2.txt") - _create_text_file(fname2, 9) - - def dataset_fn(input_context): - dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2]) - dataset = dataset.shard(input_context.num_input_pipelines, - input_context.input_pipeline_id) - return readers.TextLineDatasetV2(dataset).batch( - input_context.get_per_replica_batch_size(4), - drop_remainder=drop_remainder) - - self.trace_count = 0 - - @def_function.function - def f(v): - del v - self.trace_count += 1 - - distribution.extended.experimental_enable_get_next_as_optional = True - dataset = distribution.experimental_distribute_datasets_from_function( - dataset_fn) - for v in iter(dataset): - f(v) - self.assertEqual(self.trace_count, 1) - @combinations.generate( combinations.combine( mode=["eager"], @@ -643,9 +450,6 @@ class RaggedTensorDistributedIteratorTest(test.TestCase, distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, - strategy_combinations.multi_worker_mirrored_2x1_cpu, - strategy_combinations.multi_worker_mirrored_2x1_gpu, - strategy_combinations.multi_worker_mirrored_2x2_gpu, ], enable_get_next_as_optional=[True, False])) def testDoesNotTriggerFunctionTracing(self, distribution, diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index a59164bb0d7..b90fd24b6e0 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -362,23 +362,8 @@ class DistributedDelegate(DistributedValues): class PerReplica(DistributedValues, composite_tensor.CompositeTensor): """Holds a map from replica to unsynchronized values.""" - def __init__(self, values, type_spec_override=None): - super(PerReplica, self).__init__(values) - # Allow setting a type spec that can be different from the underlying - # values. This allows us avoid retracing for PerReplica from full, partial - # and empty batches. In a multi client setup, we need to avoid such - # retracing otherwise the collectives may mismatch since we assign new - # collective keys when retracing the function. - # - # TODO(b/166169298): remove after CrossDeviceOps is tracing safe. - self._type_spec_override = type_spec_override - @property def _type_spec(self): - if self._type_spec_override is not None: - # Return a deep copy in case the caller changes it, since _type_spec() - # normally returns a temporary object. - return copy.deepcopy(self._type_spec_override) return PerReplicaSpec( *(type_spec.type_spec_from_value(v) for v in self._values))