Move the logic to decide whether to enable partial batch handling into dataset

When partial batch handling is enabled, we should change the element spec to
have a unknown batch dimension. This will be done in a later change.

PiperOrigin-RevId: 332398721
Change-Id: I3fe09c1a12fc7b393d94930aa0da4bddb7bb0666
This commit is contained in:
Ran Chen 2020-09-18 00:19:26 -07:00 committed by TensorFlower Gardener
parent 40a443b85b
commit 6c102af171

View File

@ -552,41 +552,67 @@ def _get_next_as_optional(iterator, strategy, return_per_replica=False):
return global_has_value, replicas return global_has_value, replicas
def _is_statically_shaped(element_spec): def _is_statically_shaped(tensor_class, shape):
"""Test if an iterator output is statically shaped. """Test if an iterator output is statically shaped.
For sparse and ragged tensors this only tests the batch dimension. For sparse and ragged tensors this only tests the batch dimension.
Args: Args:
element_spec: a nest structure of `tf.TypeSpec`. The element spec of the tensor_class: a class from an iterator.output_classes list.
dataset of the iterator. shape: a TensorShape from an iterator.output_shapes list.
Returns: Returns:
True if the shape is static, false otherwise. True if the shape is static, false otherwise.
""" """
if (tensor_class == sparse_tensor.SparseTensor or
isinstance(tensor_class, ragged_tensor.RaggedTensorSpec)):
# For sparse or ragged tensor, we should only check the first
# dimension in order to get_next_as_optional. This is because
# when these tensors get batched by dataset only the batch dimension
# is set.
if shape.rank > 0 and shape.as_list()[0] is None:
return False
return True
return shape.is_fully_defined()
for spec in nest.flatten(element_spec):
shape = spec.shape def _get_static_shape(iterators):
if isinstance( """Returns a boolean indicating if the input is fully defined."""
spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)): static_shape = True
# For sparse or ragged tensor, we should only check the first for iterator in iterators:
# dimension in order to get_next_as_optional. This is because if not isinstance(iterator, (_SingleWorkerOwnedDatasetIterator,
# when these tensors get batched by dataset only the batch dimension _SingleWorkerDatasetIterator)):
# is set. continue
if shape.rank > 0 and shape.as_list()[0] is None: flattened = zip(nest.flatten(iterator.output_shapes),
return False nest.flatten(iterator.output_classes))
else: for output_shape, output_class in flattened:
if not shape.is_fully_defined(): if not _is_statically_shaped(output_class, output_shape):
return False static_shape = False
return True break
return static_shape
class DistributedIteratorBase(DistributedIteratorInterface): class DistributedIteratorBase(DistributedIteratorInterface):
"""Common implementation for all input iterators.""" """Common implementation for all input iterators."""
# pylint: disable=super-init-not-called # pylint: disable=super-init-not-called
def __init__(self, input_workers, iterators, strategy, def __init__(self, input_workers, iterators, strategy):
enable_get_next_as_optional): static_shape = _get_static_shape(iterators)
# TODO(b/133073708): we currently need a flag to control the usage because
# there is a performance difference between get_next() and
# get_next_as_optional(). And we only enable get_next_as_optional when the
# output shapes are not static.
#
# TODO(rxsang): We want to always enable the get_next_as_optional behavior
# when user passed input_fn instead of dataset.
if getattr(
strategy.extended, "experimental_enable_get_next_as_optional", False):
self._enable_get_next_as_optional = (
not static_shape) or strategy.extended._in_multi_worker_mode()
else:
self._enable_get_next_as_optional = False
assert isinstance(input_workers, InputWorkers) assert isinstance(input_workers, InputWorkers)
if not input_workers.worker_devices: if not input_workers.worker_devices:
raise ValueError("Should have at least one worker for input iterator.") raise ValueError("Should have at least one worker for input iterator.")
@ -594,7 +620,6 @@ class DistributedIteratorBase(DistributedIteratorInterface):
self._iterators = iterators self._iterators = iterators
self._input_workers = input_workers self._input_workers = input_workers
self._strategy = strategy self._strategy = strategy
self._enable_get_next_as_optional = enable_get_next_as_optional
def next(self): def next(self):
return self.__next__() return self.__next__()
@ -728,13 +753,9 @@ class DistributedIteratorV1(DistributedIteratorBase):
class DistributedIteratorSpec(type_spec.TypeSpec): class DistributedIteratorSpec(type_spec.TypeSpec):
"""Type specification for `DistributedIterator`.""" """Type specification for `DistributedIterator`."""
__slots__ = [ __slots__ = ["_input_workers", "_element_spec", "_strategy"]
"_input_workers", "_element_spec", "_strategy",
"_enable_get_next_as_optional"
]
def __init__(self, input_workers, element_spec, strategy, def __init__(self, input_workers, element_spec, strategy):
enable_get_next_as_optional):
# We don't want to allow deserialization of this class because we don't # We don't want to allow deserialization of this class because we don't
# serialize the strategy object. Currently the only places where # serialize the strategy object. Currently the only places where
# _deserialize is called is when we save/restore using SavedModels. # _deserialize is called is when we save/restore using SavedModels.
@ -745,7 +766,6 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
self._input_workers = input_workers self._input_workers = input_workers
self._element_spec = element_spec self._element_spec = element_spec
self._strategy = strategy self._strategy = strategy
self._enable_get_next_as_optional = enable_get_next_as_optional
@property @property
def value_type(self): def value_type(self):
@ -786,8 +806,7 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
lambda a, b: a.most_specific_compatible_type(b), self._element_spec, lambda a, b: a.most_specific_compatible_type(b), self._element_spec,
other._element_spec) other._element_spec)
return DistributedIteratorSpec(self._input_workers, element_spec, return DistributedIteratorSpec(self._input_workers, element_spec,
self._strategy, self._strategy)
self._enable_get_next_as_optional)
@property @property
def _component_specs(self): def _component_specs(self):
@ -806,41 +825,32 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
return value._iterators # pylint: disable=protected-access return value._iterators # pylint: disable=protected-access
def _from_components(self, components): def _from_components(self, components):
return DistributedIterator( return DistributedIterator(input_workers=self._input_workers,
input_workers=self._input_workers, iterators=None,
iterators=None, components=components,
components=components, element_spec=self._element_spec,
element_spec=self._element_spec, strategy=self._strategy)
strategy=self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
@staticmethod @staticmethod
def from_value(value): def from_value(value):
# pylint: disable=protected-access # pylint: disable=protected-access
return DistributedIteratorSpec(value._input_workers, value._element_spec, return DistributedIteratorSpec(value._input_workers, value._element_spec,
value._strategy, value._strategy)
value._enable_get_next_as_optional)
def _with_tensor_ranks_only(self): def _with_tensor_ranks_only(self):
element_spec = nest.map_structure( element_spec = nest.map_structure(
lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access
self._element_spec) self._element_spec)
return DistributedIteratorSpec(self._input_workers, element_spec, return DistributedIteratorSpec(self._input_workers, element_spec,
self._strategy, self._strategy)
self._enable_get_next_as_optional)
class DistributedIterator(DistributedIteratorBase, class DistributedIterator(DistributedIteratorBase,
composite_tensor.CompositeTensor): composite_tensor.CompositeTensor):
"""Input Iterator for a distributed dataset.""" """Input Iterator for a distributed dataset."""
def __init__(self, def __init__(self, input_workers=None, iterators=None, strategy=None,
input_workers=None, components=None, element_spec=None):
iterators=None,
strategy=None,
components=None,
element_spec=None,
enable_get_next_as_optional=False):
if input_workers is None: if input_workers is None:
raise ValueError("`input_workers` should be " raise ValueError("`input_workers` should be "
"provided.") "provided.")
@ -855,15 +865,20 @@ class DistributedIterator(DistributedIteratorBase,
self._element_spec = element_spec self._element_spec = element_spec
self._input_workers = input_workers self._input_workers = input_workers
self._iterators = components self._iterators = components
static_shape = _get_static_shape(self._iterators)
self._strategy = strategy self._strategy = strategy
self._enable_get_next_as_optional = enable_get_next_as_optional if getattr(strategy.extended,
"experimental_enable_get_next_as_optional", False):
self._enable_get_next_as_optional = (
not static_shape) or strategy.extended._in_multi_worker_mode()
else:
self._enable_get_next_as_optional = False
else: else:
if (components is not None and element_spec is not None): if (components is not None and element_spec is not None):
raise ValueError(error_message) raise ValueError(error_message)
super(DistributedIterator, super(DistributedIterator, self).__init__(input_workers, iterators,
self).__init__(input_workers, iterators, strategy, strategy)
enable_get_next_as_optional)
@property @property
def element_spec(self): def element_spec(self):
@ -871,9 +886,9 @@ class DistributedIterator(DistributedIteratorBase,
@property @property
def _type_spec(self): def _type_spec(self):
return DistributedIteratorSpec(self._input_workers, self.element_spec, return DistributedIteratorSpec(self._input_workers,
self._strategy, self.element_spec,
self._enable_get_next_as_optional) self._strategy)
class _IterableInput(DistributedDatasetInterface): class _IterableInput(DistributedDatasetInterface):
@ -986,8 +1001,6 @@ class DistributedDataset(_IterableInput):
self._input_workers = input_workers self._input_workers = input_workers
self._strategy = strategy self._strategy = strategy
self._enable_get_next_as_optional = _enable_get_next_as_optional(
self._strategy, dataset.element_spec)
self._element_spec = _create_distributed_tensor_spec(self._strategy, self._element_spec = _create_distributed_tensor_spec(self._strategy,
dataset.element_spec) # pylint: disable=protected-access dataset.element_spec) # pylint: disable=protected-access
@ -1006,17 +1019,11 @@ class DistributedDataset(_IterableInput):
self._input_workers, self._input_workers,
enable_legacy_iterators) enable_legacy_iterators)
if enable_legacy_iterators: if enable_legacy_iterators:
iterator = DistributedIteratorV1( iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
self._input_workers, self._strategy)
worker_iterators,
self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
else: else:
iterator = DistributedIterator( iterator = DistributedIterator(self._input_workers, worker_iterators,
self._input_workers, self._strategy)
worker_iterators,
self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
iterator._element_spec = self.element_spec # pylint: disable=protected-access iterator._element_spec = self.element_spec # pylint: disable=protected-access
# When async eager is enabled, sometimes the iterator may not finish # When async eager is enabled, sometimes the iterator may not finish
@ -1097,8 +1104,7 @@ class DistributedDatasetV1(DistributedDataset):
self._input_workers, self._input_workers,
True) True)
iterator = DistributedIteratorV1(self._input_workers, worker_iterators, iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
self._strategy, self._strategy)
self._enable_get_next_as_optional)
iterator._element_spec = self.element_spec # pylint: disable=protected-access iterator._element_spec = self.element_spec # pylint: disable=protected-access
# When async eager is enabled, sometimes the iterator may not finish # When async eager is enabled, sometimes the iterator may not finish
@ -1150,8 +1156,6 @@ class DistributedDatasetsFromFunction(_IterableInput):
_create_datasets_per_worker_with_input_context(self._input_contexts, _create_datasets_per_worker_with_input_context(self._input_contexts,
self._input_workers, self._input_workers,
dataset_fn)) dataset_fn))
self._enable_get_next_as_optional = _enable_get_next_as_optional(
self._strategy, element_spec)
self._element_spec = _create_distributed_tensor_spec( self._element_spec = _create_distributed_tensor_spec(
self._strategy, element_spec) self._strategy, element_spec)
@ -1170,17 +1174,11 @@ class DistributedDatasetsFromFunction(_IterableInput):
enable_legacy_iterators) enable_legacy_iterators)
if enable_legacy_iterators: if enable_legacy_iterators:
iterator = DistributedIteratorV1( iterator = DistributedIteratorV1(self._input_workers, iterators,
self._input_workers, self._strategy)
iterators,
self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
else: else:
iterator = DistributedIterator( iterator = DistributedIterator(self._input_workers, iterators,
self._input_workers, self._strategy)
iterators,
self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
iterator._element_spec = self._element_spec # pylint: disable=protected-access iterator._element_spec = self._element_spec # pylint: disable=protected-access
# When async eager is enabled, sometimes the iterator may not finish # When async eager is enabled, sometimes the iterator may not finish
@ -1227,8 +1225,7 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
iterators = _create_iterators_per_worker(self._datasets, iterators = _create_iterators_per_worker(self._datasets,
self._input_workers, True) self._input_workers, True)
iterator = DistributedIteratorV1(self._input_workers, iterators, iterator = DistributedIteratorV1(self._input_workers, iterators,
self._strategy, self._strategy)
self._enable_get_next_as_optional)
iterator._element_spec = self._element_spec # pylint: disable=protected-access iterator._element_spec = self._element_spec # pylint: disable=protected-access
# When async eager is enabled, sometimes the iterator may not finish # When async eager is enabled, sometimes the iterator may not finish
@ -1292,8 +1289,9 @@ class InputFunctionIterator(DistributedIteratorV1):
"input_fn must return a tf.data.Dataset or a callable.") "input_fn must return a tf.data.Dataset or a callable.")
iterators.append(iterator) iterators.append(iterator)
super(InputFunctionIterator, self).__init__( super(InputFunctionIterator, self).__init__(input_workers, iterators,
input_workers, iterators, strategy, enable_get_next_as_optional=False) strategy)
self._enable_get_next_as_optional = False
# TODO(anjalisridhar): This class will soon be removed and users should move # TODO(anjalisridhar): This class will soon be removed and users should move
@ -1332,9 +1330,10 @@ class DatasetIterator(DistributedIteratorV1):
input_context=input_context) input_context=input_context)
worker_iterators = _create_iterators_per_worker( worker_iterators = _create_iterators_per_worker(
dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access
super(DatasetIterator, super(DatasetIterator, self).__init__(
self).__init__(input_workers, worker_iterators, strategy, input_workers,
dist_dataset._enable_get_next_as_optional) # pylint: disable=protected-access worker_iterators, # pylint: disable=protected-access
strategy)
self._element_spec = dist_dataset.element_spec self._element_spec = dist_dataset.element_spec
@ -1953,19 +1952,3 @@ def _replace_per_replica_spec(spec, i):
return spec._value_specs[i] # pylint: disable=protected-access return spec._value_specs[i] # pylint: disable=protected-access
else: else:
return spec return spec
def _enable_get_next_as_optional(strategy, element_spec):
"""Returns whether to enable using partial batch handling."""
# TODO(b/133073708): we currently need a flag to control the usage because
# there is a performance difference between get_next() and
# get_next_as_optional(). And we only enable get_next_as_optional when the
# output shapes are not static.
#
# TODO(rxsang): We want to always enable the get_next_as_optional behavior
# when user passed input_fn instead of dataset.
if not getattr(strategy.extended, "experimental_enable_get_next_as_optional",
False):
return False
return not _is_statically_shaped(
element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access