Fix incompatibilities between DistributedIterator and the corresponding DistributedIteratorSpec.
PiperOrigin-RevId: 309336208 Change-Id: I459daf891346285e7d37f5b43c9ae0a28f44fd21
This commit is contained in:
parent
896a2700d1
commit
c74afd7451
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import six
|
import six
|
||||||
@ -494,12 +495,13 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
|||||||
def _component_specs(self):
|
def _component_specs(self):
|
||||||
specs = []
|
specs = []
|
||||||
worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access
|
worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access
|
||||||
for i in range(len(worker_device_pairs)):
|
|
||||||
input_device, compute_devices = worker_device_pairs[i]
|
for i, (input_device, compute_devices) in enumerate(worker_device_pairs):
|
||||||
|
element_spec = nest.map_structure(
|
||||||
|
functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
|
||||||
specs.append(_SingleWorkerDatasetIteratorSpec(input_device,
|
specs.append(_SingleWorkerDatasetIteratorSpec(input_device,
|
||||||
compute_devices,
|
compute_devices,
|
||||||
element_spec=
|
element_spec))
|
||||||
self._element_spec))
|
|
||||||
return specs
|
return specs
|
||||||
|
|
||||||
def _to_components(self, value):
|
def _to_components(self, value):
|
||||||
@ -1140,7 +1142,7 @@ class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
|
|||||||
|
|
||||||
def __init__(self, worker, devices, element_spec):
|
def __init__(self, worker, devices, element_spec):
|
||||||
self._worker = worker
|
self._worker = worker
|
||||||
self._devices = devices
|
self._devices = tuple(device_util.canonicalize(d) for d in devices)
|
||||||
self._element_spec = element_spec
|
self._element_spec = element_spec
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1148,7 +1150,7 @@ class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
|
|||||||
return _SingleWorkerOwnedDatasetIterator
|
return _SingleWorkerOwnedDatasetIterator
|
||||||
|
|
||||||
def _serialize(self):
|
def _serialize(self):
|
||||||
return (self._worker, tuple(self._devices), self._element_spec)
|
return (self._worker, self._devices, self._element_spec)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _component_specs(self):
|
def _component_specs(self):
|
||||||
@ -1579,3 +1581,11 @@ def _create_distributed_tensor_spec(strategy, tensor_spec):
|
|||||||
return values.PerReplicaSpec(*value_specs)
|
return values.PerReplicaSpec(*value_specs)
|
||||||
|
|
||||||
return nest.map_structure(_get_value_per_replica, tensor_spec)
|
return nest.map_structure(_get_value_per_replica, tensor_spec)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_per_replica_spec(spec, i):
|
||||||
|
"""If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec."""
|
||||||
|
if isinstance(spec, values.PerReplicaSpec):
|
||||||
|
return spec._value_specs[i] # pylint: disable=protected-access
|
||||||
|
else:
|
||||||
|
return spec
|
||||||
|
@ -44,6 +44,7 @@ from tensorflow.python.distribute import values
|
|||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import composite_tensor
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
@ -174,6 +175,10 @@ class DistributedIteratorTestBase(test.TestCase):
|
|||||||
else:
|
else:
|
||||||
self.skipTest("unsupported test combination")
|
self.skipTest("unsupported test combination")
|
||||||
|
|
||||||
|
if isinstance(iterator, composite_tensor.CompositeTensor):
|
||||||
|
nest.assert_same_structure(iterator, iterator._type_spec,
|
||||||
|
expand_composites=True)
|
||||||
|
|
||||||
if iteration_type == "get_next":
|
if iteration_type == "get_next":
|
||||||
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
||||||
if not ops.executing_eagerly_outside_functions():
|
if not ops.executing_eagerly_outside_functions():
|
||||||
|
@ -30,12 +30,14 @@ from tensorflow.python.distribute import strategy_combinations
|
|||||||
from tensorflow.python.distribute import values
|
from tensorflow.python.distribute import values
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import composite_tensor
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
|
from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
|
||||||
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
class DistributedIteratorTest(test.TestCase,
|
class DistributedIteratorTest(test.TestCase,
|
||||||
@ -63,6 +65,7 @@ class DistributedIteratorTest(test.TestCase,
|
|||||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
iterator = iter(dist_dataset)
|
iterator = iter(dist_dataset)
|
||||||
|
_check_type_spec_structure(iterator)
|
||||||
|
|
||||||
spec = iterator._type_spec
|
spec = iterator._type_spec
|
||||||
self.assertEqual(spec._input_workers, iterator._input_workers)
|
self.assertEqual(spec._input_workers, iterator._input_workers)
|
||||||
@ -95,6 +98,7 @@ class DistributedIteratorTest(test.TestCase,
|
|||||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
iterator = iter(dist_dataset)
|
iterator = iter(dist_dataset)
|
||||||
|
_check_type_spec_structure(iterator)
|
||||||
|
|
||||||
spec = iterator._type_spec
|
spec = iterator._type_spec
|
||||||
|
|
||||||
@ -139,6 +143,7 @@ class DistributedIteratorTest(test.TestCase,
|
|||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
iterator = iter(dist_dataset)
|
iterator = iter(dist_dataset)
|
||||||
|
_check_type_spec_structure(iterator)
|
||||||
counter = f(iterator)
|
counter = f(iterator)
|
||||||
|
|
||||||
self.assertEqual(trace_count[0], 1)
|
self.assertEqual(trace_count[0], 1)
|
||||||
@ -173,6 +178,7 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
|
|||||||
ds = distribution.experimental_distribute_datasets_from_function(
|
ds = distribution.experimental_distribute_datasets_from_function(
|
||||||
dataset_fn)
|
dataset_fn)
|
||||||
iterator = iter(ds)
|
iterator = iter(ds)
|
||||||
|
_check_type_spec_structure(iterator)
|
||||||
type_spec = iterator.element_spec
|
type_spec = iterator.element_spec
|
||||||
|
|
||||||
@def_function.function(input_signature=[type_spec])
|
@def_function.function(input_signature=[type_spec])
|
||||||
@ -276,6 +282,7 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
|
|||||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
iterator = iter(dist_dataset)
|
iterator = iter(dist_dataset)
|
||||||
|
_check_type_spec_structure(iterator)
|
||||||
|
|
||||||
spec = iterator._type_spec
|
spec = iterator._type_spec
|
||||||
self.assertEqual(spec._input_workers, iterator._input_workers)
|
self.assertEqual(spec._input_workers, iterator._input_workers)
|
||||||
@ -336,6 +343,7 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
|
|||||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
iterator = iter(dist_dataset)
|
iterator = iter(dist_dataset)
|
||||||
|
_check_type_spec_structure(iterator)
|
||||||
|
|
||||||
spec = iterator._type_spec
|
spec = iterator._type_spec
|
||||||
|
|
||||||
@ -391,11 +399,18 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
|
|||||||
with distribution.scope():
|
with distribution.scope():
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
iterator = iter(dist_dataset)
|
iterator = iter(dist_dataset)
|
||||||
|
_check_type_spec_structure(iterator)
|
||||||
counter = f(iterator)
|
counter = f(iterator)
|
||||||
|
|
||||||
self.assertEqual(trace_count[0], 1)
|
self.assertEqual(trace_count[0], 1)
|
||||||
self.assertEqual(counter, 5)
|
self.assertEqual(counter, 5)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_type_spec_structure(x):
|
||||||
|
"""Verifies that `x` has the same structure as its `TypeSpec`."""
|
||||||
|
if isinstance(x, composite_tensor.CompositeTensor):
|
||||||
|
nest.assert_same_structure(x, x._type_spec, expand_composites=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user