Create a DistributedDataset/DistributedDatasetV1 based on TF2/1.x mode.
PiperOrigin-RevId: 306391431 Change-Id: Ib4585c757e2770e0eb9ce5692af8a837a609602c
This commit is contained in:
parent
b2418b7fee
commit
0c44461c7c
tensorflow/python/distribute
@ -22,6 +22,7 @@ import sys
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import distribute
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
@ -53,13 +54,16 @@ def get_distributed_dataset(dataset,
|
||||
strategy,
|
||||
split_batch_by=None,
|
||||
input_context=None):
|
||||
"""Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
|
||||
"""Returns a distributed dataset from the given tf.data.Dataset instance.
|
||||
|
||||
This is a common function that is used by all strategies to return the right
|
||||
tf.data.Dataset wrapped instance depending on the `dataset` argument type.
|
||||
This is a common function that is used by all strategies to return a
|
||||
distributed dataset. The distributed dataset instance returned is different
|
||||
depending on if we are in a TF 1 or TF 2 context. The distributed dataset
|
||||
instances returned differ from each other in the APIs supported by each of
|
||||
them.
|
||||
|
||||
Args:
|
||||
dataset: a tf.data.DatasetV1 or tf.data.DatasetV2 instance.
|
||||
dataset: a tf.data.Dataset instance.
|
||||
input_workers: an InputWorkers object which specifies devices on which
|
||||
iterators should be created.
|
||||
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
|
||||
@ -72,17 +76,17 @@ def get_distributed_dataset(dataset,
|
||||
`num_input_pipelines` in the `InputContext`.
|
||||
|
||||
Returns:
|
||||
A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
|
||||
A distributed dataset instance.
|
||||
"""
|
||||
if isinstance(dataset, dataset_ops.DatasetV1):
|
||||
return DistributedDatasetV1(
|
||||
if tf2.enabled():
|
||||
return DistributedDataset(
|
||||
dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
split_batch_by=split_batch_by,
|
||||
input_context=input_context)
|
||||
else:
|
||||
return DistributedDataset(
|
||||
return DistributedDatasetV1(
|
||||
dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
@ -94,15 +98,16 @@ def get_distributed_datasets_from_function(dataset_fn,
|
||||
input_workers,
|
||||
input_contexts,
|
||||
strategy):
|
||||
"""Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
|
||||
"""Returns a distributed dataset from the given input function.
|
||||
|
||||
This is a common function that is used by all strategies to return the right
|
||||
tf.data.Dataset wrapped instance depending on if we are in graph or eager
|
||||
mode.
|
||||
This is a common function that is used by all strategies to return a
|
||||
distributed dataset. The distributed dataset instance returned is different
|
||||
depending on if we are in a TF 1 or TF 2 context. The distributed dataset
|
||||
instances returned differ from each other in the APIs supported by each of
|
||||
them.
|
||||
|
||||
Args:
|
||||
dataset_fn: a function that returns a tf.data.DatasetV1 or tf.data.DatasetV2
|
||||
instance.
|
||||
dataset_fn: a function that returns a tf.data.Dataset instance.
|
||||
input_workers: an InputWorkers object which specifies devices on which
|
||||
iterators should be created.
|
||||
input_contexts: A list of `InputContext` instances to be passed to call(s)
|
||||
@ -112,9 +117,9 @@ def get_distributed_datasets_from_function(dataset_fn,
|
||||
handle last partial batch.
|
||||
|
||||
Returns:
|
||||
A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
|
||||
A distributed dataset instance.
|
||||
"""
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
if tf2.enabled():
|
||||
return DistributedDatasetsFromFunction(
|
||||
dataset_fn,
|
||||
input_workers,
|
||||
@ -372,7 +377,7 @@ class DistributedIterator(object):
|
||||
|
||||
|
||||
class DistributedIteratorV1(DistributedIterator):
|
||||
"""Input Iterator for tf.data.DatasetV1."""
|
||||
"""Input Iterator for a distributed dataset instance."""
|
||||
|
||||
@deprecated(None, "Use the iterator's `initializer` property instead.")
|
||||
def initialize(self):
|
||||
@ -451,7 +456,7 @@ class _IterableInput(object):
|
||||
|
||||
|
||||
class DistributedDataset(_IterableInput):
|
||||
"""Wrapped tf.data.DatasetV2 that supports prefetching to multiple devices."""
|
||||
"""Distributed dataset that supports prefetching to multiple devices."""
|
||||
|
||||
def __init__(self,
|
||||
dataset,
|
||||
@ -555,7 +560,7 @@ class DistributedDataset(_IterableInput):
|
||||
|
||||
|
||||
class DistributedDatasetV1(DistributedDataset):
|
||||
"""Wrapped tf.data.DatasetV1 that supports prefetching to multiple devices."""
|
||||
"""Distributed dataset that supports prefetching to multiple devices."""
|
||||
|
||||
def __init__(self,
|
||||
dataset,
|
||||
|
@ -25,6 +25,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.autograph.core import converter_testing
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
@ -157,6 +158,9 @@ class MirroredTwoDeviceDistributionTest(
|
||||
self.set_v2_tensorshape(original_v2)
|
||||
|
||||
def testReplicateDataset(self, distribution):
|
||||
if tf2.enabled() and not context.executing_eagerly():
|
||||
self.skipTest("Skipping test since we do not support graph mode in TF 2")
|
||||
|
||||
dataset_fn = lambda: dataset_ops.Dataset.range(10)
|
||||
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
||||
input_fn = self._input_fn_to_test_input_context(
|
||||
|
@ -17,6 +17,7 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
@ -49,6 +50,8 @@ class OneDeviceStrategyTest(
|
||||
self._test_call_and_merge_exceptions(distribution)
|
||||
|
||||
def testReplicateDataset(self, distribution):
|
||||
if tf2.enabled() and not context.executing_eagerly():
|
||||
self.skipTest("Skipping test since we do not support graph mode in TF 2")
|
||||
dataset_fn = lambda: dataset_ops.Dataset.range(10)
|
||||
expected_values = [[i] for i in range(10)]
|
||||
input_fn = self._input_fn_to_test_input_context(
|
||||
|
Loading…
Reference in New Issue
Block a user