Create a DistributedDataset/DistributedDatasetV1 based on TF2/1.x mode.

PiperOrigin-RevId: 306391431
Change-Id: Ib4585c757e2770e0eb9ce5692af8a837a609602c
This commit is contained in:
Anjali Sridhar 2020-04-14 01:00:16 -07:00 committed by TensorFlower Gardener
parent b2418b7fee
commit 0c44461c7c
3 changed files with 31 additions and 19 deletions

View File

@ -22,6 +22,7 @@ import sys
import six
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.ops import dataset_ops
@ -53,13 +54,16 @@ def get_distributed_dataset(dataset,
strategy,
split_batch_by=None,
input_context=None):
"""Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
"""Returns a distributed dataset from the given tf.data.Dataset instance.
This is a common function that is used by all strategies to return the right
tf.data.Dataset wrapped instance depending on the `dataset` argument type.
This is a common function that is used by all strategies to return a
distributed dataset. The distributed dataset instance returned is different
depending on if we are in a TF 1 or TF 2 context. The distributed dataset
instances returned differ from each other in the APIs supported by each of
them.
Args:
dataset: a tf.data.DatasetV1 or tf.data.DatasetV2 instance.
dataset: a tf.data.Dataset instance.
input_workers: an InputWorkers object which specifies devices on which
iterators should be created.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
@ -72,17 +76,17 @@ def get_distributed_dataset(dataset,
`num_input_pipelines` in the `InputContext`.
Returns:
A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
A distributed dataset instance.
"""
if isinstance(dataset, dataset_ops.DatasetV1):
return DistributedDatasetV1(
if tf2.enabled():
return DistributedDataset(
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
input_context=input_context)
else:
return DistributedDataset(
return DistributedDatasetV1(
dataset,
input_workers,
strategy,
@ -94,15 +98,16 @@ def get_distributed_datasets_from_function(dataset_fn,
input_workers,
input_contexts,
strategy):
"""Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
"""Returns a distributed dataset from the given input function.
This is a common function that is used by all strategies to return the right
tf.data.Dataset wrapped instance depending on if we are in graph or eager
mode.
This is a common function that is used by all strategies to return a
distributed dataset. The distributed dataset instance returned is different
depending on if we are in a TF 1 or TF 2 context. The distributed dataset
instances returned differ from each other in the APIs supported by each of
them.
Args:
dataset_fn: a function that returns a tf.data.DatasetV1 or tf.data.DatasetV2
instance.
dataset_fn: a function that returns a tf.data.Dataset instance.
input_workers: an InputWorkers object which specifies devices on which
iterators should be created.
input_contexts: A list of `InputContext` instances to be passed to call(s)
@ -112,9 +117,9 @@ def get_distributed_datasets_from_function(dataset_fn,
handle last partial batch.
Returns:
A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
A distributed dataset instance.
"""
if ops.executing_eagerly_outside_functions():
if tf2.enabled():
return DistributedDatasetsFromFunction(
dataset_fn,
input_workers,
@ -372,7 +377,7 @@ class DistributedIterator(object):
class DistributedIteratorV1(DistributedIterator):
"""Input Iterator for tf.data.DatasetV1."""
"""Input Iterator for a distributed dataset instance."""
@deprecated(None, "Use the iterator's `initializer` property instead.")
def initialize(self):
@ -451,7 +456,7 @@ class _IterableInput(object):
class DistributedDataset(_IterableInput):
"""Wrapped tf.data.DatasetV2 that supports prefetching to multiple devices."""
"""Distributed dataset that supports prefetching to multiple devices."""
def __init__(self,
dataset,
@ -555,7 +560,7 @@ class DistributedDataset(_IterableInput):
class DistributedDatasetV1(DistributedDataset):
"""Wrapped tf.data.DatasetV1 that supports prefetching to multiple devices."""
"""Distributed dataset that supports prefetching to multiple devices."""
def __init__(self,
dataset,

View File

@ -25,6 +25,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import tf2
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
@ -157,6 +158,9 @@ class MirroredTwoDeviceDistributionTest(
self.set_v2_tensorshape(original_v2)
def testReplicateDataset(self, distribution):
if tf2.enabled() and not context.executing_eagerly():
self.skipTest("Skipping test since we do not support graph mode in TF 2")
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i, i+1] for i in range(0, 10, 2)]
input_fn = self._input_fn_to_test_input_context(

View File

@ -17,6 +17,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
@ -49,6 +50,8 @@ class OneDeviceStrategyTest(
self._test_call_and_merge_exceptions(distribution)
def testReplicateDataset(self, distribution):
if tf2.enabled() and not context.executing_eagerly():
self.skipTest("Skipping test since we do not support graph mode in TF 2")
dataset_fn = lambda: dataset_ops.Dataset.range(10)
expected_values = [[i] for i in range(10)]
input_fn = self._input_fn_to_test_input_context(