diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 9defb75c703..6c2e1aae6ce 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import sys import six @@ -494,12 +495,13 @@ class DistributedIteratorSpec(type_spec.TypeSpec): def _component_specs(self): specs = [] worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access - for i in range(len(worker_device_pairs)): - input_device, compute_devices = worker_device_pairs[i] + + for i, (input_device, compute_devices) in enumerate(worker_device_pairs): + element_spec = nest.map_structure( + functools.partial(_replace_per_replica_spec, i=i), self._element_spec) specs.append(_SingleWorkerDatasetIteratorSpec(input_device, compute_devices, - element_spec= - self._element_spec)) + element_spec)) return specs def _to_components(self, value): @@ -1140,7 +1142,7 @@ class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec): def __init__(self, worker, devices, element_spec): self._worker = worker - self._devices = devices + self._devices = tuple(device_util.canonicalize(d) for d in devices) self._element_spec = element_spec @property @@ -1148,7 +1150,7 @@ class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec): return _SingleWorkerOwnedDatasetIterator def _serialize(self): - return (self._worker, tuple(self._devices), self._element_spec) + return (self._worker, self._devices, self._element_spec) @property def _component_specs(self): @@ -1579,3 +1581,11 @@ def _create_distributed_tensor_spec(strategy, tensor_spec): return values.PerReplicaSpec(*value_specs) return nest.map_structure(_get_value_per_replica, tensor_spec) + + +def _replace_per_replica_spec(spec, i): + """If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec.""" + if isinstance(spec, values.PerReplicaSpec): + return spec._value_specs[i] # pylint: disable=protected-access + else: + return spec diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 030636453ea..60212f7a3b7 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -44,6 +44,7 @@ from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import test +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -174,6 +175,10 @@ class DistributedIteratorTestBase(test.TestCase): else: self.skipTest("unsupported test combination") + if isinstance(iterator, composite_tensor.CompositeTensor): + nest.assert_same_structure(iterator, iterator._type_spec, + expand_composites=True) + if iteration_type == "get_next": evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) if not ops.executing_eagerly_outside_functions(): diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py index 53bcc576b24..7f5b0e09f2c 100644 --- a/tensorflow/python/distribute/input_lib_type_spec_test.py +++ b/tensorflow/python/distribute/input_lib_type_spec_test.py @@ -30,12 +30,14 @@ from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute import values from tensorflow.python.eager import def_function from tensorflow.python.eager import test +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib +from tensorflow.python.util import nest class DistributedIteratorTest(test.TestCase, @@ -63,6 +65,7 @@ class DistributedIteratorTest(test.TestCase, dist_dataset = distribution.experimental_distribute_dataset(dataset) with distribution.scope(): iterator = iter(dist_dataset) + _check_type_spec_structure(iterator) spec = iterator._type_spec self.assertEqual(spec._input_workers, iterator._input_workers) @@ -95,6 +98,7 @@ class DistributedIteratorTest(test.TestCase, dist_dataset = distribution.experimental_distribute_dataset(dataset) with distribution.scope(): iterator = iter(dist_dataset) + _check_type_spec_structure(iterator) spec = iterator._type_spec @@ -139,6 +143,7 @@ class DistributedIteratorTest(test.TestCase, with distribution.scope(): for _ in range(3): iterator = iter(dist_dataset) + _check_type_spec_structure(iterator) counter = f(iterator) self.assertEqual(trace_count[0], 1) @@ -173,6 +178,7 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): ds = distribution.experimental_distribute_datasets_from_function( dataset_fn) iterator = iter(ds) + _check_type_spec_structure(iterator) type_spec = iterator.element_spec @def_function.function(input_signature=[type_spec]) @@ -276,6 +282,7 @@ class RaggedTensorDistributedIteratorTest(test.TestCase, dist_dataset = distribution.experimental_distribute_dataset(dataset) with distribution.scope(): iterator = iter(dist_dataset) + _check_type_spec_structure(iterator) spec = iterator._type_spec self.assertEqual(spec._input_workers, iterator._input_workers) @@ -336,6 +343,7 @@ class RaggedTensorDistributedIteratorTest(test.TestCase, dist_dataset = distribution.experimental_distribute_dataset(dataset) with distribution.scope(): iterator = iter(dist_dataset) + _check_type_spec_structure(iterator) spec = iterator._type_spec @@ -391,11 +399,18 @@ class RaggedTensorDistributedIteratorTest(test.TestCase, with distribution.scope(): for _ in range(3): iterator = iter(dist_dataset) + _check_type_spec_structure(iterator) counter = f(iterator) self.assertEqual(trace_count[0], 1) self.assertEqual(counter, 5) +def _check_type_spec_structure(x): + """Verifies that `x` has the same structure as its `TypeSpec`.""" + if isinstance(x, composite_tensor.CompositeTensor): + nest.assert_same_structure(x, x._type_spec, expand_composites=True) + + if __name__ == "__main__": test.main()