Add CompositeTensor support to distributed iterators.
PiperOrigin-RevId: 308423654 Change-Id: I406d7adeca6059112768f0b8a02a80fbbd463463
This commit is contained in:
		
							parent
							
								
									be3e3961f8
								
							
						
					
					
						commit
						dac6d6ae7c
					
				@ -847,6 +847,38 @@ distribute_py_test(
 | 
			
		||||
    srcs = ["input_lib_test.py"],
 | 
			
		||||
    main = "input_lib_test.py",
 | 
			
		||||
    shard_count = 10,
 | 
			
		||||
    tags = [
 | 
			
		||||
        "multi_and_single_gpu",
 | 
			
		||||
        "no_gpu_presubmit",  # TODO(b/154660040)
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":collective_all_reduce_strategy",
 | 
			
		||||
        ":combinations",
 | 
			
		||||
        ":input_lib",
 | 
			
		||||
        ":mirrored_strategy",
 | 
			
		||||
        ":multi_worker_test_base",
 | 
			
		||||
        ":reduce_util",
 | 
			
		||||
        ":strategy_combinations",
 | 
			
		||||
        ":values",
 | 
			
		||||
        "//tensorflow/python:control_flow_ops",
 | 
			
		||||
        "//tensorflow/python:errors",
 | 
			
		||||
        "//tensorflow/python:math_ops",
 | 
			
		||||
        "//tensorflow/python:sparse_ops",
 | 
			
		||||
        "//tensorflow/python:sparse_tensor",
 | 
			
		||||
        "//tensorflow/python/data/ops:dataset_ops",
 | 
			
		||||
        "//tensorflow/python/eager:context",
 | 
			
		||||
        "//tensorflow/python/eager:test",
 | 
			
		||||
        "//tensorflow/python/ops/ragged:ragged_tensor",
 | 
			
		||||
        "//third_party/py/numpy",
 | 
			
		||||
        "@absl_py//absl/testing:parameterized",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
distribute_py_test(
 | 
			
		||||
    name = "input_lib_type_spec_test",
 | 
			
		||||
    srcs = ["input_lib_type_spec_test.py"],
 | 
			
		||||
    main = "input_lib_type_spec_test.py",
 | 
			
		||||
    shard_count = 10,
 | 
			
		||||
    tags = [
 | 
			
		||||
        "multi_and_single_gpu",
 | 
			
		||||
    ],
 | 
			
		||||
@ -1453,9 +1485,10 @@ distribute_py_test(
 | 
			
		||||
    name = "ctl_correctness_test",
 | 
			
		||||
    srcs = ["ctl_correctness_test.py"],
 | 
			
		||||
    main = "ctl_correctness_test.py",
 | 
			
		||||
    shard_count = 10,
 | 
			
		||||
    shard_count = 20,
 | 
			
		||||
    tags = [
 | 
			
		||||
        "multi_and_single_gpu",
 | 
			
		||||
        "no_gpu_presubmit",  # TODO(b/154660040)
 | 
			
		||||
        "noguitar",  # b/140755528
 | 
			
		||||
    ],
 | 
			
		||||
    deps = [
 | 
			
		||||
 | 
			
		||||
@ -330,8 +330,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
 | 
			
		||||
        communication=self._communication)
 | 
			
		||||
    super(CollectiveAllReduceExtended, self)._initialize_single_worker(
 | 
			
		||||
        local_devices)
 | 
			
		||||
    host_device = device_util.get_host_for_device(self._worker_device)
 | 
			
		||||
    self._input_workers = input_lib.InputWorkers(
 | 
			
		||||
        [(self._worker_device, self.worker_devices)])
 | 
			
		||||
        [(host_device, self.worker_devices)])
 | 
			
		||||
 | 
			
		||||
    # Add a default device so that ops without specified devices will not end up
 | 
			
		||||
    # on other workers.
 | 
			
		||||
 | 
			
		||||
@ -33,6 +33,7 @@ from tensorflow.python.distribute import input_ops
 | 
			
		||||
from tensorflow.python.distribute import reduce_util
 | 
			
		||||
from tensorflow.python.distribute import values
 | 
			
		||||
from tensorflow.python.eager import context
 | 
			
		||||
from tensorflow.python.framework import composite_tensor
 | 
			
		||||
from tensorflow.python.framework import constant_op
 | 
			
		||||
from tensorflow.python.framework import device as tf_device
 | 
			
		||||
from tensorflow.python.framework import dtypes
 | 
			
		||||
@ -41,6 +42,7 @@ from tensorflow.python.framework import ops
 | 
			
		||||
from tensorflow.python.framework import sparse_tensor
 | 
			
		||||
from tensorflow.python.framework import tensor_shape
 | 
			
		||||
from tensorflow.python.framework import tensor_util
 | 
			
		||||
from tensorflow.python.framework import type_spec
 | 
			
		||||
from tensorflow.python.ops import array_ops
 | 
			
		||||
from tensorflow.python.ops import control_flow_ops
 | 
			
		||||
from tensorflow.python.ops import math_ops
 | 
			
		||||
@ -143,9 +145,10 @@ class InputWorkers(object):
 | 
			
		||||
      worker_device_pairs: A sequence of pairs:
 | 
			
		||||
        `(input device, a tuple of compute devices fed by that input device)`.
 | 
			
		||||
    """
 | 
			
		||||
    self._input_worker_devices = tuple(d for d, _ in worker_device_pairs)
 | 
			
		||||
    self._worker_device_pairs = worker_device_pairs
 | 
			
		||||
    self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs)
 | 
			
		||||
    self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
 | 
			
		||||
                              for _, f in worker_device_pairs)
 | 
			
		||||
                              for _, f in self._worker_device_pairs)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def num_workers(self):
 | 
			
		||||
@ -165,6 +168,12 @@ class InputWorkers(object):
 | 
			
		||||
                            for i in range(len(devices)))
 | 
			
		||||
    return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)
 | 
			
		||||
 | 
			
		||||
  def serialize(self):
 | 
			
		||||
    return self._worker_device_pairs
 | 
			
		||||
 | 
			
		||||
  def deserialize(self, worker_device_pairs):
 | 
			
		||||
    return InputWorkers(worker_device_pairs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_next_as_optional(iterator, strategy, name=None):
 | 
			
		||||
  """Returns an empty dataset indicator and the next input from the iterator."""
 | 
			
		||||
@ -208,7 +217,7 @@ def _get_next_as_optional(iterator, strategy, name=None):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _is_statically_shaped(tensor_class, shape):
 | 
			
		||||
  """Test if an iteratort output is statically shaped.
 | 
			
		||||
  """Test if an iterator output is statically shaped.
 | 
			
		||||
 | 
			
		||||
  For sparse and ragged tensors this only tests the batch dimension.
 | 
			
		||||
 | 
			
		||||
@ -231,20 +240,27 @@ def _is_statically_shaped(tensor_class, shape):
 | 
			
		||||
  return shape.is_fully_defined()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DistributedIterator(object):
 | 
			
		||||
def _get_static_shape(iterators):
 | 
			
		||||
  """Returns a boolean indicating if the input is fully defined."""
 | 
			
		||||
  static_shape = True
 | 
			
		||||
  for iterator in iterators:
 | 
			
		||||
    if not isinstance(iterator, (_SingleWorkerOwnedDatasetIterator,
 | 
			
		||||
                                 _SingleWorkerDatasetIterator)):
 | 
			
		||||
      continue
 | 
			
		||||
    flattened = zip(nest.flatten(iterator.output_shapes),
 | 
			
		||||
                    nest.flatten(iterator.output_classes))
 | 
			
		||||
    for output_shape, output_class in flattened:
 | 
			
		||||
      if not _is_statically_shaped(output_class, output_shape):
 | 
			
		||||
        static_shape = False
 | 
			
		||||
        break
 | 
			
		||||
    return static_shape
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DistributedIteratorBase(object):
 | 
			
		||||
  """Common implementation for all input iterators."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, input_workers, iterators, strategy):
 | 
			
		||||
    static_shape = True
 | 
			
		||||
    for iterator in iterators:
 | 
			
		||||
      if not isinstance(iterator, _SingleWorkerDatasetIterator):
 | 
			
		||||
        continue
 | 
			
		||||
      flattened = zip(nest.flatten(iterator.output_shapes),
 | 
			
		||||
                      nest.flatten(iterator.output_classes))
 | 
			
		||||
      for output_shape, output_class in flattened:
 | 
			
		||||
        if not _is_statically_shaped(output_class, output_shape):
 | 
			
		||||
          static_shape = False
 | 
			
		||||
          break
 | 
			
		||||
    static_shape = _get_static_shape(iterators)
 | 
			
		||||
 | 
			
		||||
    # TODO(b/133073708): we currently need a flag to control the usage because
 | 
			
		||||
    # there is a performance difference between get_next() and
 | 
			
		||||
@ -360,6 +376,10 @@ class DistributedIterator(object):
 | 
			
		||||
 | 
			
		||||
    return values.regroup(replicas)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DistributedIteratorV1(DistributedIteratorBase):
 | 
			
		||||
  """Input Iterator for a distributed dataset."""
 | 
			
		||||
 | 
			
		||||
  # We need a private initializer method for re-initializing multidevice
 | 
			
		||||
  # iterators when used with Keras training loops. If we don't reinitialize the
 | 
			
		||||
  # iterator we run into memory leak issues (b/123315763).
 | 
			
		||||
@ -370,23 +390,14 @@ class DistributedIterator(object):
 | 
			
		||||
      init_ops.extend(it.initialize())
 | 
			
		||||
    return control_flow_ops.group(init_ops)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def element_spec(self):
 | 
			
		||||
    """The type specification of an element of this iterator."""
 | 
			
		||||
    return self._element_spec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DistributedIteratorV1(DistributedIterator):
 | 
			
		||||
  """Input Iterator for a distributed dataset instance."""
 | 
			
		||||
 | 
			
		||||
  @deprecated(None, "Use the iterator's `initializer` property instead.")
 | 
			
		||||
  def initialize(self):
 | 
			
		||||
    """Initialze underlying iterators.
 | 
			
		||||
    """Initialize underlying iterators.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      A list of any initializer ops that should be run.
 | 
			
		||||
    """
 | 
			
		||||
    return super(DistributedIteratorV1, self)._initializer
 | 
			
		||||
    return self._initializer
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def initializer(self):
 | 
			
		||||
@ -415,6 +426,161 @@ class DistributedIteratorV1(DistributedIterator):
 | 
			
		||||
        return self._iterators[i]
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def element_spec(self):
 | 
			
		||||
    """The type specification of an element of this iterator."""
 | 
			
		||||
    return self._element_spec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DistributedIteratorSpec(type_spec.TypeSpec):
 | 
			
		||||
  """Type specification for `DistributedIterator`."""
 | 
			
		||||
 | 
			
		||||
  __slots__ = ["_input_workers", "_element_spec", "_strategy"]
 | 
			
		||||
 | 
			
		||||
  def __init__(self, input_workers, element_spec, strategy):
 | 
			
		||||
    # We don't want to allow deserialization of this class because we don't
 | 
			
		||||
    # serialize the strategy object. Currently the only places where
 | 
			
		||||
    # _deserialize is called is when we save/restore using SavedModels.
 | 
			
		||||
    if isinstance(input_workers, tuple):
 | 
			
		||||
      raise NotImplementedError("DistributedIteratorSpec does not have support "
 | 
			
		||||
                                "for deserialization.")
 | 
			
		||||
    else:
 | 
			
		||||
      self._input_workers = input_workers
 | 
			
		||||
      self._element_spec = element_spec
 | 
			
		||||
      self._strategy = strategy
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def value_type(self):
 | 
			
		||||
    return DistributedIterator
 | 
			
		||||
 | 
			
		||||
  def _serialize(self):
 | 
			
		||||
    # We cannot serialize the strategy object so we convert it to an id that we
 | 
			
		||||
    # can use for comparison.
 | 
			
		||||
    return (self._input_workers.serialize(),
 | 
			
		||||
            self._element_spec, id(self._strategy))
 | 
			
		||||
 | 
			
		||||
  def _deserialize(self):
 | 
			
		||||
    raise ValueError("Deserialization is currently unsupported for "
 | 
			
		||||
                     "DistributedIteratorSpec.")
 | 
			
		||||
 | 
			
		||||
  @staticmethod
 | 
			
		||||
  def _is_compatible(a, b):
 | 
			
		||||
    """Returns true if the given type serializations compatible."""
 | 
			
		||||
    if type(a) is not type(b):
 | 
			
		||||
      return False
 | 
			
		||||
    if isinstance(a, tuple):
 | 
			
		||||
      return (len(a) == len(b) and
 | 
			
		||||
              all(DistributedIteratorSpec._is_compatible(x, y) for (x, y) in
 | 
			
		||||
                  zip(a, b)))
 | 
			
		||||
    if isinstance(a, dict):
 | 
			
		||||
      return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all(
 | 
			
		||||
          DistributedIteratorSpec._is_compatible(a[k], b[k]) for k in a.keys()))
 | 
			
		||||
    if isinstance(a, (type_spec.TypeSpec, tensor_shape.TensorShape,
 | 
			
		||||
                      dtypes.DType)):
 | 
			
		||||
      return a.is_compatible_with(b)
 | 
			
		||||
    return a == b
 | 
			
		||||
 | 
			
		||||
  # Overriding this method so that we can merge and reconstruct the spec object
 | 
			
		||||
  def most_specific_compatible_type(self, other):
 | 
			
		||||
    """Returns the most specific TypeSpec compatible with `self` and `other`.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      other: A `TypeSpec`.
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
      ValueError: If there is no TypeSpec that is compatible with both `self`
 | 
			
		||||
        and `other`.
 | 
			
		||||
    """
 | 
			
		||||
    # pylint: disable=protected-access
 | 
			
		||||
    if type(self) is not type(other):
 | 
			
		||||
      raise ValueError("No TypeSpec is compatible with both %s and %s" %
 | 
			
		||||
                       (self, other))
 | 
			
		||||
    if not self._is_compatible(self._input_workers.serialize(),
 | 
			
		||||
                               other._input_workers.serialize()):
 | 
			
		||||
      raise ValueError("_input_workers is not compatible with both %s "
 | 
			
		||||
                       "and %s" % (self, other))
 | 
			
		||||
    if self._element_spec != other._element_spec:
 | 
			
		||||
      raise ValueError("_element_spec is not compatible with both %s "
 | 
			
		||||
                       "and %s" % (self, other))
 | 
			
		||||
    if id(self._strategy) != id(other._strategy):
 | 
			
		||||
      raise ValueError("tf.distribute strategy is not compatible with both %s "
 | 
			
		||||
                       "and %s" % (self, other))
 | 
			
		||||
    return DistributedIteratorSpec(self._input_workers, self._element_spec,
 | 
			
		||||
                                   self._strategy)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def _component_specs(self):
 | 
			
		||||
    specs = []
 | 
			
		||||
    worker_device_pairs = self._input_workers._worker_device_pairs  # pylint: disable=protected-access
 | 
			
		||||
    for i in range(len(worker_device_pairs)):
 | 
			
		||||
      input_device, compute_devices = worker_device_pairs[i]
 | 
			
		||||
      specs.append(_SingleWorkerDatasetIteratorSpec(input_device,
 | 
			
		||||
                                                    compute_devices,
 | 
			
		||||
                                                    element_spec=
 | 
			
		||||
                                                    self._element_spec))
 | 
			
		||||
    return specs
 | 
			
		||||
 | 
			
		||||
  def _to_components(self, value):
 | 
			
		||||
    return value._iterators  # pylint: disable=protected-access
 | 
			
		||||
 | 
			
		||||
  def _from_components(self, components):
 | 
			
		||||
    return DistributedIterator(input_workers=self._input_workers,
 | 
			
		||||
                               iterators=None,
 | 
			
		||||
                               components=components,
 | 
			
		||||
                               element_spec=self._element_spec,
 | 
			
		||||
                               strategy=self._strategy)
 | 
			
		||||
 | 
			
		||||
  @staticmethod
 | 
			
		||||
  def from_value(value):
 | 
			
		||||
    # pylint: disable=protected-access
 | 
			
		||||
    return DistributedIteratorSpec(value._input_workers, value._element_spec,
 | 
			
		||||
                                   value._strategy)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DistributedIterator(DistributedIteratorBase,
 | 
			
		||||
                          composite_tensor.CompositeTensor):
 | 
			
		||||
  """Input Iterator for a distributed dataset."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, input_workers=None, iterators=None, strategy=None,
 | 
			
		||||
               components=None, element_spec=None):
 | 
			
		||||
    if input_workers is None:
 | 
			
		||||
      raise ValueError("`input_workers` should be "
 | 
			
		||||
                       "provided.")
 | 
			
		||||
 | 
			
		||||
    error_message = ("Either `input_workers` or "
 | 
			
		||||
                     "both `components` and `element_spec` need to be "
 | 
			
		||||
                     "provided.")
 | 
			
		||||
 | 
			
		||||
    if iterators is None:
 | 
			
		||||
      if (components is None or element_spec is None):
 | 
			
		||||
        raise ValueError(error_message)
 | 
			
		||||
      self._element_spec = element_spec
 | 
			
		||||
      self._input_workers = input_workers
 | 
			
		||||
      self._iterators = components
 | 
			
		||||
      static_shape = _get_static_shape(self._iterators)
 | 
			
		||||
      self._strategy = strategy
 | 
			
		||||
      if getattr(
 | 
			
		||||
          strategy.extended, "experimental_enable_get_next_as_optional", False):
 | 
			
		||||
        self._enable_get_next_as_optional = not static_shape
 | 
			
		||||
      else:
 | 
			
		||||
        self._enable_get_next_as_optional = False
 | 
			
		||||
    else:
 | 
			
		||||
      if (components is not None and element_spec is not None):
 | 
			
		||||
        raise ValueError(error_message)
 | 
			
		||||
 | 
			
		||||
      super(DistributedIterator, self).__init__(input_workers, iterators,
 | 
			
		||||
                                                strategy)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def element_spec(self):
 | 
			
		||||
    return self._element_spec
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def _type_spec(self):
 | 
			
		||||
    return DistributedIteratorSpec(self._input_workers,
 | 
			
		||||
                                   self.element_spec,
 | 
			
		||||
                                   self._strategy)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _IterableInput(object):
 | 
			
		||||
  """Base class for iterable inputs for distribution strategies."""
 | 
			
		||||
@ -482,7 +648,6 @@ class DistributedDataset(_IterableInput):
 | 
			
		||||
        `num_input_pipelines` in the `InputContext`.
 | 
			
		||||
    """
 | 
			
		||||
    super(DistributedDataset, self).__init__(input_workers=input_workers)
 | 
			
		||||
 | 
			
		||||
    # We clone and shard the dataset on each worker. The current setup tries to
 | 
			
		||||
    # shard the dataset by files if possible so that each worker sees a
 | 
			
		||||
    # different subset of files. If that is not possible, will attempt to shard
 | 
			
		||||
@ -541,10 +706,20 @@ class DistributedDataset(_IterableInput):
 | 
			
		||||
      raise RuntimeError("__iter__() is only supported inside of tf.function "
 | 
			
		||||
                         "or when eager execution is enabled.")
 | 
			
		||||
 | 
			
		||||
    # This is an optional flag that can be used to turn off using
 | 
			
		||||
    # OwnedMultiDeviceIterators and instead use the legacy MultiDeviceIterators
 | 
			
		||||
    # as a stop gap solution that will allow us to roll out this change.
 | 
			
		||||
    enable_legacy_iterators = getattr(self._strategy,
 | 
			
		||||
                                      "_enable_legacy_iterators", False)
 | 
			
		||||
    worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
 | 
			
		||||
                                                    self._input_workers)
 | 
			
		||||
    iterator = DistributedIterator(self._input_workers, worker_iterators,
 | 
			
		||||
                                   self._strategy)
 | 
			
		||||
                                                    self._input_workers,
 | 
			
		||||
                                                    enable_legacy_iterators)
 | 
			
		||||
    if enable_legacy_iterators:
 | 
			
		||||
      iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
 | 
			
		||||
                                       self._strategy)
 | 
			
		||||
    else:
 | 
			
		||||
      iterator = DistributedIterator(self._input_workers, worker_iterators,
 | 
			
		||||
                                     self._strategy)
 | 
			
		||||
    iterator._element_spec = self.element_spec  # pylint: disable=protected-access
 | 
			
		||||
    return iterator
 | 
			
		||||
 | 
			
		||||
@ -615,12 +790,21 @@ class DistributedDatasetV1(DistributedDataset):
 | 
			
		||||
 | 
			
		||||
  def _get_iterator(self):
 | 
			
		||||
    worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
 | 
			
		||||
                                                    self._input_workers)
 | 
			
		||||
                                                    self._input_workers,
 | 
			
		||||
                                                    True)
 | 
			
		||||
    iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
 | 
			
		||||
                                     self._strategy)
 | 
			
		||||
    iterator._element_spec = self.element_spec  # pylint: disable=protected-access
 | 
			
		||||
    return iterator
 | 
			
		||||
 | 
			
		||||
  def __iter__(self):
 | 
			
		||||
    if (ops.executing_eagerly_outside_functions() or
 | 
			
		||||
        ops.get_default_graph().building_function):
 | 
			
		||||
      return self._get_iterator()
 | 
			
		||||
 | 
			
		||||
    raise RuntimeError("__iter__() is only supported inside of tf.function "
 | 
			
		||||
                       "or when eager execution is enabled.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO(priyag): Add other replication modes.
 | 
			
		||||
class DistributedDatasetsFromFunction(_IterableInput):
 | 
			
		||||
@ -653,20 +837,36 @@ class DistributedDatasetsFromFunction(_IterableInput):
 | 
			
		||||
    self._strategy = strategy
 | 
			
		||||
    self._element_spec = None
 | 
			
		||||
 | 
			
		||||
  def __iter__(self):
 | 
			
		||||
    if not (context.executing_eagerly() or
 | 
			
		||||
            ops.get_default_graph().building_function):
 | 
			
		||||
      raise RuntimeError("__iter__() is only supported inside of tf.function "
 | 
			
		||||
                         "or when eager execution is enabled.")
 | 
			
		||||
    super(DistributedDatasetsFromFunction, self).__init__(
 | 
			
		||||
        input_workers=input_workers)
 | 
			
		||||
 | 
			
		||||
    iterators, element_spec = _create_iterators_per_worker_with_input_context(
 | 
			
		||||
        self._input_contexts, self._input_workers, self._dataset_fn)
 | 
			
		||||
    iterator = DistributedIterator(self._input_workers, iterators,
 | 
			
		||||
                                   self._strategy)
 | 
			
		||||
    self._element_spec = _create_distributed_tensor_spec(self._strategy,
 | 
			
		||||
                                                         element_spec)
 | 
			
		||||
    iterator._element_spec = self._element_spec  # pylint: disable=protected-access
 | 
			
		||||
    return iterator
 | 
			
		||||
  def __iter__(self):
 | 
			
		||||
    if (ops.executing_eagerly_outside_functions() or
 | 
			
		||||
        ops.get_default_graph().building_function):
 | 
			
		||||
      # This is an optional flag that can be used to turn off using
 | 
			
		||||
      # OwnedMultiDeviceIterators and instead use the legacy
 | 
			
		||||
      # MultiDeviceIterators as a stop gap solution that will allow us to roll
 | 
			
		||||
      # out this change.
 | 
			
		||||
      enable_legacy_iterators = getattr(self._strategy,
 | 
			
		||||
                                        "_enable_legacy_iterators", False)
 | 
			
		||||
 | 
			
		||||
      iterators, element_spec = _create_iterators_per_worker_with_input_context(
 | 
			
		||||
          self._input_contexts, self._input_workers, self._dataset_fn,
 | 
			
		||||
          enable_legacy_iterators)
 | 
			
		||||
 | 
			
		||||
      if enable_legacy_iterators:
 | 
			
		||||
        iterator = DistributedIteratorV1(self._input_workers, iterators,
 | 
			
		||||
                                         self._strategy)
 | 
			
		||||
      else:
 | 
			
		||||
        iterator = DistributedIterator(self._input_workers, iterators,
 | 
			
		||||
                                       self._strategy)
 | 
			
		||||
      self._element_spec = _create_distributed_tensor_spec(self._strategy,
 | 
			
		||||
                                                           element_spec)
 | 
			
		||||
      iterator._element_spec = self._element_spec  # pylint: disable=protected-access
 | 
			
		||||
      return iterator
 | 
			
		||||
 | 
			
		||||
    raise RuntimeError("__iter__() is only supported inside of tf.function "
 | 
			
		||||
                       "or when eager execution is enabled.")
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def element_spec(self):
 | 
			
		||||
@ -705,7 +905,8 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
 | 
			
		||||
 | 
			
		||||
  def _get_iterator(self):
 | 
			
		||||
    iterators, element_spec = _create_iterators_per_worker_with_input_context(
 | 
			
		||||
        self._input_contexts, self._input_workers, self._dataset_fn)
 | 
			
		||||
        self._input_contexts, self._input_workers, self._dataset_fn,
 | 
			
		||||
        True)
 | 
			
		||||
    iterator = DistributedIteratorV1(self._input_workers, iterators,
 | 
			
		||||
                                     self._strategy)
 | 
			
		||||
    self._element_spec = _create_distributed_tensor_spec(self._strategy,
 | 
			
		||||
@ -713,6 +914,14 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
 | 
			
		||||
    iterator._element_spec = self._element_spec  # pylint: disable=protected-access
 | 
			
		||||
    return iterator
 | 
			
		||||
 | 
			
		||||
  def __iter__(self):
 | 
			
		||||
    if (ops.executing_eagerly_outside_functions() or
 | 
			
		||||
        ops.get_default_graph().building_function):
 | 
			
		||||
      return self._get_iterator()
 | 
			
		||||
 | 
			
		||||
    raise RuntimeError("__iter__() is only supported inside of tf.function "
 | 
			
		||||
                       "or when eager execution is enabled.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO(anjalisridhar): This class will be soon be removed in favor of newer
 | 
			
		||||
# APIs.
 | 
			
		||||
@ -797,7 +1006,7 @@ class DatasetIterator(DistributedIteratorV1):
 | 
			
		||||
        split_batch_by=split_batch_by,
 | 
			
		||||
        input_context=input_context)
 | 
			
		||||
    worker_iterators = _create_iterators_per_worker(
 | 
			
		||||
        dist_dataset._cloned_datasets, input_workers)  # pylint: disable=protected-access
 | 
			
		||||
        dist_dataset._cloned_datasets, input_workers, True)  # pylint: disable=protected-access
 | 
			
		||||
    super(DatasetIterator, self).__init__(
 | 
			
		||||
        input_workers,
 | 
			
		||||
        worker_iterators,  # pylint: disable=protected-access
 | 
			
		||||
@ -808,18 +1017,18 @@ class DatasetIterator(DistributedIteratorV1):
 | 
			
		||||
def _dummy_tensor_fn(value_structure):
 | 
			
		||||
  """A function to create dummy tensors from `value_structure`."""
 | 
			
		||||
 | 
			
		||||
  def create_dummy_tensor(type_spec):
 | 
			
		||||
  def create_dummy_tensor(spec):
 | 
			
		||||
    """Create a dummy tensor with possible batch dimensions set to 0."""
 | 
			
		||||
    if isinstance(type_spec, ragged_tensor.RaggedTensorSpec):
 | 
			
		||||
    if isinstance(spec, ragged_tensor.RaggedTensorSpec):
 | 
			
		||||
      # Splice out the ragged dimensions.
 | 
			
		||||
      # pylint: disable=protected-access
 | 
			
		||||
      feature_shape = type_spec._shape[:1].concatenate(
 | 
			
		||||
          type_spec._shape[(1 + type_spec._ragged_rank):])
 | 
			
		||||
      feature_type = type_spec._dtype
 | 
			
		||||
      feature_shape = spec._shape[:1].concatenate(
 | 
			
		||||
          spec._shape[(1 + spec._ragged_rank):])
 | 
			
		||||
      feature_type = spec._dtype
 | 
			
		||||
      # pylint: enable=protected-access
 | 
			
		||||
    else:
 | 
			
		||||
      feature_shape = type_spec.shape
 | 
			
		||||
      feature_type = type_spec.dtype
 | 
			
		||||
      feature_shape = spec.shape
 | 
			
		||||
      feature_type = spec.dtype
 | 
			
		||||
    # Ideally we should set the batch dimension to 0, however as in
 | 
			
		||||
    # DistributionStrategy we don't know the batch dimension, we try to
 | 
			
		||||
    # guess it as much as possible. If the feature has unknown dimensions, we
 | 
			
		||||
@ -827,11 +1036,11 @@ def _dummy_tensor_fn(value_structure):
 | 
			
		||||
    # first dimension as batch dimension and set it to 0.
 | 
			
		||||
    dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()]
 | 
			
		||||
            if feature_shape else [])
 | 
			
		||||
    if dims and (isinstance(type_spec, ragged_tensor.RaggedTensorSpec) or
 | 
			
		||||
    if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or
 | 
			
		||||
                 feature_shape.is_fully_defined()):
 | 
			
		||||
      dims[0] = tensor_shape.Dimension(0)
 | 
			
		||||
 | 
			
		||||
    if isinstance(type_spec, sparse_tensor.SparseTensorSpec):
 | 
			
		||||
    if isinstance(spec, sparse_tensor.SparseTensorSpec):
 | 
			
		||||
      return sparse_tensor.SparseTensor(
 | 
			
		||||
          values=array_ops.zeros(0, feature_type),
 | 
			
		||||
          indices=array_ops.zeros((0, len(dims)), dtypes.int64),
 | 
			
		||||
@ -839,26 +1048,26 @@ def _dummy_tensor_fn(value_structure):
 | 
			
		||||
 | 
			
		||||
    # Create the dummy tensor.
 | 
			
		||||
    dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
 | 
			
		||||
    if isinstance(type_spec, ragged_tensor.RaggedTensorSpec):
 | 
			
		||||
    if isinstance(spec, ragged_tensor.RaggedTensorSpec):
 | 
			
		||||
      # Reinsert the ragged dimensions with size 0.
 | 
			
		||||
      # pylint: disable=protected-access
 | 
			
		||||
      row_splits = array_ops.zeros(1, type_spec._row_splits_dtype)
 | 
			
		||||
      row_splits = array_ops.zeros(1, spec._row_splits_dtype)
 | 
			
		||||
      dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits(
 | 
			
		||||
          dummy_tensor, (row_splits,) * type_spec._ragged_rank, validate=False)
 | 
			
		||||
          dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False)
 | 
			
		||||
      # pylint: enable=protected-access
 | 
			
		||||
    return dummy_tensor
 | 
			
		||||
 | 
			
		||||
  return nest.map_structure(create_dummy_tensor, value_structure)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _SingleWorkerDatasetIterator(object):
 | 
			
		||||
class _SingleWorkerDatasetIteratorBase(object):
 | 
			
		||||
  """Iterator for a single `tf.data.Dataset`."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, dataset, worker, devices):
 | 
			
		||||
    """Create iterator for the `dataset` to fetch data to worker's `devices` .
 | 
			
		||||
 | 
			
		||||
    `MultiDeviceIterator` is used to prefetch input to the devices on the
 | 
			
		||||
    given worker.
 | 
			
		||||
    A `MultiDeviceIterator`  or `OwnedMultiDeviceIterator` is used to prefetch
 | 
			
		||||
    input to the devices on the given worker.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      dataset: A `tf.data.Dataset` instance.
 | 
			
		||||
@ -868,13 +1077,11 @@ class _SingleWorkerDatasetIterator(object):
 | 
			
		||||
    self._dataset = dataset
 | 
			
		||||
    self._worker = worker
 | 
			
		||||
    self._devices = devices
 | 
			
		||||
    self._element_spec = dataset.element_spec
 | 
			
		||||
    self._make_iterator()
 | 
			
		||||
 | 
			
		||||
  def _make_iterator(self):
 | 
			
		||||
    """Make appropriate iterator on the dataset."""
 | 
			
		||||
    with ops.device(self._worker):
 | 
			
		||||
      self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
 | 
			
		||||
          self._dataset, self._devices)
 | 
			
		||||
    raise NotImplementedError("must be implemented in descendants")
 | 
			
		||||
 | 
			
		||||
  def get_next(self, device, name=None):
 | 
			
		||||
    """Get next element for the given device."""
 | 
			
		||||
@ -923,9 +1130,9 @@ class _SingleWorkerDatasetIterator(object):
 | 
			
		||||
        # Place the condition op in the same device as the data so the data
 | 
			
		||||
        # doesn't need to be sent back to the worker.
 | 
			
		||||
        with ops.device(self._devices[i]):
 | 
			
		||||
          # As MultiDeviceIterator will fetch data in order, so we only need to
 | 
			
		||||
          # check if the first replica has value to see whether there is data
 | 
			
		||||
          # left for this single worker.
 | 
			
		||||
          # Data will be fetched in order, so we only need to check if the first
 | 
			
		||||
          # replica has value to see whether there is data left for this single
 | 
			
		||||
          # worker.
 | 
			
		||||
          if i == 0:
 | 
			
		||||
            worker_has_value = data.has_value()
 | 
			
		||||
 | 
			
		||||
@ -943,8 +1150,159 @@ class _SingleWorkerDatasetIterator(object):
 | 
			
		||||
 | 
			
		||||
      return worker_has_value, result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
 | 
			
		||||
  """Type specification for `_SingleWorkerOwnedDatasetIterator`."""
 | 
			
		||||
 | 
			
		||||
  __slots__ = ["_worker", "_devices", "_element_spec"]
 | 
			
		||||
 | 
			
		||||
  def __init__(self, worker, devices, element_spec):
 | 
			
		||||
    self._worker = worker
 | 
			
		||||
    self._devices = devices
 | 
			
		||||
    self._element_spec = element_spec
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def value_type(self):
 | 
			
		||||
    return _SingleWorkerOwnedDatasetIterator
 | 
			
		||||
 | 
			
		||||
  def _serialize(self):
 | 
			
		||||
    return (self._worker, tuple(self._devices), self._element_spec)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def _component_specs(self):
 | 
			
		||||
    specs = []
 | 
			
		||||
    specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec(
 | 
			
		||||
        self._devices, self._worker, element_spec=self._element_spec))
 | 
			
		||||
    return specs
 | 
			
		||||
 | 
			
		||||
  def _to_components(self, value):
 | 
			
		||||
    return [value._iterator]  # pylint: disable=protected-access
 | 
			
		||||
 | 
			
		||||
  def _from_components(self, components):
 | 
			
		||||
    return _SingleWorkerOwnedDatasetIterator(
 | 
			
		||||
        dataset=None,
 | 
			
		||||
        worker=self._worker,
 | 
			
		||||
        devices=self._devices,
 | 
			
		||||
        components=components,
 | 
			
		||||
        element_spec=self._element_spec)
 | 
			
		||||
 | 
			
		||||
  @staticmethod
 | 
			
		||||
  def from_value(value):
 | 
			
		||||
    # pylint: disable=protected-access
 | 
			
		||||
    return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
 | 
			
		||||
                                            value._element_spec)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
 | 
			
		||||
                                        composite_tensor.CompositeTensor):
 | 
			
		||||
  """Iterator for a DistributedDataset instance."""
 | 
			
		||||
 | 
			
		||||
  def __init__(self, dataset=None, worker=None, devices=None, components=None,
 | 
			
		||||
               element_spec=None):
 | 
			
		||||
    """Create iterator for the `dataset` to fetch data to worker's `devices` .
 | 
			
		||||
 | 
			
		||||
    `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the
 | 
			
		||||
    given worker. The lifetime of this iterator is tied to the encompassing
 | 
			
		||||
    python object. Once we go out of scope of the python object or return from
 | 
			
		||||
    a tf.function the underlying iterator resource is deleted.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      dataset: A `tf.data.Dataset` instance.
 | 
			
		||||
      worker: Worker on which ops should be created.
 | 
			
		||||
      devices: Distribute data from `dataset` to these devices.
 | 
			
		||||
      components: Tensor components to construct the
 | 
			
		||||
        _SingleWorkerOwnedDatasetIterator from.
 | 
			
		||||
      element_spec: A nested structure of `TypeSpec` objects that represents the
 | 
			
		||||
      type specification of elements of the iterator.
 | 
			
		||||
    """
 | 
			
		||||
    if worker is None or devices is None:
 | 
			
		||||
      raise ValueError("Both `worker` and `devices` should be provided")
 | 
			
		||||
 | 
			
		||||
    error_message = ("Either `dataset` or both `components` and `element_spec` "
 | 
			
		||||
                     "need to be provided.")
 | 
			
		||||
 | 
			
		||||
    if dataset is None:
 | 
			
		||||
      if (components is None or element_spec is None):
 | 
			
		||||
        raise ValueError(error_message)
 | 
			
		||||
      self._element_spec = element_spec
 | 
			
		||||
      self._worker = worker
 | 
			
		||||
      self._devices = devices
 | 
			
		||||
      self._iterator = components[0]
 | 
			
		||||
    else:
 | 
			
		||||
      if (components is not None or element_spec is not None):
 | 
			
		||||
        raise ValueError(error_message)
 | 
			
		||||
      super(_SingleWorkerOwnedDatasetIterator, self).__init__(dataset, worker,
 | 
			
		||||
                                                              devices)
 | 
			
		||||
 | 
			
		||||
  def _make_iterator(self):
 | 
			
		||||
    """Make appropriate iterator on the dataset."""
 | 
			
		||||
    if not self._worker:
 | 
			
		||||
      raise ValueError("Worked device must be specified when creating an "
 | 
			
		||||
                       "owned iterator.")
 | 
			
		||||
    host_device = device_util.get_host_for_device(self._worker)
 | 
			
		||||
    with ops.device(self._worker):
 | 
			
		||||
      self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
 | 
			
		||||
          self._dataset, self._devices, source_device=host_device)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def element_spec(self):
 | 
			
		||||
    return self._element_spec
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def _type_spec(self):
 | 
			
		||||
    return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
 | 
			
		||||
                                            self._element_spec)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def output_classes(self):
 | 
			
		||||
    """Returns the class of each component of an element of this iterator.
 | 
			
		||||
 | 
			
		||||
    The expected values are `tf.Tensor` and `tf.SparseTensor`.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      A nested structure of Python `type` objects corresponding to each
 | 
			
		||||
      component of an element of this dataset.
 | 
			
		||||
    """
 | 
			
		||||
    return nest.map_structure(
 | 
			
		||||
        lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access
 | 
			
		||||
        self._element_spec)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def output_shapes(self):
 | 
			
		||||
    """Returns the shape of each component of an element of this iterator.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      A nested structure of `tf.TensorShape` objects corresponding to each
 | 
			
		||||
      component of an element of this dataset.
 | 
			
		||||
    """
 | 
			
		||||
    return nest.map_structure(
 | 
			
		||||
        lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access
 | 
			
		||||
        self._element_spec)
 | 
			
		||||
 | 
			
		||||
  @property
 | 
			
		||||
  def output_types(self):
 | 
			
		||||
    """Returns the type of each component of an element of this iterator.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
      A nested structure of `tf.DType` objects corresponding to each component
 | 
			
		||||
      of an element of this dataset.
 | 
			
		||||
    """
 | 
			
		||||
    return nest.map_structure(
 | 
			
		||||
        lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access
 | 
			
		||||
        self._element_spec)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _SingleWorkerDatasetIterator(_SingleWorkerDatasetIteratorBase):
 | 
			
		||||
  """Iterator for a single DistributedDatasetV1 instance."""
 | 
			
		||||
 | 
			
		||||
  def _make_iterator(self):
 | 
			
		||||
    """Make appropriate iterator on the dataset."""
 | 
			
		||||
    with ops.device(self._worker):
 | 
			
		||||
      self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
 | 
			
		||||
          self._dataset, self._devices)
 | 
			
		||||
 | 
			
		||||
  def initialize(self):
 | 
			
		||||
    """Initialze underlying iterator.
 | 
			
		||||
    """Initialize underlying iterator.
 | 
			
		||||
 | 
			
		||||
    In eager execution, this simply recreates the underlying iterator.
 | 
			
		||||
    In graph execution, it returns the initializer ops for the underlying
 | 
			
		||||
@ -1005,7 +1363,8 @@ class _SingleWorkerCallableIterator(object):
 | 
			
		||||
    return []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _create_iterators_per_worker(worker_datasets, input_workers):
 | 
			
		||||
def _create_iterators_per_worker(worker_datasets, input_workers,
 | 
			
		||||
                                 enable_legacy_iterators):
 | 
			
		||||
  """Create a multidevice iterator on each of the workers."""
 | 
			
		||||
  assert isinstance(input_workers, InputWorkers)
 | 
			
		||||
 | 
			
		||||
@ -1014,23 +1373,35 @@ def _create_iterators_per_worker(worker_datasets, input_workers):
 | 
			
		||||
  for i, worker in enumerate(input_workers.worker_devices):
 | 
			
		||||
    with ops.device(worker):
 | 
			
		||||
      worker_devices = input_workers.compute_devices_for_worker(i)
 | 
			
		||||
      iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
 | 
			
		||||
                                              worker_devices)
 | 
			
		||||
      if tf2.enabled() and not enable_legacy_iterators:
 | 
			
		||||
        iterator = _SingleWorkerOwnedDatasetIterator(worker_datasets[i], worker,
 | 
			
		||||
                                                     worker_devices)
 | 
			
		||||
      else:
 | 
			
		||||
        iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
 | 
			
		||||
                                                worker_devices)
 | 
			
		||||
      iterators.append(iterator)
 | 
			
		||||
  return iterators
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _create_iterators_per_worker_with_input_context(input_contexts,
 | 
			
		||||
                                                    input_workers,
 | 
			
		||||
                                                    dataset_fn):
 | 
			
		||||
                                                    dataset_fn,
 | 
			
		||||
                                                    enable_legacy_iterators):
 | 
			
		||||
  """Create a multidevice iterator per workers given a dataset function."""
 | 
			
		||||
  iterators = []
 | 
			
		||||
  element_specs = []
 | 
			
		||||
  for i, ctx in enumerate(input_contexts):
 | 
			
		||||
    worker = input_workers.worker_devices[i]
 | 
			
		||||
    with ops.device(worker):
 | 
			
		||||
      dataset = dataset_fn(ctx)
 | 
			
		||||
      element_specs.append(dataset.element_spec)
 | 
			
		||||
      devices = input_workers.compute_devices_for_worker(i)
 | 
			
		||||
      iterator = _SingleWorkerDatasetIterator(dataset, worker, devices)
 | 
			
		||||
      if tf2.enabled() and not enable_legacy_iterators:
 | 
			
		||||
        iterator = _SingleWorkerOwnedDatasetIterator(dataset, worker,
 | 
			
		||||
                                                     devices)
 | 
			
		||||
      else:
 | 
			
		||||
        iterator = _SingleWorkerDatasetIterator(dataset, worker,
 | 
			
		||||
                                                devices)
 | 
			
		||||
      iterators.append(iterator)
 | 
			
		||||
  return iterators, dataset.element_spec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -45,6 +45,7 @@ from tensorflow.python.eager import context
 | 
			
		||||
from tensorflow.python.eager import def_function
 | 
			
		||||
from tensorflow.python.eager import test
 | 
			
		||||
from tensorflow.python.framework import errors
 | 
			
		||||
from tensorflow.python.framework import ops
 | 
			
		||||
from tensorflow.python.framework import sparse_tensor
 | 
			
		||||
from tensorflow.python.ops import control_flow_ops
 | 
			
		||||
from tensorflow.python.ops import math_ops
 | 
			
		||||
@ -105,20 +106,21 @@ class DistributedIteratorTestBase(test.TestCase):
 | 
			
		||||
                    split_batch_by,
 | 
			
		||||
                    strategy,
 | 
			
		||||
                    input_context=None):
 | 
			
		||||
    if isinstance(dataset, (dataset_ops.Dataset, dataset_ops.DatasetV1Adapter)):
 | 
			
		||||
      return input_lib.DistributedDatasetV1(
 | 
			
		||||
          dataset,
 | 
			
		||||
          input_workers,
 | 
			
		||||
          strategy,
 | 
			
		||||
          split_batch_by=split_batch_by,
 | 
			
		||||
          input_context=input_context)
 | 
			
		||||
    elif input_type == "dataset":
 | 
			
		||||
      return input_lib.DistributedDataset(
 | 
			
		||||
          dataset,
 | 
			
		||||
          input_workers,
 | 
			
		||||
          strategy,
 | 
			
		||||
          split_batch_by=split_batch_by,
 | 
			
		||||
          input_context=input_context)
 | 
			
		||||
    if input_type == "dataset":
 | 
			
		||||
      if tf2.enabled():
 | 
			
		||||
        return input_lib.DistributedDataset(
 | 
			
		||||
            dataset,
 | 
			
		||||
            input_workers,
 | 
			
		||||
            strategy,
 | 
			
		||||
            split_batch_by=split_batch_by,
 | 
			
		||||
            input_context=input_context)
 | 
			
		||||
      else:
 | 
			
		||||
        return input_lib.DistributedDatasetV1(
 | 
			
		||||
            dataset,
 | 
			
		||||
            input_workers,
 | 
			
		||||
            strategy,
 | 
			
		||||
            split_batch_by=split_batch_by,
 | 
			
		||||
            input_context=input_context)
 | 
			
		||||
    else:
 | 
			
		||||
      return strategy.experimental_distribute_datasets_from_function(dataset)
 | 
			
		||||
 | 
			
		||||
@ -139,6 +141,9 @@ class DistributedIteratorTestBase(test.TestCase):
 | 
			
		||||
    if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
 | 
			
		||||
      self.skipTest("unsupported test combination.")
 | 
			
		||||
 | 
			
		||||
    if api_type == "wrap_into_iterator" and input_type == "input_fn":
 | 
			
		||||
      self.skipTest("unsupported test combination.")
 | 
			
		||||
 | 
			
		||||
    devices = nest.flatten([ds for _, ds in worker_device_pairs])
 | 
			
		||||
    input_workers = input_lib.InputWorkers(worker_device_pairs)
 | 
			
		||||
 | 
			
		||||
@ -161,7 +166,7 @@ class DistributedIteratorTestBase(test.TestCase):
 | 
			
		||||
          strategy,
 | 
			
		||||
          input_context=input_context)
 | 
			
		||||
 | 
			
		||||
      if context.executing_eagerly():
 | 
			
		||||
      if ops.executing_eagerly_outside_functions():
 | 
			
		||||
        iterator = iter(dataset)
 | 
			
		||||
      else:
 | 
			
		||||
        if isinstance(dataset, input_lib.DistributedDatasetV1):
 | 
			
		||||
@ -171,7 +176,7 @@ class DistributedIteratorTestBase(test.TestCase):
 | 
			
		||||
 | 
			
		||||
    if iteration_type == "get_next":
 | 
			
		||||
      evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
 | 
			
		||||
      if isinstance(iterator, input_lib.DistributedIteratorV1):
 | 
			
		||||
      if not ops.executing_eagerly_outside_functions():
 | 
			
		||||
        evaluate(control_flow_ops.group(iterator.initializer))
 | 
			
		||||
 | 
			
		||||
      for expected_value in expected_values:
 | 
			
		||||
@ -190,10 +195,13 @@ class DistributedIteratorTestBase(test.TestCase):
 | 
			
		||||
                                   next_element) for r in range(len(devices))])
 | 
			
		||||
 | 
			
		||||
      # After re-initializing the iterator, should be able to iterate again.
 | 
			
		||||
      if isinstance(iterator, input_lib.DistributedIteratorV1):
 | 
			
		||||
      if not ops.executing_eagerly_outside_functions():
 | 
			
		||||
        evaluate(control_flow_ops.group(iterator.initializer))
 | 
			
		||||
      else:
 | 
			
		||||
        evaluate(control_flow_ops.group(iterator._initializer))
 | 
			
		||||
        if api_type == "wrap_into_iterator":
 | 
			
		||||
          self.skipTest("unsupported test combination")
 | 
			
		||||
        else:
 | 
			
		||||
          iterator = iter(dataset)
 | 
			
		||||
 | 
			
		||||
      for expected_value in expected_values:
 | 
			
		||||
        next_element = iterator.get_next()
 | 
			
		||||
@ -225,6 +233,48 @@ class DistributedIteratorTestBase(test.TestCase):
 | 
			
		||||
class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
 | 
			
		||||
                                          parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
          input_type=["input_fn", "dataset"],
 | 
			
		||||
          distribution=[
 | 
			
		||||
              strategy_combinations.one_device_strategy,
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_one_cpu,
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_gpu_and_cpu
 | 
			
		||||
          ]))
 | 
			
		||||
  def testDisablingOwnedIteratorsInTF2(self, distribution, input_type):
 | 
			
		||||
    if not tf2.enabled():
 | 
			
		||||
      self.skipTest("unsupported test combination")
 | 
			
		||||
 | 
			
		||||
    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
 | 
			
		||||
    input_workers = input_lib.InputWorkers(worker_device_pairs)
 | 
			
		||||
    dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
 | 
			
		||||
    dataset_or_input_fn = self._create_dataset_or_input_fn(
 | 
			
		||||
        input_type, dataset_fn)
 | 
			
		||||
 | 
			
		||||
    input_workers = input_lib.InputWorkers(worker_device_pairs)
 | 
			
		||||
    if input_type == "dataset":
 | 
			
		||||
      dist_dataset = input_lib.get_distributed_dataset(dataset_or_input_fn,
 | 
			
		||||
                                                       input_workers,
 | 
			
		||||
                                                       distribution)
 | 
			
		||||
    else:
 | 
			
		||||
      dist_dataset = input_lib.get_distributed_datasets_from_function(
 | 
			
		||||
          dataset_or_input_fn, input_workers, [distribute_lib.InputContext()],
 | 
			
		||||
          distribution)
 | 
			
		||||
 | 
			
		||||
    # Default Iterator types in TF2.
 | 
			
		||||
    iterator = iter(dist_dataset)
 | 
			
		||||
    self.assertIsInstance(iterator, input_lib.DistributedIterator)
 | 
			
		||||
    self.assertIsInstance(iterator._iterators[0],
 | 
			
		||||
                          input_lib._SingleWorkerOwnedDatasetIterator)
 | 
			
		||||
 | 
			
		||||
    # Disable creating owned iterators by setting a property on the strategy.
 | 
			
		||||
    distribution._enable_legacy_iterators = True
 | 
			
		||||
    iterator = iter(dist_dataset)
 | 
			
		||||
    self.assertIsInstance(iterator, input_lib.DistributedIteratorV1)
 | 
			
		||||
    self.assertIsInstance(iterator._iterators[0],
 | 
			
		||||
                          input_lib._SingleWorkerDatasetIterator)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
@ -234,7 +284,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
 | 
			
		||||
  def testMultiDeviceIterInitialize(self, distribution):
 | 
			
		||||
    if tf2.enabled():
 | 
			
		||||
      self.skipTest("Only V1 is supported.")
 | 
			
		||||
    worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
 | 
			
		||||
    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
 | 
			
		||||
                                              "/device:CPU:0"])]
 | 
			
		||||
    dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
 | 
			
		||||
 | 
			
		||||
    input_workers = input_lib.InputWorkers(worker_device_pairs)
 | 
			
		||||
@ -250,25 +301,6 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
 | 
			
		||||
 | 
			
		||||
    init_func_for_iter()
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["graph"],
 | 
			
		||||
          distribution=[
 | 
			
		||||
              strategy_combinations.one_device_strategy,
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_one_cpu
 | 
			
		||||
          ]))
 | 
			
		||||
  def testDatasetV2IterError(self, distribution):
 | 
			
		||||
    worker_device_pairs = [("", ["/device:CPU:0"])]
 | 
			
		||||
    input_workers = input_lib.InputWorkers(worker_device_pairs)
 | 
			
		||||
    dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
 | 
			
		||||
 | 
			
		||||
    dist_dataset = input_lib.get_distributed_dataset(
 | 
			
		||||
        dataset_fn(distribute_lib.InputContext()), input_workers, distribution)
 | 
			
		||||
 | 
			
		||||
    with self.assertRaisesRegexp(RuntimeError,
 | 
			
		||||
                                 "or when eager execution is enabled"):
 | 
			
		||||
      iter(dist_dataset)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["graph", "eager"],
 | 
			
		||||
@ -282,11 +314,11 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
 | 
			
		||||
          enable_get_next_as_optional=[True, False]))
 | 
			
		||||
  def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution,
 | 
			
		||||
                       enable_get_next_as_optional):
 | 
			
		||||
    worker_device_pairs = [("", ["/device:CPU:0"])]
 | 
			
		||||
    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
 | 
			
		||||
    if tf2.enabled():
 | 
			
		||||
      dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
 | 
			
		||||
    else:
 | 
			
		||||
      dataset_fn = lambda _: dataset_ops.Dataset.range(10)
 | 
			
		||||
      dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
 | 
			
		||||
    dataset_or_input_fn = self._create_dataset_or_input_fn(
 | 
			
		||||
        input_type, dataset_fn)
 | 
			
		||||
 | 
			
		||||
@ -316,7 +348,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
 | 
			
		||||
          enable_get_next_as_optional=[True, False]))
 | 
			
		||||
  def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type,
 | 
			
		||||
                                 distribution, enable_get_next_as_optional):
 | 
			
		||||
    worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
 | 
			
		||||
    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
 | 
			
		||||
                                              "/device:CPU:0"])]
 | 
			
		||||
    if tf2.enabled():
 | 
			
		||||
      dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
 | 
			
		||||
    else:
 | 
			
		||||
@ -386,7 +419,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
 | 
			
		||||
          enable_get_next_as_optional=[True, False]))
 | 
			
		||||
  def testTupleDataset(self, input_type, api_type, iteration_type, distribution,
 | 
			
		||||
                       enable_get_next_as_optional):
 | 
			
		||||
    worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
 | 
			
		||||
    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
 | 
			
		||||
                                              "/device:CPU:0"])]
 | 
			
		||||
 | 
			
		||||
    def dataset_fn(ctx):
 | 
			
		||||
      del ctx
 | 
			
		||||
@ -422,7 +456,7 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_one_cpu
 | 
			
		||||
          ]))
 | 
			
		||||
  def testIterableIterator(self, distribution):
 | 
			
		||||
    worker_device_pairs = [("", ["/device:CPU:0"])]
 | 
			
		||||
    worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
 | 
			
		||||
    input_workers = input_lib.InputWorkers(worker_device_pairs)
 | 
			
		||||
 | 
			
		||||
    dataset = dataset_ops.DatasetV2.range(10)
 | 
			
		||||
@ -446,7 +480,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
 | 
			
		||||
          ]))
 | 
			
		||||
  def testUnevenDatasetBatches(self, input_type, api_type, iteration_type,
 | 
			
		||||
                               drop_remainder, distribution):
 | 
			
		||||
    worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
 | 
			
		||||
    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
 | 
			
		||||
                                              "/device:CPU:0"])]
 | 
			
		||||
    if tf2.enabled():
 | 
			
		||||
      dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch(  # pylint: disable=g-long-lambda
 | 
			
		||||
          2, drop_remainder=drop_remainder)
 | 
			
		||||
@ -486,7 +521,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
 | 
			
		||||
  def testBatchSplitting(self, input_type, api_type, iteration_type,
 | 
			
		||||
                         split_batch_by, distribution,
 | 
			
		||||
                         enable_get_next_as_optional):
 | 
			
		||||
    worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
 | 
			
		||||
    worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
 | 
			
		||||
                                              "/device:CPU:0"])]
 | 
			
		||||
    batch_size = 10
 | 
			
		||||
    if tf2.enabled():
 | 
			
		||||
      dataset_fn = lambda _: dataset_ops.DatasetV2.range(100).batch(batch_size)
 | 
			
		||||
@ -1075,68 +1111,5 @@ class DistributedIteratorMultiWorkerTest(
 | 
			
		||||
          strategy,
 | 
			
		||||
          sess=sess)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
          distribution=[
 | 
			
		||||
              strategy_combinations.one_device_strategy,
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_one_cpu,
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
 | 
			
		||||
              strategy_combinations.tpu_strategy,
 | 
			
		||||
              strategy_combinations.central_storage_strategy_with_two_gpus,
 | 
			
		||||
          ],
 | 
			
		||||
          input_type=["dataset", "dataset_fn"],
 | 
			
		||||
      ))
 | 
			
		||||
  def testInputSignatureForPerReplicaValues(self, distribution, input_type):
 | 
			
		||||
    def dataset_fn(ctx):
 | 
			
		||||
      del ctx  # unused
 | 
			
		||||
      return dataset_ops.DatasetV2.from_tensor_slices(
 | 
			
		||||
          np.ones([10, 12]).astype(np.float32)).batch(4)
 | 
			
		||||
 | 
			
		||||
    if input_type == "dataset":
 | 
			
		||||
      ds = distribution.experimental_distribute_dataset(
 | 
			
		||||
          dataset_fn(distribute_lib.InputContext()))
 | 
			
		||||
      type_spec = ds.element_spec
 | 
			
		||||
    else:
 | 
			
		||||
      ds = distribution.experimental_distribute_datasets_from_function(
 | 
			
		||||
          dataset_fn)
 | 
			
		||||
      iterator = iter(ds)
 | 
			
		||||
      type_spec = iterator.element_spec
 | 
			
		||||
 | 
			
		||||
    @def_function.function(input_signature=[type_spec])
 | 
			
		||||
    def process_inputs(inputs):
 | 
			
		||||
      distribution.run(lambda inputs: inputs, args=(inputs,))
 | 
			
		||||
 | 
			
		||||
    for x in ds:
 | 
			
		||||
      process_inputs(x)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
          distribution=[
 | 
			
		||||
              strategy_combinations.one_device_strategy,
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_one_cpu,
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
 | 
			
		||||
              strategy_combinations.tpu_strategy,
 | 
			
		||||
              strategy_combinations.central_storage_strategy_with_two_gpus,
 | 
			
		||||
          ],
 | 
			
		||||
      ))
 | 
			
		||||
  def testInputSignatureForNestedPerReplicaValues(self, distribution):
 | 
			
		||||
    a = np.ones((10, 2)) * 5
 | 
			
		||||
    b = np.ones((10, 3)) * 6
 | 
			
		||||
    dataset = dataset_ops.DatasetV2.from_tensor_slices((a, b)).batch(2)
 | 
			
		||||
 | 
			
		||||
    dist_dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
 | 
			
		||||
    @def_function.function(input_signature=[dist_dataset.element_spec])
 | 
			
		||||
    def process_inputs(inputs):
 | 
			
		||||
      distribution.run(lambda inputs: inputs, args=(inputs,))
 | 
			
		||||
 | 
			
		||||
    for x in dist_dataset:
 | 
			
		||||
      process_inputs(x)
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
  test.main()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										366
									
								
								tensorflow/python/distribute/input_lib_type_spec_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										366
									
								
								tensorflow/python/distribute/input_lib_type_spec_test.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,366 @@
 | 
			
		||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
# ==============================================================================
 | 
			
		||||
"""Tests for the input_lib library which tests iterator type specs."""
 | 
			
		||||
 | 
			
		||||
from __future__ import absolute_import
 | 
			
		||||
from __future__ import division
 | 
			
		||||
from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from tensorflow.python import tf2
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.distribute import combinations
 | 
			
		||||
from tensorflow.python.distribute import distribute_lib
 | 
			
		||||
from tensorflow.python.distribute import strategy_combinations
 | 
			
		||||
from tensorflow.python.distribute import values
 | 
			
		||||
from tensorflow.python.eager import def_function
 | 
			
		||||
from tensorflow.python.eager import test
 | 
			
		||||
from tensorflow.python.framework import dtypes
 | 
			
		||||
from tensorflow.python.framework import sparse_tensor
 | 
			
		||||
from tensorflow.python.framework import tensor_shape
 | 
			
		||||
from tensorflow.python.framework import tensor_spec
 | 
			
		||||
from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DistributedIteratorTest(test.TestCase,
 | 
			
		||||
                              parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
          input_type=["dataset"],
 | 
			
		||||
          distribution=[
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
 | 
			
		||||
              strategy_combinations.tpu_strategy,
 | 
			
		||||
          ],
 | 
			
		||||
          enable_get_next_as_optional=[True, False]))
 | 
			
		||||
  def testTypeSpec(self, input_type, distribution,
 | 
			
		||||
                   enable_get_next_as_optional):
 | 
			
		||||
    if not tf2.enabled():
 | 
			
		||||
      self.skipTest("DistributedIterator has CompositeTensor support in "
 | 
			
		||||
                    "TF 2 only.")
 | 
			
		||||
    dataset = dataset_ops.DatasetV2.range(10).batch(2)
 | 
			
		||||
 | 
			
		||||
    distribution.extended.experimental_enable_get_next_as_optional = (
 | 
			
		||||
        enable_get_next_as_optional)
 | 
			
		||||
 | 
			
		||||
    dist_dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      iterator = iter(dist_dataset)
 | 
			
		||||
 | 
			
		||||
    spec = iterator._type_spec
 | 
			
		||||
    self.assertEqual(spec._input_workers, iterator._input_workers)
 | 
			
		||||
    self.assertEqual(spec._element_spec._value_specs,
 | 
			
		||||
                     (tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64,
 | 
			
		||||
                                             name=None),
 | 
			
		||||
                      tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64,
 | 
			
		||||
                                             name=None)))
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
          input_type=["dataset"],
 | 
			
		||||
          distribution=[
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
 | 
			
		||||
              strategy_combinations.tpu_strategy,
 | 
			
		||||
          ],
 | 
			
		||||
          enable_get_next_as_optional=[True, False]))
 | 
			
		||||
  def testTypeSpecRoundTrip(self, input_type,
 | 
			
		||||
                            distribution, enable_get_next_as_optional):
 | 
			
		||||
    if not tf2.enabled():
 | 
			
		||||
      self.skipTest("DistributedIterator CompositeTensor support is only "
 | 
			
		||||
                    "present in TF 2.0 only.")
 | 
			
		||||
 | 
			
		||||
    dataset = dataset_ops.DatasetV2.range(10).batch(2)
 | 
			
		||||
 | 
			
		||||
    distribution.extended.experimental_enable_get_next_as_optional = (
 | 
			
		||||
        enable_get_next_as_optional)
 | 
			
		||||
 | 
			
		||||
    dist_dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      iterator = iter(dist_dataset)
 | 
			
		||||
 | 
			
		||||
    spec = iterator._type_spec
 | 
			
		||||
 | 
			
		||||
    tensor_list = spec._to_components(iterator)
 | 
			
		||||
    re_iterator = spec._from_components(tensor_list)
 | 
			
		||||
 | 
			
		||||
    self.assertEqual(iterator._input_workers, re_iterator._input_workers)
 | 
			
		||||
    self.assertAllEqual(iterator._iterators, re_iterator._iterators)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
          input_type=["dataset"],
 | 
			
		||||
          distribution=[
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
 | 
			
		||||
              strategy_combinations.tpu_strategy,
 | 
			
		||||
          ],
 | 
			
		||||
          enable_get_next_as_optional=[True, False]))
 | 
			
		||||
  def testDoesNotTriggerFunctionTracing(self, input_type, distribution,
 | 
			
		||||
                                        enable_get_next_as_optional):
 | 
			
		||||
    if not tf2.enabled():
 | 
			
		||||
      self.skipTest("DistributedIterator CompositeTensor support is only "
 | 
			
		||||
                    "present in TF 2.0 only.")
 | 
			
		||||
 | 
			
		||||
    trace_count = [0]
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def f(iterator):
 | 
			
		||||
      trace_count[0] += 1
 | 
			
		||||
      counter = np.int64(0)
 | 
			
		||||
      for _ in range(5):
 | 
			
		||||
        next(iterator)
 | 
			
		||||
        counter += 1
 | 
			
		||||
      return counter
 | 
			
		||||
 | 
			
		||||
    dataset = dataset_ops.DatasetV2.range(10).batch(2)
 | 
			
		||||
 | 
			
		||||
    distribution.extended.experimental_enable_get_next_as_optional = (
 | 
			
		||||
        enable_get_next_as_optional)
 | 
			
		||||
 | 
			
		||||
    dist_dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      for _ in range(3):
 | 
			
		||||
        iterator = iter(dist_dataset)
 | 
			
		||||
        counter = f(iterator)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(trace_count[0], 1)
 | 
			
		||||
        self.assertEqual(counter, 5)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
          distribution=[
 | 
			
		||||
              strategy_combinations.one_device_strategy,
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_one_cpu,
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
 | 
			
		||||
              strategy_combinations.tpu_strategy,
 | 
			
		||||
              strategy_combinations.central_storage_strategy_with_two_gpus,
 | 
			
		||||
          ],
 | 
			
		||||
          input_type=["dataset", "dataset_fn"],
 | 
			
		||||
      ))
 | 
			
		||||
  def testInputSignatureForPerReplicaValues(self, distribution, input_type):
 | 
			
		||||
    def dataset_fn(ctx):
 | 
			
		||||
      del ctx  # unused
 | 
			
		||||
      return dataset_ops.DatasetV2.from_tensor_slices(
 | 
			
		||||
          np.ones([10, 12]).astype(np.float32)).batch(4)
 | 
			
		||||
 | 
			
		||||
    if input_type == "dataset":
 | 
			
		||||
      ds = distribution.experimental_distribute_dataset(
 | 
			
		||||
          dataset_fn(distribute_lib.InputContext()))
 | 
			
		||||
      type_spec = ds.element_spec
 | 
			
		||||
    else:
 | 
			
		||||
      ds = distribution.experimental_distribute_datasets_from_function(
 | 
			
		||||
          dataset_fn)
 | 
			
		||||
      iterator = iter(ds)
 | 
			
		||||
      type_spec = iterator.element_spec
 | 
			
		||||
 | 
			
		||||
    @def_function.function(input_signature=[type_spec])
 | 
			
		||||
    def process_inputs(inputs):
 | 
			
		||||
      distribution.run(lambda inputs: inputs, args=(inputs,))
 | 
			
		||||
 | 
			
		||||
    for x in ds:
 | 
			
		||||
      process_inputs(x)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
          distribution=[
 | 
			
		||||
              strategy_combinations.one_device_strategy,
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_one_cpu,
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
 | 
			
		||||
              strategy_combinations.tpu_strategy,
 | 
			
		||||
              strategy_combinations.central_storage_strategy_with_two_gpus,
 | 
			
		||||
          ],
 | 
			
		||||
      ))
 | 
			
		||||
  def testInputSignatureForNestedPerReplicaValues(self, distribution):
 | 
			
		||||
    a = np.ones((10, 2)) * 5
 | 
			
		||||
    b = np.ones((10, 3)) * 6
 | 
			
		||||
    dataset = dataset_ops.DatasetV2.from_tensor_slices((a, b)).batch(2)
 | 
			
		||||
 | 
			
		||||
    dist_dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
 | 
			
		||||
    @def_function.function(input_signature=[dist_dataset.element_spec])
 | 
			
		||||
    def process_inputs(inputs):
 | 
			
		||||
      distribution.run(lambda inputs: inputs, args=(inputs,))
 | 
			
		||||
 | 
			
		||||
    for x in dist_dataset:
 | 
			
		||||
      process_inputs(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RaggedTensorDistributedIteratorTest(test.TestCase,
 | 
			
		||||
                                          parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
          distribution=[
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
 | 
			
		||||
          ],
 | 
			
		||||
          enable_get_next_as_optional=[True, False]))
 | 
			
		||||
  def testTypeSpec(self, distribution, enable_get_next_as_optional):
 | 
			
		||||
    if not tf2.enabled():
 | 
			
		||||
      self.skipTest("DistributedIterator has CompositeTensor support in "
 | 
			
		||||
                    "TF 2.0 only.")
 | 
			
		||||
    ctx = distribute_lib.InputContext()
 | 
			
		||||
    batch_size = ctx.get_per_replica_batch_size(8)
 | 
			
		||||
    # Use 20 which isn't divisible by 8 to test partial batch behavior.
 | 
			
		||||
    row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
 | 
			
		||||
    ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
 | 
			
		||||
        np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
 | 
			
		||||
    dataset = dataset_ops.DatasetV2.from_tensor_slices({
 | 
			
		||||
        "dense": ragged_tensor.to_tensor(),
 | 
			
		||||
        "ragged": ragged_tensor,
 | 
			
		||||
        "sparse": ragged_tensor.to_sparse(),
 | 
			
		||||
    })
 | 
			
		||||
    dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
 | 
			
		||||
    dataset = dataset.batch(batch_size)
 | 
			
		||||
 | 
			
		||||
    distribution.extended.experimental_enable_get_next_as_optional = (
 | 
			
		||||
        enable_get_next_as_optional)
 | 
			
		||||
 | 
			
		||||
    dist_dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      iterator = iter(dist_dataset)
 | 
			
		||||
 | 
			
		||||
    spec = iterator._type_spec
 | 
			
		||||
    self.assertEqual(spec._input_workers, iterator._input_workers)
 | 
			
		||||
    self.assertEqual(
 | 
			
		||||
        spec._element_spec, {
 | 
			
		||||
            "sparse":
 | 
			
		||||
                values.PerReplicaSpec(
 | 
			
		||||
                    sparse_tensor.SparseTensorSpec(
 | 
			
		||||
                        tensor_shape.TensorShape([None, 3]), dtypes.float32),
 | 
			
		||||
                    sparse_tensor.SparseTensorSpec(
 | 
			
		||||
                        tensor_shape.TensorShape([None, 3]), dtypes.float32)),
 | 
			
		||||
            "dense":
 | 
			
		||||
                values.PerReplicaSpec(
 | 
			
		||||
                    tensor_spec.TensorSpec(
 | 
			
		||||
                        shape=(None, 3), dtype=dtypes.float32, name=None),
 | 
			
		||||
                    tensor_spec.TensorSpec(
 | 
			
		||||
                        shape=(None, 3), dtype=dtypes.float32, name=None)),
 | 
			
		||||
            "ragged":
 | 
			
		||||
                values.PerReplicaSpec(
 | 
			
		||||
                    ragged_tensor_lib.RaggedTensorSpec(
 | 
			
		||||
                        tensor_shape.TensorShape([None, None]), dtypes.float32,
 | 
			
		||||
                        1, dtypes.int64),
 | 
			
		||||
                    ragged_tensor_lib.RaggedTensorSpec(
 | 
			
		||||
                        tensor_shape.TensorShape([None, None]), dtypes.float32,
 | 
			
		||||
                        1, dtypes.int64))
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
          distribution=[
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
 | 
			
		||||
              strategy_combinations.tpu_strategy,
 | 
			
		||||
          ],
 | 
			
		||||
          enable_get_next_as_optional=[True, False]))
 | 
			
		||||
  def testTypeSpecRoundTrip(self, distribution, enable_get_next_as_optional):
 | 
			
		||||
    if not tf2.enabled():
 | 
			
		||||
      self.skipTest("DistributedIterator CompositeTensor support is only "
 | 
			
		||||
                    "present in TF 2.0 only.")
 | 
			
		||||
 | 
			
		||||
    ctx = distribute_lib.InputContext()
 | 
			
		||||
    batch_size = ctx.get_per_replica_batch_size(8)
 | 
			
		||||
    # Use 20 which isn't divisible by 8 to test partial batch behavior.
 | 
			
		||||
    row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
 | 
			
		||||
    ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
 | 
			
		||||
        np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
 | 
			
		||||
    dataset = dataset_ops.DatasetV2.from_tensor_slices({
 | 
			
		||||
        "dense": ragged_tensor.to_tensor(),
 | 
			
		||||
        "ragged": ragged_tensor,
 | 
			
		||||
        "sparse": ragged_tensor.to_sparse(),
 | 
			
		||||
    })
 | 
			
		||||
    dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
 | 
			
		||||
    dataset = dataset.batch(batch_size)
 | 
			
		||||
 | 
			
		||||
    distribution.extended.experimental_enable_get_next_as_optional = (
 | 
			
		||||
        enable_get_next_as_optional)
 | 
			
		||||
 | 
			
		||||
    dist_dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      iterator = iter(dist_dataset)
 | 
			
		||||
 | 
			
		||||
    spec = iterator._type_spec
 | 
			
		||||
 | 
			
		||||
    tensor_list = spec._to_components(iterator)
 | 
			
		||||
    re_iterator = spec._from_components(tensor_list)
 | 
			
		||||
 | 
			
		||||
    self.assertEqual(iterator._input_workers, re_iterator._input_workers)
 | 
			
		||||
    self.assertAllEqual(iterator._iterators, re_iterator._iterators)
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(
 | 
			
		||||
      combinations.combine(
 | 
			
		||||
          mode=["eager"],
 | 
			
		||||
          distribution=[
 | 
			
		||||
              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
 | 
			
		||||
              strategy_combinations.tpu_strategy,
 | 
			
		||||
          ],
 | 
			
		||||
          enable_get_next_as_optional=[True, False]))
 | 
			
		||||
  def testDoesNotTriggerFunctionTracing(self, distribution,
 | 
			
		||||
                                        enable_get_next_as_optional):
 | 
			
		||||
    if not tf2.enabled():
 | 
			
		||||
      self.skipTest("DistributedIterator CompositeTensor support is only "
 | 
			
		||||
                    "present in TF 2.0 only.")
 | 
			
		||||
 | 
			
		||||
    trace_count = [0]
 | 
			
		||||
 | 
			
		||||
    @def_function.function
 | 
			
		||||
    def f(iterator):
 | 
			
		||||
      trace_count[0] += 1
 | 
			
		||||
      counter = np.int64(0)
 | 
			
		||||
      for _ in range(5):
 | 
			
		||||
        next(iterator)
 | 
			
		||||
        counter += 1
 | 
			
		||||
      return counter
 | 
			
		||||
 | 
			
		||||
    ctx = distribute_lib.InputContext()
 | 
			
		||||
    batch_size = ctx.get_per_replica_batch_size(8)
 | 
			
		||||
    # Use 20 which isn't divisible by 8 to test partial batch behavior.
 | 
			
		||||
    row_lengths = np.mod(np.arange(50), 4).astype(np.int64)
 | 
			
		||||
    ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
 | 
			
		||||
        np.repeat(np.arange(50, dtype=np.float32), row_lengths), row_lengths)
 | 
			
		||||
    dataset = dataset_ops.DatasetV2.from_tensor_slices({
 | 
			
		||||
        "dense": ragged_tensor.to_tensor(),
 | 
			
		||||
        "ragged": ragged_tensor,
 | 
			
		||||
        "sparse": ragged_tensor.to_sparse(),
 | 
			
		||||
    })
 | 
			
		||||
    dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
 | 
			
		||||
    dataset = dataset.batch(batch_size)
 | 
			
		||||
 | 
			
		||||
    distribution.extended.experimental_enable_get_next_as_optional = (
 | 
			
		||||
        enable_get_next_as_optional)
 | 
			
		||||
 | 
			
		||||
    dist_dataset = distribution.experimental_distribute_dataset(dataset)
 | 
			
		||||
    with distribution.scope():
 | 
			
		||||
      for _ in range(3):
 | 
			
		||||
        iterator = iter(dist_dataset)
 | 
			
		||||
        counter = f(iterator)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(trace_count[0], 1)
 | 
			
		||||
        self.assertEqual(counter, 5)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
  test.main()
 | 
			
		||||
@ -1161,12 +1161,7 @@ class DataHandler(object):
 | 
			
		||||
        if self._insufficient_data:  # Set by `catch_stop_iteration`.
 | 
			
		||||
          break
 | 
			
		||||
        if self._adapter.should_recreate_iterator():
 | 
			
		||||
          if ds_context.has_strategy():
 | 
			
		||||
            # TODO(b/138326910): remove this when MultiDeviceIterator is a
 | 
			
		||||
            # CompositeTensor (unless this is more efficient)
 | 
			
		||||
            data_iterator._initializer  # pylint: disable=pointless-statement, protected-access
 | 
			
		||||
          else:
 | 
			
		||||
            data_iterator = iter(self._dataset)
 | 
			
		||||
          data_iterator = iter(self._dataset)
 | 
			
		||||
        yield epoch, data_iterator
 | 
			
		||||
        self._adapter.on_epoch_end()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user