diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 4072e7ad230..93f28e7f4bc 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1008,22 +1008,24 @@ distribute_py_test( "multi_and_single_gpu", ], deps = [ - ":collective_all_reduce_strategy", ":combinations", - ":input_lib", - ":mirrored_strategy", - ":multi_worker_test_base", - ":reduce_util", + ":distribute_lib", ":strategy_combinations", + ":test_util", ":tpu_strategy", ":values", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:errors", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:composite_tensor", + "//tensorflow/python:dtypes", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:tf2", + "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/eager:context", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:test", "//tensorflow/python/ops/ragged:ragged_tensor", "//third_party/py/numpy", diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index d01cedcead0..f99c978b6ff 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -535,7 +535,7 @@ def _get_next_as_optional(iterator, strategy, return_per_replica=False): flattened_data = [] for per_worker_data in replicas: flattened_data.extend(per_worker_data) - replicas = distribute_utils.regroup(flattened_data) + replicas = _create_per_replica(flattened_data, strategy) # Run an all-reduce to see whether any worker has values. # TODO(b/131423105): we should be able to short-cut the all-reduce in some @@ -635,7 +635,7 @@ class DistributedIteratorBase(DistributedIteratorInterface): # Make `replicas` a flat list of values across all replicas. replicas.extend( self._iterators[i].get_next_as_list_static_shapes(new_name)) - return distribute_utils.regroup(replicas) + return _create_per_replica(replicas, self._strategy) out_of_range_replicas = [] def out_of_range_fn(worker_index, device): @@ -669,7 +669,7 @@ class DistributedIteratorBase(DistributedIteratorInterface): results.append(result) replicas = results - return distribute_utils.regroup(replicas) + return _create_per_replica(replicas, self._strategy) class DistributedIteratorV1(DistributedIteratorBase): @@ -988,10 +988,20 @@ class DistributedDataset(_IterableInput): self._input_workers = input_workers self._strategy = strategy + element_spec = self._cloned_datasets[0].element_spec self._enable_get_next_as_optional = _enable_get_next_as_optional( - self._strategy, dataset.element_spec) + self._strategy, element_spec) + # When partial batch handling is enabled, always set the batch dimension to + # None, otherwise we just follow element_spec of the underlying dataset + # (whose batch dimension may also be None). This is because with partial + # batching handling we could always produce empty batches. + # + # TODO(b/163362689): avoid this once we have more elegent way to handle + # retracing and collectives. + if strategy.extended._in_multi_worker_mode(): # pylint: disable=protected-access + element_spec = nest.map_structure(_rebatch_as_dynamic, element_spec) self._element_spec = _create_distributed_tensor_spec( - self._strategy, self._cloned_datasets[0].element_spec) + self._strategy, element_spec) def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync): """Returns a callable that rebatches the input dataset. @@ -1208,6 +1218,15 @@ class DistributedDatasetsFromFunction(_IterableInput): dataset_fn)) self._enable_get_next_as_optional = _enable_get_next_as_optional( self._strategy, element_spec) + # When partial batch handling is enabled, always set the batch dimension to + # None, otherwise we just follow element_spec of the underlying dataset + # (whose batch dimension may also be None). This is because with partial + # batching handling we could always produce empty batches. + # + # TODO(b/163362689): avoid this once we have more elegent way to handle + # retracing and collectives. + if strategy.extended._in_multi_worker_mode(): # pylint: disable=protected-access + element_spec = nest.map_structure(_rebatch_as_dynamic, element_spec) self._element_spec = _create_distributed_tensor_spec( self._strategy, element_spec) @@ -1350,6 +1369,7 @@ class InputFunctionIterator(DistributedIteratorV1): super(InputFunctionIterator, self).__init__( input_workers, iterators, strategy, enable_get_next_as_optional=False) + self._enable_get_next_as_optional = False # TODO(anjalisridhar): This class will soon be removed and users should move @@ -1993,13 +2013,14 @@ def _create_distributed_tensor_spec(strategy, tensor_spec): """ num_replicas = len(strategy.extended.worker_devices) - # If the number of devices used in the strategy is just 1 then we return - # the tensor_spec as is. - if num_replicas == 1: + # For one device strategy that is not MultiWorkerMirroredStrategy, return the + # tensor_spec as is, since we don't wrap the output with PerReplica in this + # case. + # TODO(b/166464552): remove after we always wrap for all strategies. + if not _always_wrap(strategy): return tensor_spec - # If the number of devices is greater than 1 then we assume the input to - # tf.function is a per replica type. + # For other cases we assume the input to tf.function is a per replica type. def _get_value_per_replica(tensor_spec_per_input): value_specs = [tensor_spec_per_input for _ in range(num_replicas)] return values.PerReplicaSpec(*value_specs) @@ -2029,3 +2050,63 @@ def _enable_get_next_as_optional(strategy, element_spec): return False return not _is_statically_shaped( element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access + + +def _create_per_replica(value_list, strategy): + """Creates a PerReplica. + + For strategies other than OneDeviceStrategy, it creates a PerReplica whose + type spec is set to the element spec of the dataset. This helps avoid + retracing for partial batches. Retracing is problematic for multi client when + different client retraces different time, since retracing changes the + collective keys in the tf.function, and causes mismatches among clients. + + For single client strategies, this simply calls distribute_utils.regroup(). + + Args: + value_list: a list of values, one for each replica. + strategy: the `tf.distribute.Strategy`. + + Returns: + a structure of PerReplica. + + """ + # TODO(b/166464552): always wrap for all one device strategies as well. + always_wrap = _always_wrap(strategy) + per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap) + + # When partial batch handling is enabled, always set the batch dimension to + # None, otherwise we just follow element_spec of the underlying dataset + # (whose batch dimension may also be None). This is because with partial + # batching handling we could always produce empty batches. + # + # TODO(b/163362689): avoid this once we have more elegent way to handle + # retracing and collectives. + if strategy.extended._in_multi_worker_mode(): # pylint: disable=protected-access + # Use expand_composites=False since we don't want to expand PerReplica, + # which is a CompositeTensor. + flat_per_replicas = nest.flatten(per_replicas, expand_composites=False) + flat_spec = [type_spec.type_spec_from_value(v) for v in flat_per_replicas] + for per_replica, spec in zip(flat_per_replicas, flat_spec): + # pylint: disable=protected-access + per_replica._type_spec_override = values.PerReplicaSpec( + *nest.map_structure(_rebatch_as_dynamic, spec._value_specs)) + # pylint: enable=protected-access + per_replicas = nest.pack_sequence_as(per_replicas, flat_per_replicas) + + return per_replicas + + +def _always_wrap(strategy): + """Returns whether to always wrap the values in a DistributedValues.""" + return strategy.extended._in_multi_worker_mode() or len( # pylint: disable=protected-access + strategy.extended.worker_devices) > 1 + + +def _rebatch_as_dynamic(spec): + """Rebatch the spec to have a dynamic batch dimension.""" + # pylint: disable=protected-access + if isinstance(spec, type_spec.BatchableTypeSpec) and spec._shape.ndims > 0: + return spec._unbatch()._batch(None) + # pylint: enable=protected-access + return spec diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 5abd6f483d3..42f4c1a621f 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -557,7 +557,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase, iterator = iter(dist_dataset) for i, element in enumerate(iterator): - self.assertEqual(i, element.numpy()) + self.assertAllEqual(distribution.experimental_local_results(element), [i]) @combinations.generate( combinations.combine( diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py index bc6ac811bbb..df52e0eaabf 100644 --- a/tensorflow/python/distribute/input_lib_type_spec_test.py +++ b/tensorflow/python/distribute/input_lib_type_spec_test.py @@ -18,15 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os from absl.testing import parameterized import numpy as np from tensorflow.python import tf2 from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import test_util from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import values from tensorflow.python.eager import def_function @@ -37,6 +40,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops +from tensorflow.python.ops import string_ops from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib from tensorflow.python.util import nest @@ -116,14 +120,17 @@ class DistributedIteratorTest(test.TestCase, distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, ], - enable_get_next_as_optional=[True, False])) + enable_get_next_as_optional=[True, False], + drop_remainder=[True, False], + tf_api_version=2, + )) def testDoesNotTriggerFunctionTracing(self, input_type, distribution, - enable_get_next_as_optional): - if not tf2.enabled(): - self.skipTest("DistributedIterator CompositeTensor support is only " - "present in TF 2.0 only.") - + enable_get_next_as_optional, + drop_remainder): trace_count = [0] @def_function.function @@ -135,7 +142,8 @@ class DistributedIteratorTest(test.TestCase, counter += 1 return counter - dataset = dataset_ops.DatasetV2.range(10).batch(2) + dataset = dataset_ops.DatasetV2.range(10).batch( + 2, drop_remainder=drop_remainder) distribution.extended.experimental_enable_get_next_as_optional = ( enable_get_next_as_optional) @@ -161,27 +169,84 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): strategy_combinations.mirrored_strategy_with_one_cpu, strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, - strategy_combinations.central_storage_strategy_with_two_gpus, + strategy_combinations.central_storage_strategy_with_gpu_and_cpu, + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, ], - input_type=["dataset", "dataset_fn"], + tf_api_version=2, + enable_get_next_as_optional=[True, False], + drop_remainder=[True, False], )) - def testInputSignatureForPerReplicaValues(self, distribution, input_type): - def dataset_fn(ctx): - del ctx # unused - return dataset_ops.DatasetV2.from_tensor_slices( - np.ones([10, 12]).astype(np.float32)).batch(4) + def testInputSignatureForPerReplicaValues(self, distribution, + enable_get_next_as_optional, + drop_remainder): + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + ds = dataset_ops.DatasetV2.from_tensor_slices( + np.ones([9, 12]).astype(np.float32)).batch( + 4, drop_remainder=drop_remainder) + ds = distribution.experimental_distribute_dataset(ds) + _check_type_spec_structure(iter(ds)) + element_spec = ds.element_spec + iter_element_spec = iter(ds).element_spec + nest.assert_same_structure(element_spec, iter_element_spec) + self.assertAllEqual( + nest.flatten(element_spec), nest.flatten(iter_element_spec)) - if input_type == "dataset": - ds = distribution.experimental_distribute_dataset( - dataset_fn(distribute_lib.InputContext())) - type_spec = ds.element_spec - else: - ds = distribution.distribute_datasets_from_function(dataset_fn) - iterator = iter(ds) - _check_type_spec_structure(iterator) - type_spec = iterator.element_spec + @def_function.function(input_signature=[element_spec]) + def process_inputs(inputs): + distribution.run(lambda inputs: inputs, args=(inputs,)) - @def_function.function(input_signature=[type_spec]) + for x in ds: + process_inputs(x) + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.one_device_strategy, + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_gpu_and_cpu, + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ], + tf_api_version=2, + enable_get_next_as_optional=[True, False], + drop_remainder=[True, False], + )) + def testFromFunctionInputSignatureForPerReplicaValues( + self, distribution, enable_get_next_as_optional, drop_remainder): + # Create files that produce partial/empty batches at different batch. Note + # that some worker will get empty batches even when drop_remainder=True. + fname1 = os.path.join(self.get_temp_dir(), "1.txt") + _create_text_file(fname1, 5) + fname2 = os.path.join(self.get_temp_dir(), "2.txt") + _create_text_file(fname2, 9) + + def dataset_fn(input_context): + dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2]) + dataset = dataset.shard(input_context.num_input_pipelines, + input_context.input_pipeline_id) + return readers.TextLineDatasetV2(dataset).map( + string_ops.string_to_number).batch( + input_context.get_per_replica_batch_size(4), + drop_remainder=drop_remainder) + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + ds = distribution.experimental_distribute_datasets_from_function(dataset_fn) + _check_type_spec_structure(iter(ds)) + element_spec = ds.element_spec + iter_element_spec = iter(ds).element_spec + nest.assert_same_structure(element_spec, iter_element_spec) + self.assertAllEqual( + nest.flatten(element_spec), nest.flatten(iter_element_spec)) + + @def_function.function(input_signature=[element_spec]) def process_inputs(inputs): distribution.run(lambda inputs: inputs, args=(inputs,)) @@ -247,6 +312,158 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): self.assertEqual(spec1, spec1.most_specific_compatible_type(spec2)) self.assertEqual(spec1, spec2.most_specific_compatible_type(spec1)) + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ], + tf_api_version=2, + enable_get_next_as_optional=[True, False], + drop_remainder=[True, False], + )) + def testFromDatasetDoesNotTriggerFunctionTracing(self, distribution, + enable_get_next_as_optional, + drop_remainder): + self.trace_count = 0 + + @def_function.function + def f(v): + del v + self.trace_count += 1 + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + # Total dataset size 5 allows us to have full batches, partial batches and + # empty batches. + dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3))).batch( + 4, drop_remainder=drop_remainder) + dataset = distribution.experimental_distribute_dataset(dataset) + for v in iter(dataset): + f(v) + self.assertEqual(self.trace_count, 1) + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ], + tf_api_version=2, + enable_get_next_as_optional=[True, False], + drop_remainder=[True, False], + )) + def testFromDatasetFileShardingDoesNotTriggerFunctionTracing( + self, distribution, enable_get_next_as_optional, drop_remainder): + # Create files that produce partial/empty batches at different batch. + fname1 = os.path.join(self.get_temp_dir(), "1.txt") + _create_text_file(fname1, 5) + fname2 = os.path.join(self.get_temp_dir(), "2.txt") + _create_text_file(fname2, 9) + + self.trace_count = 0 + + @def_function.function + def f(v): + del v + self.trace_count += 1 + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + dataset = readers.TextLineDatasetV2([fname1, fname2]).batch( + 4, drop_remainder=drop_remainder) + dataset = distribution.experimental_distribute_dataset(dataset) + for v in iter(dataset): + f(v) + self.assertEqual(self.trace_count, 1) + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ], + tf_api_version=2, + enable_get_next_as_optional=[True, False], + drop_remainder=[True, False], + )) + def testFromFunctionDoesNotTriggerFunctionTracing( + self, distribution, enable_get_next_as_optional, drop_remainder): + + def dataset_fn(input_context): + # Total dataset size 5 allows us to have full batches, partial batches and + # empty batches. + dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3))) + dataset = dataset.batch( + input_context.get_per_replica_batch_size(4), + drop_remainder=drop_remainder) + return dataset.shard(input_context.num_input_pipelines, + input_context.input_pipeline_id) + + self.trace_count = 0 + + @def_function.function + def f(v): + del v + self.trace_count += 1 + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + dataset = distribution.experimental_distribute_datasets_from_function( + dataset_fn) + for v in iter(dataset): + f(v) + self.assertEqual(self.trace_count, 1) + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, + ], + tf_api_version=2, + enable_get_next_as_optional=[True, False], + drop_remainder=[True, False], + )) + def testFromFunctionFileShardingDoesNotTriggerFunctionTracing( + self, distribution, enable_get_next_as_optional, drop_remainder): + # Create files that produce partial/empty batches at different batch. + fname1 = os.path.join(self.get_temp_dir(), "1.txt") + _create_text_file(fname1, 5) + fname2 = os.path.join(self.get_temp_dir(), "2.txt") + _create_text_file(fname2, 9) + + def dataset_fn(input_context): + dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2]) + dataset = dataset.shard(input_context.num_input_pipelines, + input_context.input_pipeline_id) + return readers.TextLineDatasetV2(dataset).batch( + input_context.get_per_replica_batch_size(4), + drop_remainder=drop_remainder) + + self.trace_count = 0 + + @def_function.function + def f(v): + del v + self.trace_count += 1 + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + dataset = distribution.experimental_distribute_datasets_from_function( + dataset_fn) + for v in iter(dataset): + f(v) + self.assertEqual(self.trace_count, 1) + class RaggedTensorDistributedIteratorTest(test.TestCase, parameterized.TestCase): @@ -254,14 +471,14 @@ class RaggedTensorDistributedIteratorTest(test.TestCase, @combinations.generate( combinations.combine( mode=["eager"], + tf_api_version=2, distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.central_storage_strategy_with_gpu_and_cpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, ], enable_get_next_as_optional=[True, False])) def testTypeSpec(self, distribution, enable_get_next_as_optional): - if not tf2.enabled(): - self.skipTest("DistributedIterator has CompositeTensor support in " - "TF 2.0 only.") ctx = distribute_lib.InputContext() batch_size = ctx.get_per_replica_batch_size(8) # Use 20 which isn't divisible by 8 to test partial batch behavior. @@ -313,16 +530,16 @@ class RaggedTensorDistributedIteratorTest(test.TestCase, @combinations.generate( combinations.combine( mode=["eager"], + tf_api_version=2, distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_gpu_and_cpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, ], + enable_get_next_as_optional=[True, False])) def testTypeSpecRoundTrip(self, distribution, enable_get_next_as_optional): - if not tf2.enabled(): - self.skipTest("DistributedIterator CompositeTensor support is only " - "present in TF 2.0 only.") - ctx = distribute_lib.InputContext() batch_size = ctx.get_per_replica_batch_size(8) # Use 20 which isn't divisible by 8 to test partial batch behavior. @@ -366,17 +583,17 @@ class RaggedTensorDistributedIteratorTest(test.TestCase, @combinations.generate( combinations.combine( mode=["eager"], + tf_api_version=2, distribution=[ strategy_combinations.mirrored_strategy_with_gpu_and_cpu, strategy_combinations.tpu_strategy, + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.multi_worker_mirrored_2x2_gpu, ], enable_get_next_as_optional=[True, False])) def testDoesNotTriggerFunctionTracing(self, distribution, enable_get_next_as_optional): - if not tf2.enabled(): - self.skipTest("DistributedIterator CompositeTensor support is only " - "present in TF 2.0 only.") - trace_count = [0] @def_function.function @@ -432,5 +649,11 @@ def _check_type_spec_structure(x): nest.assert_same_structure(x, x._type_spec, expand_composites=True) +def _create_text_file(fname, num_lines): + with open(fname, "w") as f: + for i in range(num_lines): + f.write("%d\n" % i) + + if __name__ == "__main__": - test.main() + test_util.main() diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index b90fd24b6e0..a59164bb0d7 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -362,8 +362,23 @@ class DistributedDelegate(DistributedValues): class PerReplica(DistributedValues, composite_tensor.CompositeTensor): """Holds a map from replica to unsynchronized values.""" + def __init__(self, values, type_spec_override=None): + super(PerReplica, self).__init__(values) + # Allow setting a type spec that can be different from the underlying + # values. This allows us avoid retracing for PerReplica from full, partial + # and empty batches. In a multi client setup, we need to avoid such + # retracing otherwise the collectives may mismatch since we assign new + # collective keys when retracing the function. + # + # TODO(b/166169298): remove after CrossDeviceOps is tracing safe. + self._type_spec_override = type_spec_override + @property def _type_spec(self): + if self._type_spec_override is not None: + # Return a deep copy in case the caller changes it, since _type_spec() + # normally returns a temporary object. + return copy.deepcopy(self._type_spec_override) return PerReplicaSpec( *(type_spec.type_spec_from_value(v) for v in self._values))