[retry]DistributedDataset creates elements with fixed spec to help avoid retracing

tf.function tracing depends on the inputs to the function. For a typical training loop:

x, y = next(iter)
train_fn(x,y)

it may retrace when getting a partial/batches. This is problematic for multi client training since different client may retrace at different time. We assign collective instance_key when tracing a function, retracing results in different sets of instance keys.

This change we overrides the PerReplica type spec, which is used to calculate function cache key. This tries to avoid retracing in common cases, but it doesn't guarantee that it won't happen.

Note that after such change, the function also gets partial shape information. This is the reason we only do it for multi client strategies (MWMS), to avoid performance penalty to e.g. TPU.

PiperOrigin-RevId: 338203534
Change-Id: Iae9d6c3c82113d623707e19142fbebe5597d7898
This commit is contained in:
Ran Chen 2020-10-20 22:41:40 -07:00 committed by TensorFlower Gardener
parent fbcdf129b9
commit 9f51b98f0b
6 changed files with 392 additions and 58 deletions

View File

@ -1018,22 +1018,24 @@ distribute_py_test(
"multi_and_single_gpu",
],
deps = [
":collective_all_reduce_strategy",
":combinations",
":input_lib",
":mirrored_strategy",
":multi_worker_test_base",
":reduce_util",
":distribute_lib",
":strategy_combinations",
":test_util",
":tpu_strategy",
":values",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:composite_tensor",
"//tensorflow/python:dtypes",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tf2",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//third_party/py/numpy",

View File

@ -559,7 +559,8 @@ def _get_next_as_optional(iterator, strategy, return_per_replica=False):
flattened_data = []
for per_worker_data in replicas:
flattened_data.extend(per_worker_data)
replicas = distribute_utils.regroup(flattened_data)
replicas = _create_per_replica(
flattened_data, strategy, get_next_as_optional=True)
# Run an all-reduce to see whether any worker has values.
# TODO(b/131423105): we should be able to short-cut the all-reduce in some
@ -659,7 +660,8 @@ class DistributedIteratorBase(DistributedIteratorInterface):
# Make `replicas` a flat list of values across all replicas.
replicas.extend(
self._iterators[i].get_next_as_list_static_shapes(new_name))
return distribute_utils.regroup(replicas)
return _create_per_replica(
replicas, self._strategy, get_next_as_optional=False)
out_of_range_replicas = []
def out_of_range_fn(worker_index, device):
@ -693,7 +695,8 @@ class DistributedIteratorBase(DistributedIteratorInterface):
results.append(result)
replicas = results
return distribute_utils.regroup(replicas)
return _create_per_replica(replicas, self._strategy,
self._enable_get_next_as_optional)
class DistributedIteratorV1(DistributedIteratorBase):
@ -893,11 +896,25 @@ class DistributedIterator(DistributedIteratorBase,
@property
def element_spec(self):
# When partial batch handling is enabled, always set the batch dimension to
# None, otherwise we just follow element_spec of the underlying dataset
# (whose batch dimension may also be None). This is because with partial
# batching handling we could always produce empty batches.
#
# TODO(b/163362689): avoid this once we have more elegent way to handle
# retracing and collectives.
if (self._enable_get_next_as_optional and
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
return nest.map_structure(
_rebatch_as_dynamic, self._element_spec, expand_composites=False)
return self._element_spec
@property
def _type_spec(self):
return DistributedIteratorSpec(self._input_workers, self.element_spec,
# Note that we use actual element_spec to create DistributedIteratorSpec,
# to be consistent with the underlying iterators' specs.
# TODO(b/163362689): remove the comment after the bug if fixed.
return DistributedIteratorSpec(self._input_workers, self._element_spec,
self._strategy,
self._enable_get_next_as_optional)
@ -1097,7 +1114,7 @@ class DistributedDataset(_IterableInput):
worker_iterators,
self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
iterator._element_spec = self.element_spec # pylint: disable=protected-access
iterator._element_spec = self._element_spec # pylint: disable=protected-access
# When async eager is enabled, sometimes the iterator may not finish
# initialization before passing to a multi device function, add a sync point
@ -1110,6 +1127,17 @@ class DistributedDataset(_IterableInput):
@property
def element_spec(self):
"""The type specification of an element of this dataset."""
# When partial batch handling is enabled, always set the batch dimension to
# None, otherwise we just follow element_spec of the underlying dataset
# (whose batch dimension may also be None). This is because with partial
# batching handling we could always produce empty batches.
#
# TODO(b/163362689): avoid this once we have more elegent way to handle
# retracing and collectives.
if (self._enable_get_next_as_optional and
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
return nest.map_structure(
_rebatch_as_dynamic, self._element_spec, expand_composites=False)
return self._element_spec
@ -1279,6 +1307,17 @@ class DistributedDatasetsFromFunction(_IterableInput):
@property
def element_spec(self):
"""The type specification of an element of this dataset."""
# When partial batch handling is enabled, always set the batch dimension to
# None, otherwise we just follow element_spec of the underlying dataset
# (whose batch dimension may also be None). This is because with partial
# batching handling we could always produce empty batches.
#
# TODO(b/163362689): avoid this once we have more elegent way to handle
# retracing and collectives.
if (self._enable_get_next_as_optional and
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
return nest.map_structure(
_rebatch_as_dynamic, self._element_spec, expand_composites=False)
return self._element_spec
@ -1376,6 +1415,7 @@ class InputFunctionIterator(DistributedIteratorV1):
super(InputFunctionIterator, self).__init__(
input_workers, iterators, strategy, enable_get_next_as_optional=False)
self._enable_get_next_as_optional = False
# TODO(anjalisridhar): This class will soon be removed and users should move
@ -2065,13 +2105,14 @@ def _create_distributed_tensor_spec(strategy, tensor_spec):
"""
num_replicas = len(strategy.extended.worker_devices)
# If the number of devices used in the strategy is just 1 then we return
# the tensor_spec as is.
if num_replicas == 1:
# For one device strategy that is not MultiWorkerMirroredStrategy, return the
# tensor_spec as is, since we don't wrap the output with PerReplica in this
# case.
# TODO(b/166464552): remove after we always wrap for all strategies.
if not _always_wrap(strategy):
return tensor_spec
# If the number of devices is greater than 1 then we assume the input to
# tf.function is a per replica type.
# For other cases we assume the input to tf.function is a per replica type.
def _get_value_per_replica(tensor_spec_per_input):
value_specs = [tensor_spec_per_input for _ in range(num_replicas)]
return values.PerReplicaSpec(*value_specs)
@ -2109,3 +2150,70 @@ def _enable_get_next_as_optional(strategy, dataset):
return not _is_statically_shaped(
dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
def _create_per_replica(value_list, strategy, get_next_as_optional):
"""Creates a PerReplica.
For strategies other than OneDeviceStrategy, it creates a PerReplica whose
type spec is set to the element spec of the dataset. This helps avoid
retracing for partial batches. Retracing is problematic for multi client when
different client retraces different time, since retracing changes the
collective keys in the tf.function, and causes mismatches among clients.
For single client strategies, this simply calls distribute_utils.regroup().
Args:
value_list: a list of values, one for each replica.
strategy: the `tf.distribute.Strategy`.
get_next_as_optional: whether last partial batch handling is enabled.
Returns:
a structure of PerReplica.
"""
# TODO(b/166464552): always wrap for all one device strategies as well.
always_wrap = _always_wrap(strategy)
per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap)
# When partial batch handling is enabled, always set the batch dimension to
# None, otherwise we just follow element_spec of the underlying dataset
# (whose batch dimension may also be None). This is because with partial
# batching handling we could always produce empty batches.
#
# TODO(b/163362689): avoid this once we have more elegent way to handle
# retracing and collectives.
if (get_next_as_optional and strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
# Use expand_composites=False since we don't want to expand PerReplica,
# which is a CompositeTensor.
flat_per_replicas = nest.flatten(per_replicas, expand_composites=False)
flat_spec = [type_spec.type_spec_from_value(v) for v in flat_per_replicas]
for per_replica, spec in zip(flat_per_replicas, flat_spec):
per_replica._type_spec_override = _rebatch_as_dynamic(spec) # pylint: disable=protected-access
per_replicas = nest.pack_sequence_as(per_replicas, flat_per_replicas)
return per_replicas
def _always_wrap(strategy):
"""Returns whether to always wrap the values in a DistributedValues."""
return strategy.extended._in_multi_worker_mode() or len( # pylint: disable=protected-access
strategy.extended.worker_devices) > 1
def _rebatch_as_dynamic(per_replica_spec):
"""Rebatch the spec to have a dynamic batch dimension."""
assert isinstance(per_replica_spec, values.PerReplicaSpec), per_replica_spec
# pylint: disable=protected-access
def _rebatch(spec):
# Rebatch if possible.
try:
return spec._unbatch()._batch(None)
except ValueError:
pass
return spec
return values.PerReplicaSpec(
*nest.map_structure(_rebatch, per_replica_spec._value_specs))
# pylint: enable=protected-access

View File

@ -557,7 +557,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
iterator = iter(dist_dataset)
for i, element in enumerate(iterator):
self.assertEqual(i, element.numpy())
self.assertAllEqual(distribution.experimental_local_results(element), [i])
@combinations.generate(
combinations.combine(

View File

@ -18,15 +18,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
import numpy as np
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import test_util
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import values
from tensorflow.python.eager import def_function
@ -37,6 +40,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
from tensorflow.python.util import nest
@ -116,14 +120,17 @@ class DistributedIteratorTest(test.TestCase,
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
enable_get_next_as_optional=[True, False]))
enable_get_next_as_optional=[True, False],
drop_remainder=[True, False],
tf_api_version=2,
))
def testDoesNotTriggerFunctionTracing(self, input_type, distribution,
enable_get_next_as_optional):
if not tf2.enabled():
self.skipTest("DistributedIterator CompositeTensor support is only "
"present in TF 2.0 only.")
enable_get_next_as_optional,
drop_remainder):
trace_count = [0]
@def_function.function
@ -135,7 +142,8 @@ class DistributedIteratorTest(test.TestCase,
counter += 1
return counter
dataset = dataset_ops.DatasetV2.range(10).batch(2)
dataset = dataset_ops.DatasetV2.range(10).batch(
2, drop_remainder=drop_remainder)
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
@ -161,27 +169,79 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
input_type=["dataset", "dataset_fn"],
tf_api_version=2,
enable_get_next_as_optional=[True, False],
drop_remainder=[True, False],
))
def testInputSignatureForPerReplicaValues(self, distribution, input_type):
def dataset_fn(ctx):
del ctx # unused
return dataset_ops.DatasetV2.from_tensor_slices(
np.ones([10, 12]).astype(np.float32)).batch(4)
def testInputSignatureForPerReplicaValues(self, distribution,
enable_get_next_as_optional,
drop_remainder):
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
ds = dataset_ops.DatasetV2.from_tensor_slices(
np.ones([9, 12]).astype(np.float32)).batch(
4, drop_remainder=drop_remainder)
ds = distribution.experimental_distribute_dataset(ds)
_check_type_spec_structure(iter(ds))
element_spec = ds.element_spec
iter_element_spec = iter(ds).element_spec
nest.assert_same_structure(element_spec, iter_element_spec)
self.assertAllEqual(
nest.flatten(element_spec), nest.flatten(iter_element_spec))
if input_type == "dataset":
ds = distribution.experimental_distribute_dataset(
dataset_fn(distribute_lib.InputContext()))
type_spec = ds.element_spec
else:
ds = distribution.distribute_datasets_from_function(dataset_fn)
iterator = iter(ds)
_check_type_spec_structure(iterator)
type_spec = iterator.element_spec
@def_function.function(input_signature=[element_spec])
def process_inputs(inputs):
distribution.run(lambda inputs: inputs, args=(inputs,))
@def_function.function(input_signature=[type_spec])
for x in ds:
process_inputs(x)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
tf_api_version=2,
enable_get_next_as_optional=[True, False],
drop_remainder=[True, False],
))
def testFromFunctionInputSignatureForPerReplicaValues(
self, distribution, enable_get_next_as_optional, drop_remainder):
# Create files that produce partial/empty batches at different batch. Note
# that some worker will get empty batches even when drop_remainder=True.
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
_create_text_file(fname1, 5)
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
_create_text_file(fname2, 9)
def dataset_fn(input_context):
dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2])
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
return readers.TextLineDatasetV2(dataset).map(
string_ops.string_to_number).batch(
input_context.get_per_replica_batch_size(4),
drop_remainder=drop_remainder)
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
ds = distribution.experimental_distribute_datasets_from_function(dataset_fn)
_check_type_spec_structure(iter(ds))
element_spec = ds.element_spec
iter_element_spec = iter(ds).element_spec
nest.assert_same_structure(element_spec, iter_element_spec)
self.assertAllEqual(
nest.flatten(element_spec), nest.flatten(iter_element_spec))
@def_function.function(input_signature=[element_spec])
def process_inputs(inputs):
distribution.run(lambda inputs: inputs, args=(inputs,))
@ -247,6 +307,149 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
self.assertEqual(spec1, spec1.most_specific_compatible_type(spec2))
self.assertEqual(spec1, spec2.most_specific_compatible_type(spec1))
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
tf_api_version=2,
drop_remainder=[True, False],
))
def testFromDatasetDoesNotTriggerFunctionTracing(self, distribution,
drop_remainder):
self.trace_count = 0
@def_function.function
def f(v):
del v
self.trace_count += 1
distribution.extended.experimental_enable_get_next_as_optional = True
# Total dataset size 5 allows us to have full batches, partial batches and
# empty batches.
dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3))).batch(
4, drop_remainder=drop_remainder)
dataset = distribution.experimental_distribute_dataset(dataset)
for v in iter(dataset):
f(v)
self.assertEqual(self.trace_count, 1)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
tf_api_version=2,
drop_remainder=[True, False],
))
def testFromDatasetFileShardingDoesNotTriggerFunctionTracing(
self, distribution, drop_remainder):
# Create files that produce partial/empty batches at different batch.
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
_create_text_file(fname1, 5)
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
_create_text_file(fname2, 9)
self.trace_count = 0
@def_function.function
def f(v):
del v
self.trace_count += 1
distribution.extended.experimental_enable_get_next_as_optional = True
dataset = readers.TextLineDatasetV2([fname1, fname2]).batch(
4, drop_remainder=drop_remainder)
dataset = distribution.experimental_distribute_dataset(dataset)
for v in iter(dataset):
f(v)
self.assertEqual(self.trace_count, 1)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
tf_api_version=2,
drop_remainder=[True, False],
))
def testFromFunctionDoesNotTriggerFunctionTracing(self, distribution,
drop_remainder):
def dataset_fn(input_context):
# Total dataset size 5 allows us to have full batches, partial batches and
# empty batches.
dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3)))
dataset = dataset.batch(
input_context.get_per_replica_batch_size(4),
drop_remainder=drop_remainder)
return dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
self.trace_count = 0
@def_function.function
def f(v):
del v
self.trace_count += 1
distribution.extended.experimental_enable_get_next_as_optional = True
dataset = distribution.experimental_distribute_datasets_from_function(
dataset_fn)
for v in iter(dataset):
f(v)
self.assertEqual(self.trace_count, 1)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
tf_api_version=2,
drop_remainder=[True, False],
))
def testFromFunctionFileShardingDoesNotTriggerFunctionTracing(
self, distribution, drop_remainder):
# Create files that produce partial/empty batches at different batch.
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
_create_text_file(fname1, 5)
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
_create_text_file(fname2, 9)
def dataset_fn(input_context):
dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2])
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
return readers.TextLineDatasetV2(dataset).batch(
input_context.get_per_replica_batch_size(4),
drop_remainder=drop_remainder)
self.trace_count = 0
@def_function.function
def f(v):
del v
self.trace_count += 1
distribution.extended.experimental_enable_get_next_as_optional = True
dataset = distribution.experimental_distribute_datasets_from_function(
dataset_fn)
for v in iter(dataset):
f(v)
self.assertEqual(self.trace_count, 1)
class RaggedTensorDistributedIteratorTest(test.TestCase,
parameterized.TestCase):
@ -254,14 +457,14 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
@combinations.generate(
combinations.combine(
mode=["eager"],
tf_api_version=2,
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
enable_get_next_as_optional=[True, False]))
def testTypeSpec(self, distribution, enable_get_next_as_optional):
if not tf2.enabled():
self.skipTest("DistributedIterator has CompositeTensor support in "
"TF 2.0 only.")
ctx = distribute_lib.InputContext()
batch_size = ctx.get_per_replica_batch_size(8)
# Use 20 which isn't divisible by 8 to test partial batch behavior.
@ -313,16 +516,16 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
@combinations.generate(
combinations.combine(
mode=["eager"],
tf_api_version=2,
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
enable_get_next_as_optional=[True, False]))
def testTypeSpecRoundTrip(self, distribution, enable_get_next_as_optional):
if not tf2.enabled():
self.skipTest("DistributedIterator CompositeTensor support is only "
"present in TF 2.0 only.")
ctx = distribute_lib.InputContext()
batch_size = ctx.get_per_replica_batch_size(8)
# Use 20 which isn't divisible by 8 to test partial batch behavior.
@ -366,17 +569,17 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
@combinations.generate(
combinations.combine(
mode=["eager"],
tf_api_version=2,
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
],
enable_get_next_as_optional=[True, False]))
def testDoesNotTriggerFunctionTracing(self, distribution,
enable_get_next_as_optional):
if not tf2.enabled():
self.skipTest("DistributedIterator CompositeTensor support is only "
"present in TF 2.0 only.")
trace_count = [0]
@def_function.function
@ -432,5 +635,11 @@ def _check_type_spec_structure(x):
nest.assert_same_structure(x, x._type_spec, expand_composites=True)
def _create_text_file(fname, num_lines):
with open(fname, "w") as f:
for i in range(num_lines):
f.write("%d\n" % i)
if __name__ == "__main__":
test.main()
test_util.main()

View File

@ -362,8 +362,23 @@ class DistributedDelegate(DistributedValues):
class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
"""Holds a map from replica to unsynchronized values."""
def __init__(self, values, type_spec_override=None):
super(PerReplica, self).__init__(values)
# Allow setting a type spec that can be different from the underlying
# values. This allows us avoid retracing for PerReplica from full, partial
# and empty batches. In a multi client setup, we need to avoid such
# retracing otherwise the collectives may mismatch since we assign new
# collective keys when retracing the function.
#
# TODO(b/166169298): remove after CrossDeviceOps is tracing safe.
self._type_spec_override = type_spec_override
@property
def _type_spec(self):
if self._type_spec_override is not None:
# Return a deep copy in case the caller changes it, since _type_spec()
# normally returns a temporary object.
return copy.deepcopy(self._type_spec_override)
return PerReplicaSpec(
*(type_spec.type_spec_from_value(v) for v in self._values))

View File

@ -341,7 +341,7 @@ distribute_py_test(
full_precision = True,
main = "distribute_strategy_test.py",
python_version = "PY3",
shard_count = 10,
shard_count = 20,
tags = [
"multi_and_single_gpu",
"no_rocm", # times out on ROCm