Fix incompatibilities between DistributedIterator and the corresponding DistributedIteratorSpec.
PiperOrigin-RevId: 309336208 Change-Id: I459daf891346285e7d37f5b43c9ae0a28f44fd21
This commit is contained in:
parent
896a2700d1
commit
c74afd7451
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import sys
|
||||
|
||||
import six
|
||||
@ -494,12 +495,13 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||
def _component_specs(self):
|
||||
specs = []
|
||||
worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access
|
||||
for i in range(len(worker_device_pairs)):
|
||||
input_device, compute_devices = worker_device_pairs[i]
|
||||
|
||||
for i, (input_device, compute_devices) in enumerate(worker_device_pairs):
|
||||
element_spec = nest.map_structure(
|
||||
functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
|
||||
specs.append(_SingleWorkerDatasetIteratorSpec(input_device,
|
||||
compute_devices,
|
||||
element_spec=
|
||||
self._element_spec))
|
||||
element_spec))
|
||||
return specs
|
||||
|
||||
def _to_components(self, value):
|
||||
@ -1140,7 +1142,7 @@ class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
|
||||
|
||||
def __init__(self, worker, devices, element_spec):
|
||||
self._worker = worker
|
||||
self._devices = devices
|
||||
self._devices = tuple(device_util.canonicalize(d) for d in devices)
|
||||
self._element_spec = element_spec
|
||||
|
||||
@property
|
||||
@ -1148,7 +1150,7 @@ class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
|
||||
return _SingleWorkerOwnedDatasetIterator
|
||||
|
||||
def _serialize(self):
|
||||
return (self._worker, tuple(self._devices), self._element_spec)
|
||||
return (self._worker, self._devices, self._element_spec)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
@ -1579,3 +1581,11 @@ def _create_distributed_tensor_spec(strategy, tensor_spec):
|
||||
return values.PerReplicaSpec(*value_specs)
|
||||
|
||||
return nest.map_structure(_get_value_per_replica, tensor_spec)
|
||||
|
||||
|
||||
def _replace_per_replica_spec(spec, i):
|
||||
"""If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec."""
|
||||
if isinstance(spec, values.PerReplicaSpec):
|
||||
return spec._value_specs[i] # pylint: disable=protected-access
|
||||
else:
|
||||
return spec
|
||||
|
@ -44,6 +44,7 @@ from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -174,6 +175,10 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
else:
|
||||
self.skipTest("unsupported test combination")
|
||||
|
||||
if isinstance(iterator, composite_tensor.CompositeTensor):
|
||||
nest.assert_same_structure(iterator, iterator._type_spec,
|
||||
expand_composites=True)
|
||||
|
||||
if iteration_type == "get_next":
|
||||
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
|
@ -30,12 +30,14 @@ from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class DistributedIteratorTest(test.TestCase,
|
||||
@ -63,6 +65,7 @@ class DistributedIteratorTest(test.TestCase,
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
with distribution.scope():
|
||||
iterator = iter(dist_dataset)
|
||||
_check_type_spec_structure(iterator)
|
||||
|
||||
spec = iterator._type_spec
|
||||
self.assertEqual(spec._input_workers, iterator._input_workers)
|
||||
@ -95,6 +98,7 @@ class DistributedIteratorTest(test.TestCase,
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
with distribution.scope():
|
||||
iterator = iter(dist_dataset)
|
||||
_check_type_spec_structure(iterator)
|
||||
|
||||
spec = iterator._type_spec
|
||||
|
||||
@ -139,6 +143,7 @@ class DistributedIteratorTest(test.TestCase,
|
||||
with distribution.scope():
|
||||
for _ in range(3):
|
||||
iterator = iter(dist_dataset)
|
||||
_check_type_spec_structure(iterator)
|
||||
counter = f(iterator)
|
||||
|
||||
self.assertEqual(trace_count[0], 1)
|
||||
@ -173,6 +178,7 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
iterator = iter(ds)
|
||||
_check_type_spec_structure(iterator)
|
||||
type_spec = iterator.element_spec
|
||||
|
||||
@def_function.function(input_signature=[type_spec])
|
||||
@ -276,6 +282,7 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
with distribution.scope():
|
||||
iterator = iter(dist_dataset)
|
||||
_check_type_spec_structure(iterator)
|
||||
|
||||
spec = iterator._type_spec
|
||||
self.assertEqual(spec._input_workers, iterator._input_workers)
|
||||
@ -336,6 +343,7 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
with distribution.scope():
|
||||
iterator = iter(dist_dataset)
|
||||
_check_type_spec_structure(iterator)
|
||||
|
||||
spec = iterator._type_spec
|
||||
|
||||
@ -391,11 +399,18 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
|
||||
with distribution.scope():
|
||||
for _ in range(3):
|
||||
iterator = iter(dist_dataset)
|
||||
_check_type_spec_structure(iterator)
|
||||
counter = f(iterator)
|
||||
|
||||
self.assertEqual(trace_count[0], 1)
|
||||
self.assertEqual(counter, 5)
|
||||
|
||||
|
||||
def _check_type_spec_structure(x):
|
||||
"""Verifies that `x` has the same structure as its `TypeSpec`."""
|
||||
if isinstance(x, composite_tensor.CompositeTensor):
|
||||
nest.assert_same_structure(x, x._type_spec, expand_composites=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user