Move the logic to decide whether to enable partial batch handling into dataset

When partial batch handling is enabled, we should change the element spec to
have a unknown batch dimension. This will be done in a later change.

PiperOrigin-RevId: 332398721
Change-Id: I3fe09c1a12fc7b393d94930aa0da4bddb7bb0666
This commit is contained in:
Ran Chen 2020-09-18 00:19:26 -07:00 committed by TensorFlower Gardener
parent 40a443b85b
commit 6c102af171

View File

@ -552,41 +552,67 @@ def _get_next_as_optional(iterator, strategy, return_per_replica=False):
return global_has_value, replicas
def _is_statically_shaped(element_spec):
def _is_statically_shaped(tensor_class, shape):
"""Test if an iterator output is statically shaped.
For sparse and ragged tensors this only tests the batch dimension.
Args:
element_spec: a nest structure of `tf.TypeSpec`. The element spec of the
dataset of the iterator.
tensor_class: a class from an iterator.output_classes list.
shape: a TensorShape from an iterator.output_shapes list.
Returns:
True if the shape is static, false otherwise.
"""
if (tensor_class == sparse_tensor.SparseTensor or
isinstance(tensor_class, ragged_tensor.RaggedTensorSpec)):
# For sparse or ragged tensor, we should only check the first
# dimension in order to get_next_as_optional. This is because
# when these tensors get batched by dataset only the batch dimension
# is set.
if shape.rank > 0 and shape.as_list()[0] is None:
return False
return True
return shape.is_fully_defined()
for spec in nest.flatten(element_spec):
shape = spec.shape
if isinstance(
spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)):
# For sparse or ragged tensor, we should only check the first
# dimension in order to get_next_as_optional. This is because
# when these tensors get batched by dataset only the batch dimension
# is set.
if shape.rank > 0 and shape.as_list()[0] is None:
return False
else:
if not shape.is_fully_defined():
return False
return True
def _get_static_shape(iterators):
"""Returns a boolean indicating if the input is fully defined."""
static_shape = True
for iterator in iterators:
if not isinstance(iterator, (_SingleWorkerOwnedDatasetIterator,
_SingleWorkerDatasetIterator)):
continue
flattened = zip(nest.flatten(iterator.output_shapes),
nest.flatten(iterator.output_classes))
for output_shape, output_class in flattened:
if not _is_statically_shaped(output_class, output_shape):
static_shape = False
break
return static_shape
class DistributedIteratorBase(DistributedIteratorInterface):
"""Common implementation for all input iterators."""
# pylint: disable=super-init-not-called
def __init__(self, input_workers, iterators, strategy,
enable_get_next_as_optional):
def __init__(self, input_workers, iterators, strategy):
static_shape = _get_static_shape(iterators)
# TODO(b/133073708): we currently need a flag to control the usage because
# there is a performance difference between get_next() and
# get_next_as_optional(). And we only enable get_next_as_optional when the
# output shapes are not static.
#
# TODO(rxsang): We want to always enable the get_next_as_optional behavior
# when user passed input_fn instead of dataset.
if getattr(
strategy.extended, "experimental_enable_get_next_as_optional", False):
self._enable_get_next_as_optional = (
not static_shape) or strategy.extended._in_multi_worker_mode()
else:
self._enable_get_next_as_optional = False
assert isinstance(input_workers, InputWorkers)
if not input_workers.worker_devices:
raise ValueError("Should have at least one worker for input iterator.")
@ -594,7 +620,6 @@ class DistributedIteratorBase(DistributedIteratorInterface):
self._iterators = iterators
self._input_workers = input_workers
self._strategy = strategy
self._enable_get_next_as_optional = enable_get_next_as_optional
def next(self):
return self.__next__()
@ -728,13 +753,9 @@ class DistributedIteratorV1(DistributedIteratorBase):
class DistributedIteratorSpec(type_spec.TypeSpec):
"""Type specification for `DistributedIterator`."""
__slots__ = [
"_input_workers", "_element_spec", "_strategy",
"_enable_get_next_as_optional"
]
__slots__ = ["_input_workers", "_element_spec", "_strategy"]
def __init__(self, input_workers, element_spec, strategy,
enable_get_next_as_optional):
def __init__(self, input_workers, element_spec, strategy):
# We don't want to allow deserialization of this class because we don't
# serialize the strategy object. Currently the only places where
# _deserialize is called is when we save/restore using SavedModels.
@ -745,7 +766,6 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
self._input_workers = input_workers
self._element_spec = element_spec
self._strategy = strategy
self._enable_get_next_as_optional = enable_get_next_as_optional
@property
def value_type(self):
@ -786,8 +806,7 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
lambda a, b: a.most_specific_compatible_type(b), self._element_spec,
other._element_spec)
return DistributedIteratorSpec(self._input_workers, element_spec,
self._strategy,
self._enable_get_next_as_optional)
self._strategy)
@property
def _component_specs(self):
@ -806,41 +825,32 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
return value._iterators # pylint: disable=protected-access
def _from_components(self, components):
return DistributedIterator(
input_workers=self._input_workers,
iterators=None,
components=components,
element_spec=self._element_spec,
strategy=self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
return DistributedIterator(input_workers=self._input_workers,
iterators=None,
components=components,
element_spec=self._element_spec,
strategy=self._strategy)
@staticmethod
def from_value(value):
# pylint: disable=protected-access
return DistributedIteratorSpec(value._input_workers, value._element_spec,
value._strategy,
value._enable_get_next_as_optional)
value._strategy)
def _with_tensor_ranks_only(self):
element_spec = nest.map_structure(
lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access
self._element_spec)
return DistributedIteratorSpec(self._input_workers, element_spec,
self._strategy,
self._enable_get_next_as_optional)
self._strategy)
class DistributedIterator(DistributedIteratorBase,
composite_tensor.CompositeTensor):
"""Input Iterator for a distributed dataset."""
def __init__(self,
input_workers=None,
iterators=None,
strategy=None,
components=None,
element_spec=None,
enable_get_next_as_optional=False):
def __init__(self, input_workers=None, iterators=None, strategy=None,
components=None, element_spec=None):
if input_workers is None:
raise ValueError("`input_workers` should be "
"provided.")
@ -855,15 +865,20 @@ class DistributedIterator(DistributedIteratorBase,
self._element_spec = element_spec
self._input_workers = input_workers
self._iterators = components
static_shape = _get_static_shape(self._iterators)
self._strategy = strategy
self._enable_get_next_as_optional = enable_get_next_as_optional
if getattr(strategy.extended,
"experimental_enable_get_next_as_optional", False):
self._enable_get_next_as_optional = (
not static_shape) or strategy.extended._in_multi_worker_mode()
else:
self._enable_get_next_as_optional = False
else:
if (components is not None and element_spec is not None):
raise ValueError(error_message)
super(DistributedIterator,
self).__init__(input_workers, iterators, strategy,
enable_get_next_as_optional)
super(DistributedIterator, self).__init__(input_workers, iterators,
strategy)
@property
def element_spec(self):
@ -871,9 +886,9 @@ class DistributedIterator(DistributedIteratorBase,
@property
def _type_spec(self):
return DistributedIteratorSpec(self._input_workers, self.element_spec,
self._strategy,
self._enable_get_next_as_optional)
return DistributedIteratorSpec(self._input_workers,
self.element_spec,
self._strategy)
class _IterableInput(DistributedDatasetInterface):
@ -986,8 +1001,6 @@ class DistributedDataset(_IterableInput):
self._input_workers = input_workers
self._strategy = strategy
self._enable_get_next_as_optional = _enable_get_next_as_optional(
self._strategy, dataset.element_spec)
self._element_spec = _create_distributed_tensor_spec(self._strategy,
dataset.element_spec) # pylint: disable=protected-access
@ -1006,17 +1019,11 @@ class DistributedDataset(_IterableInput):
self._input_workers,
enable_legacy_iterators)
if enable_legacy_iterators:
iterator = DistributedIteratorV1(
self._input_workers,
worker_iterators,
self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
self._strategy)
else:
iterator = DistributedIterator(
self._input_workers,
worker_iterators,
self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
iterator = DistributedIterator(self._input_workers, worker_iterators,
self._strategy)
iterator._element_spec = self.element_spec # pylint: disable=protected-access
# When async eager is enabled, sometimes the iterator may not finish
@ -1097,8 +1104,7 @@ class DistributedDatasetV1(DistributedDataset):
self._input_workers,
True)
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
self._strategy,
self._enable_get_next_as_optional)
self._strategy)
iterator._element_spec = self.element_spec # pylint: disable=protected-access
# When async eager is enabled, sometimes the iterator may not finish
@ -1150,8 +1156,6 @@ class DistributedDatasetsFromFunction(_IterableInput):
_create_datasets_per_worker_with_input_context(self._input_contexts,
self._input_workers,
dataset_fn))
self._enable_get_next_as_optional = _enable_get_next_as_optional(
self._strategy, element_spec)
self._element_spec = _create_distributed_tensor_spec(
self._strategy, element_spec)
@ -1170,17 +1174,11 @@ class DistributedDatasetsFromFunction(_IterableInput):
enable_legacy_iterators)
if enable_legacy_iterators:
iterator = DistributedIteratorV1(
self._input_workers,
iterators,
self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
iterator = DistributedIteratorV1(self._input_workers, iterators,
self._strategy)
else:
iterator = DistributedIterator(
self._input_workers,
iterators,
self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
iterator = DistributedIterator(self._input_workers, iterators,
self._strategy)
iterator._element_spec = self._element_spec # pylint: disable=protected-access
# When async eager is enabled, sometimes the iterator may not finish
@ -1227,8 +1225,7 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
iterators = _create_iterators_per_worker(self._datasets,
self._input_workers, True)
iterator = DistributedIteratorV1(self._input_workers, iterators,
self._strategy,
self._enable_get_next_as_optional)
self._strategy)
iterator._element_spec = self._element_spec # pylint: disable=protected-access
# When async eager is enabled, sometimes the iterator may not finish
@ -1292,8 +1289,9 @@ class InputFunctionIterator(DistributedIteratorV1):
"input_fn must return a tf.data.Dataset or a callable.")
iterators.append(iterator)
super(InputFunctionIterator, self).__init__(
input_workers, iterators, strategy, enable_get_next_as_optional=False)
super(InputFunctionIterator, self).__init__(input_workers, iterators,
strategy)
self._enable_get_next_as_optional = False
# TODO(anjalisridhar): This class will soon be removed and users should move
@ -1332,9 +1330,10 @@ class DatasetIterator(DistributedIteratorV1):
input_context=input_context)
worker_iterators = _create_iterators_per_worker(
dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access
super(DatasetIterator,
self).__init__(input_workers, worker_iterators, strategy,
dist_dataset._enable_get_next_as_optional) # pylint: disable=protected-access
super(DatasetIterator, self).__init__(
input_workers,
worker_iterators, # pylint: disable=protected-access
strategy)
self._element_spec = dist_dataset.element_spec
@ -1953,19 +1952,3 @@ def _replace_per_replica_spec(spec, i):
return spec._value_specs[i] # pylint: disable=protected-access
else:
return spec
def _enable_get_next_as_optional(strategy, element_spec):
"""Returns whether to enable using partial batch handling."""
# TODO(b/133073708): we currently need a flag to control the usage because
# there is a performance difference between get_next() and
# get_next_as_optional(). And we only enable get_next_as_optional when the
# output shapes are not static.
#
# TODO(rxsang): We want to always enable the get_next_as_optional behavior
# when user passed input_fn instead of dataset.
if not getattr(strategy.extended, "experimental_enable_get_next_as_optional",
False):
return False
return not _is_statically_shaped(
element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access