From 0c44461c7c9fd3f1960dc008699c8996f8e0d149 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Tue, 14 Apr 2020 01:00:16 -0700 Subject: [PATCH] Create a DistributedDataset/DistributedDatasetV1 based on TF2/1.x mode. PiperOrigin-RevId: 306391431 Change-Id: Ib4585c757e2770e0eb9ce5692af8a837a609602c --- tensorflow/python/distribute/input_lib.py | 43 +++++++++++-------- .../distribute/mirrored_strategy_test.py | 4 ++ .../distribute/one_device_strategy_test.py | 3 ++ 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 163f775cc93..6cf6bd0db26 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -22,6 +22,7 @@ import sys import six +from tensorflow.python import tf2 from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import distribute from tensorflow.python.data.ops import dataset_ops @@ -53,13 +54,16 @@ def get_distributed_dataset(dataset, strategy, split_batch_by=None, input_context=None): - """Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance. + """Returns a distributed dataset from the given tf.data.Dataset instance. - This is a common function that is used by all strategies to return the right - tf.data.Dataset wrapped instance depending on the `dataset` argument type. + This is a common function that is used by all strategies to return a + distributed dataset. The distributed dataset instance returned is different + depending on if we are in a TF 1 or TF 2 context. The distributed dataset + instances returned differ from each other in the APIs supported by each of + them. Args: - dataset: a tf.data.DatasetV1 or tf.data.DatasetV2 instance. + dataset: a tf.data.Dataset instance. input_workers: an InputWorkers object which specifies devices on which iterators should be created. strategy: a `tf.distribute.Strategy` object, used to run all-reduce to @@ -72,17 +76,17 @@ def get_distributed_dataset(dataset, `num_input_pipelines` in the `InputContext`. Returns: - A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance. + A distributed dataset instance. """ - if isinstance(dataset, dataset_ops.DatasetV1): - return DistributedDatasetV1( + if tf2.enabled(): + return DistributedDataset( dataset, input_workers, strategy, split_batch_by=split_batch_by, input_context=input_context) else: - return DistributedDataset( + return DistributedDatasetV1( dataset, input_workers, strategy, @@ -94,15 +98,16 @@ def get_distributed_datasets_from_function(dataset_fn, input_workers, input_contexts, strategy): - """Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance. + """Returns a distributed dataset from the given input function. - This is a common function that is used by all strategies to return the right - tf.data.Dataset wrapped instance depending on if we are in graph or eager - mode. + This is a common function that is used by all strategies to return a + distributed dataset. The distributed dataset instance returned is different + depending on if we are in a TF 1 or TF 2 context. The distributed dataset + instances returned differ from each other in the APIs supported by each of + them. Args: - dataset_fn: a function that returns a tf.data.DatasetV1 or tf.data.DatasetV2 - instance. + dataset_fn: a function that returns a tf.data.Dataset instance. input_workers: an InputWorkers object which specifies devices on which iterators should be created. input_contexts: A list of `InputContext` instances to be passed to call(s) @@ -112,9 +117,9 @@ def get_distributed_datasets_from_function(dataset_fn, handle last partial batch. Returns: - A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance. + A distributed dataset instance. """ - if ops.executing_eagerly_outside_functions(): + if tf2.enabled(): return DistributedDatasetsFromFunction( dataset_fn, input_workers, @@ -372,7 +377,7 @@ class DistributedIterator(object): class DistributedIteratorV1(DistributedIterator): - """Input Iterator for tf.data.DatasetV1.""" + """Input Iterator for a distributed dataset instance.""" @deprecated(None, "Use the iterator's `initializer` property instead.") def initialize(self): @@ -451,7 +456,7 @@ class _IterableInput(object): class DistributedDataset(_IterableInput): - """Wrapped tf.data.DatasetV2 that supports prefetching to multiple devices.""" + """Distributed dataset that supports prefetching to multiple devices.""" def __init__(self, dataset, @@ -555,7 +560,7 @@ class DistributedDataset(_IterableInput): class DistributedDatasetV1(DistributedDataset): - """Wrapped tf.data.DatasetV1 that supports prefetching to multiple devices.""" + """Distributed dataset that supports prefetching to multiple devices.""" def __init__(self, dataset, diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py index 2eb2191ad48..023addd928c 100644 --- a/tensorflow/python/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/distribute/mirrored_strategy_test.py @@ -25,6 +25,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import tf2 from tensorflow.python.autograph.core import converter_testing from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations @@ -157,6 +158,9 @@ class MirroredTwoDeviceDistributionTest( self.set_v2_tensorshape(original_v2) def testReplicateDataset(self, distribution): + if tf2.enabled() and not context.executing_eagerly(): + self.skipTest("Skipping test since we do not support graph mode in TF 2") + dataset_fn = lambda: dataset_ops.Dataset.range(10) expected_values = [[i, i+1] for i in range(0, 10, 2)] input_fn = self._input_fn_to_test_input_context( diff --git a/tensorflow/python/distribute/one_device_strategy_test.py b/tensorflow/python/distribute/one_device_strategy_test.py index f825c5e1f9e..0e6f81df1f9 100644 --- a/tensorflow/python/distribute/one_device_strategy_test.py +++ b/tensorflow/python/distribute/one_device_strategy_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python import tf2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations @@ -49,6 +50,8 @@ class OneDeviceStrategyTest( self._test_call_and_merge_exceptions(distribution) def testReplicateDataset(self, distribution): + if tf2.enabled() and not context.executing_eagerly(): + self.skipTest("Skipping test since we do not support graph mode in TF 2") dataset_fn = lambda: dataset_ops.Dataset.range(10) expected_values = [[i] for i in range(10)] input_fn = self._input_fn_to_test_input_context(