Move the logic to decide whether to enable partial batch handling into dataset
When partial batch handling is enabled, we should change the element spec to have a unknown batch dimension. This will be done in a later change. PiperOrigin-RevId: 332398721 Change-Id: I3fe09c1a12fc7b393d94930aa0da4bddb7bb0666
This commit is contained in:
parent
40a443b85b
commit
6c102af171
@ -552,41 +552,67 @@ def _get_next_as_optional(iterator, strategy, return_per_replica=False):
|
||||
return global_has_value, replicas
|
||||
|
||||
|
||||
def _is_statically_shaped(element_spec):
|
||||
def _is_statically_shaped(tensor_class, shape):
|
||||
"""Test if an iterator output is statically shaped.
|
||||
|
||||
For sparse and ragged tensors this only tests the batch dimension.
|
||||
|
||||
Args:
|
||||
element_spec: a nest structure of `tf.TypeSpec`. The element spec of the
|
||||
dataset of the iterator.
|
||||
tensor_class: a class from an iterator.output_classes list.
|
||||
shape: a TensorShape from an iterator.output_shapes list.
|
||||
|
||||
Returns:
|
||||
True if the shape is static, false otherwise.
|
||||
"""
|
||||
if (tensor_class == sparse_tensor.SparseTensor or
|
||||
isinstance(tensor_class, ragged_tensor.RaggedTensorSpec)):
|
||||
# For sparse or ragged tensor, we should only check the first
|
||||
# dimension in order to get_next_as_optional. This is because
|
||||
# when these tensors get batched by dataset only the batch dimension
|
||||
# is set.
|
||||
if shape.rank > 0 and shape.as_list()[0] is None:
|
||||
return False
|
||||
return True
|
||||
return shape.is_fully_defined()
|
||||
|
||||
for spec in nest.flatten(element_spec):
|
||||
shape = spec.shape
|
||||
if isinstance(
|
||||
spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)):
|
||||
# For sparse or ragged tensor, we should only check the first
|
||||
# dimension in order to get_next_as_optional. This is because
|
||||
# when these tensors get batched by dataset only the batch dimension
|
||||
# is set.
|
||||
if shape.rank > 0 and shape.as_list()[0] is None:
|
||||
return False
|
||||
else:
|
||||
if not shape.is_fully_defined():
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_static_shape(iterators):
|
||||
"""Returns a boolean indicating if the input is fully defined."""
|
||||
static_shape = True
|
||||
for iterator in iterators:
|
||||
if not isinstance(iterator, (_SingleWorkerOwnedDatasetIterator,
|
||||
_SingleWorkerDatasetIterator)):
|
||||
continue
|
||||
flattened = zip(nest.flatten(iterator.output_shapes),
|
||||
nest.flatten(iterator.output_classes))
|
||||
for output_shape, output_class in flattened:
|
||||
if not _is_statically_shaped(output_class, output_shape):
|
||||
static_shape = False
|
||||
break
|
||||
return static_shape
|
||||
|
||||
|
||||
class DistributedIteratorBase(DistributedIteratorInterface):
|
||||
"""Common implementation for all input iterators."""
|
||||
|
||||
# pylint: disable=super-init-not-called
|
||||
def __init__(self, input_workers, iterators, strategy,
|
||||
enable_get_next_as_optional):
|
||||
def __init__(self, input_workers, iterators, strategy):
|
||||
static_shape = _get_static_shape(iterators)
|
||||
|
||||
# TODO(b/133073708): we currently need a flag to control the usage because
|
||||
# there is a performance difference between get_next() and
|
||||
# get_next_as_optional(). And we only enable get_next_as_optional when the
|
||||
# output shapes are not static.
|
||||
#
|
||||
# TODO(rxsang): We want to always enable the get_next_as_optional behavior
|
||||
# when user passed input_fn instead of dataset.
|
||||
if getattr(
|
||||
strategy.extended, "experimental_enable_get_next_as_optional", False):
|
||||
self._enable_get_next_as_optional = (
|
||||
not static_shape) or strategy.extended._in_multi_worker_mode()
|
||||
else:
|
||||
self._enable_get_next_as_optional = False
|
||||
|
||||
assert isinstance(input_workers, InputWorkers)
|
||||
if not input_workers.worker_devices:
|
||||
raise ValueError("Should have at least one worker for input iterator.")
|
||||
@ -594,7 +620,6 @@ class DistributedIteratorBase(DistributedIteratorInterface):
|
||||
self._iterators = iterators
|
||||
self._input_workers = input_workers
|
||||
self._strategy = strategy
|
||||
self._enable_get_next_as_optional = enable_get_next_as_optional
|
||||
|
||||
def next(self):
|
||||
return self.__next__()
|
||||
@ -728,13 +753,9 @@ class DistributedIteratorV1(DistributedIteratorBase):
|
||||
class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||
"""Type specification for `DistributedIterator`."""
|
||||
|
||||
__slots__ = [
|
||||
"_input_workers", "_element_spec", "_strategy",
|
||||
"_enable_get_next_as_optional"
|
||||
]
|
||||
__slots__ = ["_input_workers", "_element_spec", "_strategy"]
|
||||
|
||||
def __init__(self, input_workers, element_spec, strategy,
|
||||
enable_get_next_as_optional):
|
||||
def __init__(self, input_workers, element_spec, strategy):
|
||||
# We don't want to allow deserialization of this class because we don't
|
||||
# serialize the strategy object. Currently the only places where
|
||||
# _deserialize is called is when we save/restore using SavedModels.
|
||||
@ -745,7 +766,6 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||
self._input_workers = input_workers
|
||||
self._element_spec = element_spec
|
||||
self._strategy = strategy
|
||||
self._enable_get_next_as_optional = enable_get_next_as_optional
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
@ -786,8 +806,7 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||
lambda a, b: a.most_specific_compatible_type(b), self._element_spec,
|
||||
other._element_spec)
|
||||
return DistributedIteratorSpec(self._input_workers, element_spec,
|
||||
self._strategy,
|
||||
self._enable_get_next_as_optional)
|
||||
self._strategy)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
@ -806,41 +825,32 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||
return value._iterators # pylint: disable=protected-access
|
||||
|
||||
def _from_components(self, components):
|
||||
return DistributedIterator(
|
||||
input_workers=self._input_workers,
|
||||
iterators=None,
|
||||
components=components,
|
||||
element_spec=self._element_spec,
|
||||
strategy=self._strategy,
|
||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
||||
return DistributedIterator(input_workers=self._input_workers,
|
||||
iterators=None,
|
||||
components=components,
|
||||
element_spec=self._element_spec,
|
||||
strategy=self._strategy)
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
# pylint: disable=protected-access
|
||||
return DistributedIteratorSpec(value._input_workers, value._element_spec,
|
||||
value._strategy,
|
||||
value._enable_get_next_as_optional)
|
||||
value._strategy)
|
||||
|
||||
def _with_tensor_ranks_only(self):
|
||||
element_spec = nest.map_structure(
|
||||
lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access
|
||||
self._element_spec)
|
||||
return DistributedIteratorSpec(self._input_workers, element_spec,
|
||||
self._strategy,
|
||||
self._enable_get_next_as_optional)
|
||||
self._strategy)
|
||||
|
||||
|
||||
class DistributedIterator(DistributedIteratorBase,
|
||||
composite_tensor.CompositeTensor):
|
||||
"""Input Iterator for a distributed dataset."""
|
||||
|
||||
def __init__(self,
|
||||
input_workers=None,
|
||||
iterators=None,
|
||||
strategy=None,
|
||||
components=None,
|
||||
element_spec=None,
|
||||
enable_get_next_as_optional=False):
|
||||
def __init__(self, input_workers=None, iterators=None, strategy=None,
|
||||
components=None, element_spec=None):
|
||||
if input_workers is None:
|
||||
raise ValueError("`input_workers` should be "
|
||||
"provided.")
|
||||
@ -855,15 +865,20 @@ class DistributedIterator(DistributedIteratorBase,
|
||||
self._element_spec = element_spec
|
||||
self._input_workers = input_workers
|
||||
self._iterators = components
|
||||
static_shape = _get_static_shape(self._iterators)
|
||||
self._strategy = strategy
|
||||
self._enable_get_next_as_optional = enable_get_next_as_optional
|
||||
if getattr(strategy.extended,
|
||||
"experimental_enable_get_next_as_optional", False):
|
||||
self._enable_get_next_as_optional = (
|
||||
not static_shape) or strategy.extended._in_multi_worker_mode()
|
||||
else:
|
||||
self._enable_get_next_as_optional = False
|
||||
else:
|
||||
if (components is not None and element_spec is not None):
|
||||
raise ValueError(error_message)
|
||||
|
||||
super(DistributedIterator,
|
||||
self).__init__(input_workers, iterators, strategy,
|
||||
enable_get_next_as_optional)
|
||||
super(DistributedIterator, self).__init__(input_workers, iterators,
|
||||
strategy)
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
@ -871,9 +886,9 @@ class DistributedIterator(DistributedIteratorBase,
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return DistributedIteratorSpec(self._input_workers, self.element_spec,
|
||||
self._strategy,
|
||||
self._enable_get_next_as_optional)
|
||||
return DistributedIteratorSpec(self._input_workers,
|
||||
self.element_spec,
|
||||
self._strategy)
|
||||
|
||||
|
||||
class _IterableInput(DistributedDatasetInterface):
|
||||
@ -986,8 +1001,6 @@ class DistributedDataset(_IterableInput):
|
||||
|
||||
self._input_workers = input_workers
|
||||
self._strategy = strategy
|
||||
self._enable_get_next_as_optional = _enable_get_next_as_optional(
|
||||
self._strategy, dataset.element_spec)
|
||||
self._element_spec = _create_distributed_tensor_spec(self._strategy,
|
||||
dataset.element_spec) # pylint: disable=protected-access
|
||||
|
||||
@ -1006,17 +1019,11 @@ class DistributedDataset(_IterableInput):
|
||||
self._input_workers,
|
||||
enable_legacy_iterators)
|
||||
if enable_legacy_iterators:
|
||||
iterator = DistributedIteratorV1(
|
||||
self._input_workers,
|
||||
worker_iterators,
|
||||
self._strategy,
|
||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
||||
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
|
||||
self._strategy)
|
||||
else:
|
||||
iterator = DistributedIterator(
|
||||
self._input_workers,
|
||||
worker_iterators,
|
||||
self._strategy,
|
||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
||||
iterator = DistributedIterator(self._input_workers, worker_iterators,
|
||||
self._strategy)
|
||||
iterator._element_spec = self.element_spec # pylint: disable=protected-access
|
||||
|
||||
# When async eager is enabled, sometimes the iterator may not finish
|
||||
@ -1097,8 +1104,7 @@ class DistributedDatasetV1(DistributedDataset):
|
||||
self._input_workers,
|
||||
True)
|
||||
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
|
||||
self._strategy,
|
||||
self._enable_get_next_as_optional)
|
||||
self._strategy)
|
||||
iterator._element_spec = self.element_spec # pylint: disable=protected-access
|
||||
|
||||
# When async eager is enabled, sometimes the iterator may not finish
|
||||
@ -1150,8 +1156,6 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
||||
_create_datasets_per_worker_with_input_context(self._input_contexts,
|
||||
self._input_workers,
|
||||
dataset_fn))
|
||||
self._enable_get_next_as_optional = _enable_get_next_as_optional(
|
||||
self._strategy, element_spec)
|
||||
self._element_spec = _create_distributed_tensor_spec(
|
||||
self._strategy, element_spec)
|
||||
|
||||
@ -1170,17 +1174,11 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
||||
enable_legacy_iterators)
|
||||
|
||||
if enable_legacy_iterators:
|
||||
iterator = DistributedIteratorV1(
|
||||
self._input_workers,
|
||||
iterators,
|
||||
self._strategy,
|
||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
||||
iterator = DistributedIteratorV1(self._input_workers, iterators,
|
||||
self._strategy)
|
||||
else:
|
||||
iterator = DistributedIterator(
|
||||
self._input_workers,
|
||||
iterators,
|
||||
self._strategy,
|
||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
||||
iterator = DistributedIterator(self._input_workers, iterators,
|
||||
self._strategy)
|
||||
iterator._element_spec = self._element_spec # pylint: disable=protected-access
|
||||
|
||||
# When async eager is enabled, sometimes the iterator may not finish
|
||||
@ -1227,8 +1225,7 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
|
||||
iterators = _create_iterators_per_worker(self._datasets,
|
||||
self._input_workers, True)
|
||||
iterator = DistributedIteratorV1(self._input_workers, iterators,
|
||||
self._strategy,
|
||||
self._enable_get_next_as_optional)
|
||||
self._strategy)
|
||||
iterator._element_spec = self._element_spec # pylint: disable=protected-access
|
||||
|
||||
# When async eager is enabled, sometimes the iterator may not finish
|
||||
@ -1292,8 +1289,9 @@ class InputFunctionIterator(DistributedIteratorV1):
|
||||
"input_fn must return a tf.data.Dataset or a callable.")
|
||||
iterators.append(iterator)
|
||||
|
||||
super(InputFunctionIterator, self).__init__(
|
||||
input_workers, iterators, strategy, enable_get_next_as_optional=False)
|
||||
super(InputFunctionIterator, self).__init__(input_workers, iterators,
|
||||
strategy)
|
||||
self._enable_get_next_as_optional = False
|
||||
|
||||
|
||||
# TODO(anjalisridhar): This class will soon be removed and users should move
|
||||
@ -1332,9 +1330,10 @@ class DatasetIterator(DistributedIteratorV1):
|
||||
input_context=input_context)
|
||||
worker_iterators = _create_iterators_per_worker(
|
||||
dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access
|
||||
super(DatasetIterator,
|
||||
self).__init__(input_workers, worker_iterators, strategy,
|
||||
dist_dataset._enable_get_next_as_optional) # pylint: disable=protected-access
|
||||
super(DatasetIterator, self).__init__(
|
||||
input_workers,
|
||||
worker_iterators, # pylint: disable=protected-access
|
||||
strategy)
|
||||
self._element_spec = dist_dataset.element_spec
|
||||
|
||||
|
||||
@ -1953,19 +1952,3 @@ def _replace_per_replica_spec(spec, i):
|
||||
return spec._value_specs[i] # pylint: disable=protected-access
|
||||
else:
|
||||
return spec
|
||||
|
||||
|
||||
def _enable_get_next_as_optional(strategy, element_spec):
|
||||
"""Returns whether to enable using partial batch handling."""
|
||||
# TODO(b/133073708): we currently need a flag to control the usage because
|
||||
# there is a performance difference between get_next() and
|
||||
# get_next_as_optional(). And we only enable get_next_as_optional when the
|
||||
# output shapes are not static.
|
||||
#
|
||||
# TODO(rxsang): We want to always enable the get_next_as_optional behavior
|
||||
# when user passed input_fn instead of dataset.
|
||||
if not getattr(strategy.extended, "experimental_enable_get_next_as_optional",
|
||||
False):
|
||||
return False
|
||||
return not _is_statically_shaped(
|
||||
element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
|
||||
|
Loading…
x
Reference in New Issue
Block a user