Fix incompatibilities between DistributedIterator and the corresponding DistributedIteratorSpec.

PiperOrigin-RevId: 309336208
Change-Id: I459daf891346285e7d37f5b43c9ae0a28f44fd21
This commit is contained in:
Edward Loper 2020-04-30 18:04:16 -07:00 committed by TensorFlower Gardener
parent 896a2700d1
commit c74afd7451
3 changed files with 36 additions and 6 deletions

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import sys
import six
@ -494,12 +495,13 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
def _component_specs(self):
specs = []
worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access
for i in range(len(worker_device_pairs)):
input_device, compute_devices = worker_device_pairs[i]
for i, (input_device, compute_devices) in enumerate(worker_device_pairs):
element_spec = nest.map_structure(
functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
specs.append(_SingleWorkerDatasetIteratorSpec(input_device,
compute_devices,
element_spec=
self._element_spec))
element_spec))
return specs
def _to_components(self, value):
@ -1140,7 +1142,7 @@ class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
def __init__(self, worker, devices, element_spec):
self._worker = worker
self._devices = devices
self._devices = tuple(device_util.canonicalize(d) for d in devices)
self._element_spec = element_spec
@property
@ -1148,7 +1150,7 @@ class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
return _SingleWorkerOwnedDatasetIterator
def _serialize(self):
return (self._worker, tuple(self._devices), self._element_spec)
return (self._worker, self._devices, self._element_spec)
@property
def _component_specs(self):
@ -1579,3 +1581,11 @@ def _create_distributed_tensor_spec(strategy, tensor_spec):
return values.PerReplicaSpec(*value_specs)
return nest.map_structure(_get_value_per_replica, tensor_spec)
def _replace_per_replica_spec(spec, i):
"""If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec."""
if isinstance(spec, values.PerReplicaSpec):
return spec._value_specs[i] # pylint: disable=protected-access
else:
return spec

View File

@ -44,6 +44,7 @@ from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@ -174,6 +175,10 @@ class DistributedIteratorTestBase(test.TestCase):
else:
self.skipTest("unsupported test combination")
if isinstance(iterator, composite_tensor.CompositeTensor):
nest.assert_same_structure(iterator, iterator._type_spec,
expand_composites=True)
if iteration_type == "get_next":
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
if not ops.executing_eagerly_outside_functions():

View File

@ -30,12 +30,14 @@ from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import values
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
from tensorflow.python.util import nest
class DistributedIteratorTest(test.TestCase,
@ -63,6 +65,7 @@ class DistributedIteratorTest(test.TestCase,
dist_dataset = distribution.experimental_distribute_dataset(dataset)
with distribution.scope():
iterator = iter(dist_dataset)
_check_type_spec_structure(iterator)
spec = iterator._type_spec
self.assertEqual(spec._input_workers, iterator._input_workers)
@ -95,6 +98,7 @@ class DistributedIteratorTest(test.TestCase,
dist_dataset = distribution.experimental_distribute_dataset(dataset)
with distribution.scope():
iterator = iter(dist_dataset)
_check_type_spec_structure(iterator)
spec = iterator._type_spec
@ -139,6 +143,7 @@ class DistributedIteratorTest(test.TestCase,
with distribution.scope():
for _ in range(3):
iterator = iter(dist_dataset)
_check_type_spec_structure(iterator)
counter = f(iterator)
self.assertEqual(trace_count[0], 1)
@ -173,6 +178,7 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn)
iterator = iter(ds)
_check_type_spec_structure(iterator)
type_spec = iterator.element_spec
@def_function.function(input_signature=[type_spec])
@ -276,6 +282,7 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
dist_dataset = distribution.experimental_distribute_dataset(dataset)
with distribution.scope():
iterator = iter(dist_dataset)
_check_type_spec_structure(iterator)
spec = iterator._type_spec
self.assertEqual(spec._input_workers, iterator._input_workers)
@ -336,6 +343,7 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
dist_dataset = distribution.experimental_distribute_dataset(dataset)
with distribution.scope():
iterator = iter(dist_dataset)
_check_type_spec_structure(iterator)
spec = iterator._type_spec
@ -391,11 +399,18 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
with distribution.scope():
for _ in range(3):
iterator = iter(dist_dataset)
_check_type_spec_structure(iterator)
counter = f(iterator)
self.assertEqual(trace_count[0], 1)
self.assertEqual(counter, 5)
def _check_type_spec_structure(x):
"""Verifies that `x` has the same structure as its `TypeSpec`."""
if isinstance(x, composite_tensor.CompositeTensor):
nest.assert_same_structure(x, x._type_spec, expand_composites=True)
if __name__ == "__main__":
test.main()