Add tf.distribute.Strategy.experimental_distribute_dataset
to clone the dataset on different devices and return a DistributedDataset
. You can iterate on this dataset in eager mode in a pythonic fashion since we implement the __iter__ protocol for the wrapped dataset.
PiperOrigin-RevId: 243459250
This commit is contained in:
parent
fb695b89e0
commit
50eee22d86
@ -764,6 +764,24 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
distribute_py_test(
|
||||
name = "custom_training_loop_test",
|
||||
srcs = ["custom_training_loop_test.py"],
|
||||
main = "custom_training_loop_test.py",
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:multi_worker_test_base",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
"//tensorflow/python/eager:test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
distribute_py_test(
|
||||
name = "minimize_loss_test",
|
||||
srcs = ["minimize_loss_test.py"],
|
||||
|
@ -350,6 +350,12 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
num_replicas_in_sync=self._num_replicas_in_sync)
|
||||
return input_context
|
||||
|
||||
def _experimental_distribute_dataset(self, dataset):
|
||||
input_context = self._make_input_context()
|
||||
return input_lib.get_distributed_dataset(dataset, self._input_workers,
|
||||
self._num_replicas_in_sync,
|
||||
input_context=input_context)
|
||||
|
||||
def _make_dataset_iterator(self, dataset):
|
||||
"""Distributes the dataset to each local GPU."""
|
||||
input_context = self._make_input_context()
|
||||
|
114
tensorflow/python/distribute/custom_training_loop_test.py
Normal file
114
tensorflow/python/distribute/custom_training_loop_test.py
Normal file
@ -0,0 +1,114 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for custom training loops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
|
||||
|
||||
class InputIterationTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.strategies_minus_tpu,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testFullEager(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
|
||||
def train_step(data):
|
||||
return data
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
results = []
|
||||
for x in dist_dataset:
|
||||
output = distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(train_step, args=(x,)))
|
||||
results.append(output)
|
||||
self._validate_outputs(results)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.strategies_minus_tpu,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testStepInFunction(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
|
||||
@def_function.function
|
||||
def train_step(data):
|
||||
return data
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
results = []
|
||||
for x in dist_dataset:
|
||||
output = distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(train_step, args=(x,)))
|
||||
results.append(output)
|
||||
self._validate_outputs(results)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.strategies_minus_tpu +
|
||||
[strategy_combinations.tpu_strategy_one_step],
|
||||
mode=["eager"]
|
||||
))
|
||||
def testRunInFunction(self, distribution):
|
||||
dataset = self._get_dataset()
|
||||
|
||||
def train_step(data):
|
||||
return data
|
||||
|
||||
@def_function.function
|
||||
def f_train_step(input_data):
|
||||
return distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(train_step, args=(input_data,)))
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
results = []
|
||||
for x in dist_dataset:
|
||||
output = f_train_step(x)
|
||||
results.append(output)
|
||||
self._validate_outputs(results)
|
||||
|
||||
def _get_dataset(self):
|
||||
if tf2.enabled():
|
||||
return dataset_ops.DatasetV2.range(10).batch(2)
|
||||
else:
|
||||
return dataset_ops.Dataset.range(10).batch(2)
|
||||
|
||||
def _validate_outputs(self, actual_results):
|
||||
expected_results = [[i, i+1] for i in range(0, 10, 2)]
|
||||
self.assertEqual(len(expected_results), len(actual_results))
|
||||
|
||||
for i, expected_result in enumerate(expected_results):
|
||||
final_result = []
|
||||
actual_result = actual_results[i]
|
||||
for val in actual_result:
|
||||
final_result.extend(val.numpy())
|
||||
self.assertAllEqual(expected_result, final_result)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -377,6 +377,39 @@ class Strategy(object):
|
||||
args = (input_iterator.get_next(),) if input_iterator is not None else ()
|
||||
return self.experimental_run_v2(fn, args=args)
|
||||
|
||||
def experimental_distribute_dataset(self, dataset):
|
||||
"""Distributes a tf.data.Dataset instance provided via `dataset`.
|
||||
|
||||
Data from the given dataset will be distributed evenly across all the
|
||||
compute replicas. This function assumes that the input dataset is batched
|
||||
by the global batch size.
|
||||
|
||||
The following is an example:
|
||||
|
||||
```python
|
||||
strategy = tf.distribute.MirroredStrategy()
|
||||
|
||||
# Create a dataset
|
||||
dataset = dataset_ops.Dataset.range(10).batch(2)
|
||||
|
||||
# Distribute that dataset
|
||||
dist_dataset = strategy.experimental_distribute_dataset(dataset)
|
||||
# Iterate over the distributed dataset
|
||||
for x in dist_dataset:
|
||||
# process dataset elements
|
||||
strategy.experimental_run_v2(train_step, args=(x,))
|
||||
```
|
||||
|
||||
Args:
|
||||
dataset: `tf.data.Dataset` that will be distributed evenly across all
|
||||
replicas.
|
||||
|
||||
Returns:
|
||||
A `DistributedDataset` which returns inputs for each step of the
|
||||
computation.
|
||||
"""
|
||||
return self._extended._experimental_distribute_dataset(dataset) # pylint: disable=protected-access
|
||||
|
||||
def experimental_run_v2(self, fn, args=(), kwargs=None):
|
||||
"""Runs ops in `fn` on each replica, with the given arguments.
|
||||
|
||||
@ -1062,6 +1095,9 @@ class StrategyExtendedV2(object):
|
||||
def _make_input_fn_iterator(self, input_fn, replication_mode):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def _experimental_distribute_dataset(self, dataset):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def _reduce(self, reduce_op, value):
|
||||
# Default implementation until we have an implementation for each strategy.
|
||||
return self._local_results(
|
||||
@ -1671,6 +1707,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
|
||||
def variable_created_in_scope(self, v):
|
||||
return v._distribute_strategy is None # pylint: disable=protected-access
|
||||
|
||||
def _experimental_distribute_dataset(self, dataset):
|
||||
return dataset
|
||||
|
||||
def _make_dataset_iterator(self, dataset):
|
||||
return _DefaultDistributionExtended.DefaultInputIterator(dataset)
|
||||
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
@ -38,6 +39,35 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def get_distributed_dataset(dataset, input_workers, split_batch_by=None,
|
||||
input_context=None):
|
||||
"""Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
|
||||
|
||||
This is a common function that is used by all strategies to return the right
|
||||
tf.data.Dataset wrapped instance depending on the `dataset` argument type.
|
||||
|
||||
Args:
|
||||
dataset: a tf.data.DatasetV1 or tf.data.DatasetV2 instance.
|
||||
input_workers: an InputWorkers object which specifies devices on which
|
||||
iterators should be created.
|
||||
split_batch_by: Optional integer. If present, we "split" each batch of the
|
||||
dataset by `split_batch_by` value.
|
||||
input_context: `InputContext` for sharding. Only pass this in for between
|
||||
graph multi-worker cases where there is only one `input_worker`. In
|
||||
these cases, we will shard based on the `input_pipeline_id` and
|
||||
`num_input_pipelines` in the `InputContext`.
|
||||
|
||||
Returns:
|
||||
A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance.
|
||||
"""
|
||||
if isinstance(dataset, dataset_ops.DatasetV1):
|
||||
return DistributedDatasetV1(dataset, input_workers, split_batch_by,
|
||||
input_context)
|
||||
else:
|
||||
return DistributedDataset(dataset, input_workers, split_batch_by,
|
||||
input_context)
|
||||
|
||||
|
||||
class InputWorkers(object):
|
||||
"""A 1-to-many mapping from input worker devices to compute devices."""
|
||||
|
||||
@ -95,31 +125,7 @@ class InputWorkers(object):
|
||||
self.__class__.__name__, debug_repr, self._device_map)
|
||||
|
||||
|
||||
class InputIterator(object):
|
||||
"""An input iterator, intended to be passed to `DistributionStrategy.run`."""
|
||||
|
||||
def get_next(self):
|
||||
"""Returns the next inputs for all replicas."""
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize the underlying input dataset, when applicable.
|
||||
|
||||
In eager mode, this will create a new iterator and return it.
|
||||
In graph mode, this will initialize the same underlying iterator(s).
|
||||
|
||||
Users are required to call this if
|
||||
- This iterator was returned from a call to `make_input_fn_iterator` with an
|
||||
input function that returns a dataset.
|
||||
- Or this iterator was returned from a call to `make_dataset_iterator`.
|
||||
|
||||
Returns:
|
||||
A list of initialization ops to be executed.
|
||||
"""
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
|
||||
class InputIteratorImpl(InputIterator):
|
||||
class DistributedIterator(object):
|
||||
"""Common implementation for all input iterators."""
|
||||
|
||||
def __init__(self, input_workers, iterators, **kwargs):
|
||||
@ -127,11 +133,11 @@ class InputIteratorImpl(InputIterator):
|
||||
# be correctly handled.
|
||||
self._enable_get_next_as_optional = False
|
||||
if len(kwargs) > 1:
|
||||
raise ValueError("InputIteratorImpl constructor only takes one "
|
||||
raise ValueError("DistributedIterator constructor only takes one "
|
||||
"experimental flag now")
|
||||
if len(kwargs) == 1:
|
||||
if "_enable_get_next_as_optional" not in kwargs:
|
||||
raise ValueError("InputIteratorImpl constructor does not support "
|
||||
raise ValueError("DistributedIterator constructor does not support "
|
||||
"arguments: {}".format(kwargs))
|
||||
self._enable_get_next_as_optional = (
|
||||
kwargs["_enable_get_next_as_optional"])
|
||||
@ -143,6 +149,18 @@ class InputIteratorImpl(InputIterator):
|
||||
self._iterators = iterators
|
||||
self._input_workers = input_workers
|
||||
|
||||
def next(self):
|
||||
return self.__next__()
|
||||
|
||||
def __next__(self):
|
||||
if not context.executing_eagerly():
|
||||
raise RuntimeError("__iter__ is only supported "
|
||||
"when eager execution is enabled.")
|
||||
try:
|
||||
return self.get_next()
|
||||
except errors.OutOfRangeError:
|
||||
raise StopIteration
|
||||
|
||||
def get_next(self, name=None):
|
||||
"""Returns the next input from the iterator for all replicas."""
|
||||
if not self._enable_get_next_as_optional:
|
||||
@ -226,16 +244,33 @@ class InputIteratorImpl(InputIterator):
|
||||
|
||||
return values.regroup(self._input_workers.device_map, replicas)
|
||||
|
||||
# We need a private initializer method for re-initializing multidevice
|
||||
# iterators when used with Keras training loops. If we don't reinitialize the
|
||||
# iterator we run into memory leak issues (b/123315763).
|
||||
@property
|
||||
def _initializer(self):
|
||||
init_ops = []
|
||||
for it in self._iterators:
|
||||
init_ops.extend(it.initialize())
|
||||
return control_flow_ops.group(init_ops)
|
||||
|
||||
|
||||
class DistributedIteratorV1(DistributedIterator):
|
||||
"""Input Iterator for tf.data.DatasetV1."""
|
||||
|
||||
# TODO(anjalisridhar): Move to using `initializer` instead to be consistent
|
||||
# with tf.data iterator APIs.
|
||||
def initialize(self):
|
||||
"""Initialze underlying iterators.
|
||||
|
||||
Returns:
|
||||
A list of any initializer ops that should be run.
|
||||
"""
|
||||
init_ops = []
|
||||
for it in self._iterators:
|
||||
init_ops.extend(it.initialize())
|
||||
return init_ops
|
||||
return super(DistributedIteratorV1, self)._initializer
|
||||
|
||||
@property
|
||||
def initializer(self):
|
||||
return self.initialize()
|
||||
|
||||
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
|
||||
@property
|
||||
@ -260,7 +295,113 @@ class InputIteratorImpl(InputIterator):
|
||||
return None
|
||||
|
||||
|
||||
class InputFunctionIterator(InputIteratorImpl):
|
||||
class DistributedDataset(object):
|
||||
"""Wrapped tf.data.DatasetV2 that supports prefetching to multiple devices."""
|
||||
|
||||
def __init__(self, dataset, input_workers, split_batch_by=None,
|
||||
input_context=None, **kwargs):
|
||||
"""Distribute the dataset on all workers.
|
||||
|
||||
If `split_batch_by` is not None, we "split" each batch of the dataset by
|
||||
`split_batch_by` value.
|
||||
|
||||
Args:
|
||||
dataset: `tf.data.Dataset` that will be used as the input source.
|
||||
input_workers: an `InputWorkers` object.
|
||||
split_batch_by: Optional integer. If present, we "split" each batch of the
|
||||
dataset by `split_batch_by` value.
|
||||
input_context: `InputContext` for sharding. Only pass this in for between
|
||||
graph multi-worker cases where there is only one `input_worker`. In
|
||||
these cases, we will shard based on the `input_pipeline_id` and
|
||||
`num_input_pipelines` in the `InputContext`.
|
||||
**kwargs: Additional experimental flags. Will be removed in future.
|
||||
"""
|
||||
# We clone and shard the dataset on each worker. The current setup tries to
|
||||
# shard the dataset by files if possible so that each worker sees a
|
||||
# different subset of files. If that is not possible, will attempt to shard
|
||||
# the final input such that each worker will run the entire preprocessing
|
||||
# pipeline and only receive its own shard of the dataset.
|
||||
assert isinstance(input_workers, InputWorkers)
|
||||
if split_batch_by:
|
||||
dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access
|
||||
|
||||
self._cloned_datasets = []
|
||||
if input_context:
|
||||
# Between-graph where we rely on the input_context for sharding
|
||||
assert input_workers.num_workers == 1
|
||||
dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access
|
||||
dataset, input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
self._cloned_datasets.append(dataset)
|
||||
else:
|
||||
for i, worker in enumerate(input_workers.worker_devices):
|
||||
with ops.device(worker):
|
||||
cloned_dataset = dataset
|
||||
if not context.executing_eagerly():
|
||||
cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access
|
||||
cloned_dataset = cloned_dataset.with_options(dataset.options())
|
||||
# TODO(b/129506833): Figure out between graph cases
|
||||
cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access
|
||||
cloned_dataset, len(input_workers.worker_devices), i)
|
||||
self._cloned_datasets.append(cloned_dataset)
|
||||
|
||||
self._input_workers = input_workers
|
||||
# TODO(anjalisridhar): Identify if we need to set this property on the
|
||||
# iterator.
|
||||
self._element_structure = dataset._element_structure # pylint: disable=protected-access
|
||||
self._kwargs = kwargs
|
||||
|
||||
def __iter__(self):
|
||||
# TODO(anjalisridhar): Remove this restriction once we can create
|
||||
# iterators in graph mode.
|
||||
if context.executing_eagerly():
|
||||
worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
|
||||
self._input_workers)
|
||||
iterator = DistributedIterator(self._input_workers, worker_iterators,
|
||||
**self._kwargs)
|
||||
iterator._element_structure = self._element_structure # pylint: disable=protected-access
|
||||
return iterator
|
||||
else:
|
||||
raise RuntimeError("__iter__ is only supported when eager "
|
||||
"execution is enabled.")
|
||||
|
||||
|
||||
class DistributedDatasetV1(DistributedDataset):
|
||||
"""Wrapped tf.data.DatasetV1 that supports prefetching to multiple devices."""
|
||||
|
||||
def __init__(self, dataset, input_workers, split_batch_by=None,
|
||||
input_context=None, **kwargs):
|
||||
self._input_workers = input_workers
|
||||
super(DistributedDatasetV1, self).__init__(dataset, input_workers,
|
||||
split_batch_by=split_batch_by,
|
||||
input_context=input_context,
|
||||
**kwargs)
|
||||
|
||||
def make_one_shot_iterator(self):
|
||||
"""Get a one time use iterator for DistributedDatasetV1."""
|
||||
return self._get_iterator()
|
||||
|
||||
def make_initializable_iterator(self):
|
||||
"""Get an initializable iterator for DistributedDatasetV1."""
|
||||
# Eager mode generates already initialized iterators. Hence we cannot create
|
||||
# an initializable iterator.
|
||||
if context.executing_eagerly():
|
||||
raise ValueError("Cannot create initializable iterator in Eager mode. "
|
||||
"Please use `make_one_shot_iterator` instead.")
|
||||
return self._get_iterator()
|
||||
|
||||
def _get_iterator(self):
|
||||
worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
|
||||
self._input_workers)
|
||||
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
|
||||
**self._kwargs)
|
||||
iterator._element_structure = self._element_structure # pylint: disable=protected-access
|
||||
return iterator
|
||||
|
||||
|
||||
# TODO(anjalisridhar): This class will be soon be removed in favor of newer
|
||||
# APIs.
|
||||
class InputFunctionIterator(DistributedIteratorV1):
|
||||
"""Iterator created from input function."""
|
||||
|
||||
def __init__(self, input_fn, input_workers, input_contexts, **kwargs):
|
||||
@ -305,7 +446,9 @@ class InputFunctionIterator(InputIteratorImpl):
|
||||
input_workers, iterators, **kwargs)
|
||||
|
||||
|
||||
class DatasetIterator(InputIteratorImpl):
|
||||
# TODO(anjalisridhar): This class will soon be removed and users should move
|
||||
# to using DistributedIterator.
|
||||
class DatasetIterator(DistributedIteratorV1):
|
||||
"""Iterator created from input dataset."""
|
||||
|
||||
def __init__(self, dataset, input_workers, split_batch_by=None,
|
||||
@ -313,20 +456,7 @@ class DatasetIterator(InputIteratorImpl):
|
||||
"""Make an iterator for the dataset on given devices.
|
||||
|
||||
If `split_batch_by` is not None, we "split" each batch of the
|
||||
dataset by `split_batch_by` value. To achieve this, we first unbatch the
|
||||
input dataset and then rebatch it with the per replica batch size that is
|
||||
calculated using `global_batch_size // split_batch_by`.
|
||||
The currently supported datasets are as follows:
|
||||
`dataset.batch()` is the last operation on the dataset OR
|
||||
`dataset.apply(map_and_batch)` is the last operation on the dataset OR
|
||||
`dataset.batch().prefetch()` are the last 2 operations on the dataset OR
|
||||
`dataset.apply(map_and_batch).prefetch()` are the last 2 operations.
|
||||
|
||||
We clone and shard the dataset on each worker. The current setup tries to
|
||||
shard the dataset by files if possible so that each worker sees a different
|
||||
subset of files. If that is not possible, will attempt to shard the final
|
||||
input such that each worker will run the entire preprocessing pipeline and
|
||||
only receive its own shard of the dataset.
|
||||
dataset by `split_batch_by` value.
|
||||
|
||||
Args:
|
||||
dataset: `tf.data.Dataset` that will be used as the input source.
|
||||
@ -339,40 +469,14 @@ class DatasetIterator(InputIteratorImpl):
|
||||
`num_input_pipelines` in the `InputContext`.
|
||||
**kwargs: Additional experimental flags. Will be removed in future.
|
||||
"""
|
||||
assert isinstance(input_workers, InputWorkers)
|
||||
if split_batch_by:
|
||||
dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access
|
||||
|
||||
iterators = []
|
||||
|
||||
if input_context:
|
||||
# Between-graph where we rely on the input_context for sharding
|
||||
assert input_workers.num_workers == 1
|
||||
worker = input_workers.worker_devices[0]
|
||||
worker_devices = input_workers.compute_devices_for_worker(0)
|
||||
dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access
|
||||
dataset, input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
iterator = _SingleWorkerDatasetIterator(dataset, worker, worker_devices)
|
||||
iterators.append(iterator)
|
||||
else:
|
||||
# In-graph cases where we depend on the list of workers for sharding
|
||||
for i, worker in enumerate(input_workers.worker_devices):
|
||||
with ops.device(worker):
|
||||
worker_devices = input_workers.compute_devices_for_worker(i)
|
||||
cloned_dataset = dataset
|
||||
if not context.executing_eagerly():
|
||||
cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access
|
||||
cloned_dataset = cloned_dataset.with_options(dataset.options())
|
||||
cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access
|
||||
cloned_dataset, len(input_workers.worker_devices), i)
|
||||
iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker,
|
||||
worker_devices)
|
||||
iterators.append(iterator)
|
||||
|
||||
self._element_structure = dataset._element_structure # pylint: disable=protected-access
|
||||
|
||||
super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs)
|
||||
dist_dataset = DistributedDatasetV1(dataset, input_workers,
|
||||
split_batch_by=split_batch_by,
|
||||
input_context=input_context)
|
||||
worker_iterators = _create_iterators_per_worker(
|
||||
dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access
|
||||
super(DatasetIterator, self).__init__(input_workers, worker_iterators, # pylint: disable=protected-access
|
||||
**kwargs)
|
||||
self._element_structure = dist_dataset._element_structure # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _dummy_tensor_fn(value_structure):
|
||||
@ -554,6 +658,21 @@ class _SingleWorkerCallableIterator(object):
|
||||
return []
|
||||
|
||||
|
||||
def _create_iterators_per_worker(worker_datasets, input_workers):
|
||||
"""Create a multidevice iterator on each of the workers."""
|
||||
assert isinstance(input_workers, InputWorkers)
|
||||
|
||||
assert len(worker_datasets) == len(input_workers.worker_devices)
|
||||
iterators = []
|
||||
for i, worker in enumerate(input_workers.worker_devices):
|
||||
with ops.device(worker):
|
||||
worker_devices = input_workers.compute_devices_for_worker(i)
|
||||
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
|
||||
worker_devices)
|
||||
iterators.append(iterator)
|
||||
return iterators
|
||||
|
||||
|
||||
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
|
||||
def _get_batched_dataset(d):
|
||||
"""Get the batched dataset from `d`."""
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
@ -32,14 +33,10 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class InputIteratorTestBase(test.TestCase):
|
||||
|
||||
def _create_iterator(self, input_type, dataset_fn, worker_device_pairs,
|
||||
devices, split_batch_by,
|
||||
enable_get_next_as_optional):
|
||||
device_map = values.ReplicaDeviceMap(devices)
|
||||
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
|
||||
class DistributedIteratorTestBase(test.TestCase):
|
||||
|
||||
def _wrap_iterator(self, input_type, dataset_fn, input_workers, devices,
|
||||
split_batch_by, enable_get_next_as_optional):
|
||||
if input_type == "input_fn":
|
||||
input_contexts = []
|
||||
for i in range(input_workers.num_workers):
|
||||
@ -59,116 +56,223 @@ class InputIteratorTestBase(test.TestCase):
|
||||
_enable_get_next_as_optional=enable_get_next_as_optional)
|
||||
return iterator
|
||||
|
||||
def _test_iterator(self,
|
||||
input_type,
|
||||
dataset_fn,
|
||||
worker_device_pairs,
|
||||
expected_values,
|
||||
sess=None,
|
||||
split_batch_by=None,
|
||||
enable_get_next_as_optional=False):
|
||||
def _wrap_dataset(self, input_type, dataset, input_workers,
|
||||
split_batch_by, enable_get_next_as_optional):
|
||||
if isinstance(dataset, dataset_ops.Dataset):
|
||||
return input_lib.DistributedDatasetV1(
|
||||
dataset, input_workers,
|
||||
split_batch_by,
|
||||
_enable_get_next_as_optional=enable_get_next_as_optional)
|
||||
else:
|
||||
return input_lib.DistributedDataset(
|
||||
dataset, input_workers,
|
||||
split_batch_by,
|
||||
_enable_get_next_as_optional=enable_get_next_as_optional)
|
||||
|
||||
def _test_input_iteration(self,
|
||||
input_type,
|
||||
api_type,
|
||||
iteration_type,
|
||||
dataset_fn,
|
||||
worker_device_pairs,
|
||||
expected_values,
|
||||
sess=None,
|
||||
split_batch_by=None,
|
||||
enable_get_next_as_optional=False):
|
||||
if iteration_type == "for_loop" and not context.executing_eagerly():
|
||||
self.skipTest("unsupported test combination.")
|
||||
|
||||
if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
|
||||
self.skipTest("unsupported test combination.")
|
||||
|
||||
if api_type == "wrap_into_dataset" and input_type == "input_fn":
|
||||
self.skipTest("unsupported test combination.")
|
||||
|
||||
devices = nest.flatten([ds for _, ds in worker_device_pairs])
|
||||
iterator = self._create_iterator(
|
||||
input_type, dataset_fn, worker_device_pairs, devices, split_batch_by,
|
||||
enable_get_next_as_optional)
|
||||
device_map = values.ReplicaDeviceMap(devices)
|
||||
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
|
||||
|
||||
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||
if api_type == "wrap_into_iterator":
|
||||
iterator = self._wrap_iterator(
|
||||
input_type, dataset_fn, input_workers, devices, split_batch_by,
|
||||
enable_get_next_as_optional)
|
||||
else:
|
||||
# wrapping into a dataset:
|
||||
given_dataset = dataset_fn(distribute_lib.InputContext())
|
||||
dataset = self._wrap_dataset(input_type, given_dataset, input_workers,
|
||||
split_batch_by, enable_get_next_as_optional)
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate(
|
||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||
self.assertEqual(len(expected_value), len(computed_value))
|
||||
for i in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||
if context.executing_eagerly():
|
||||
iterator = iter(dataset)
|
||||
else:
|
||||
# In graph mode currently we only have support for creating iterators
|
||||
# for datasetV1 instances.
|
||||
if not isinstance(dataset, dataset_ops.DatasetV1):
|
||||
self.skipTest("unsupported test combination")
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
next_element = iterator.get_next()
|
||||
evaluate(
|
||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||
if iteration_type == "get_next":
|
||||
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
||||
if isinstance(iterator, input_lib.DistributedIteratorV1):
|
||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||
else:
|
||||
evaluate(control_flow_ops.group(iterator._initializer))
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate(
|
||||
[values.select_replica(r,
|
||||
next_element) for r in range(len(devices))])
|
||||
self.assertEqual(len(expected_value), len(computed_value))
|
||||
for i in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate(
|
||||
[values.select_replica(r, next_element) for r in range(len(devices))])
|
||||
self.assertEqual(len(expected_value), len(computed_value))
|
||||
for i in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
next_element = iterator.get_next()
|
||||
evaluate(
|
||||
[values.select_replica(r,
|
||||
next_element) for r in range(len(devices))])
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
if isinstance(iterator, input_lib.DistributedIteratorV1):
|
||||
evaluate(control_flow_ops.group(iterator.initialize()))
|
||||
else:
|
||||
evaluate(control_flow_ops.group(iterator._initializer))
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate(
|
||||
[values.select_replica(r,
|
||||
next_element) for r in range(len(devices))])
|
||||
self.assertEqual(len(expected_value), len(computed_value))
|
||||
for i in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||
|
||||
if iteration_type == "for_loop" and context.executing_eagerly():
|
||||
actual_values = []
|
||||
for x in dataset:
|
||||
computed_value = self.evaluate(
|
||||
[values.select_replica(r, x) for r in range(len(devices))])
|
||||
actual_values.append(computed_value)
|
||||
for i, expected_value in enumerate(expected_values):
|
||||
self.assertEqual(len(expected_value), len(actual_values[i]))
|
||||
for j in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[j], actual_values[i][j])
|
||||
|
||||
|
||||
class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
||||
parameterized.TestCase):
|
||||
class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def testGraphModeError(self):
|
||||
with context.graph_mode():
|
||||
worker_device_pairs = [("", ["/device:CPU:0"])]
|
||||
devices = nest.flatten([ds for _, ds in worker_device_pairs])
|
||||
device_map = values.ReplicaDeviceMap(devices)
|
||||
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
|
||||
dataset = dataset_ops.Dataset.range(10).batch(2)
|
||||
|
||||
with self.assertRaisesRegexp(RuntimeError,
|
||||
"__iter__ is only "
|
||||
"supported when eager execution is "
|
||||
"enabled."):
|
||||
dist_dataset = input_lib.DistributedDatasetV1(dataset, input_workers)
|
||||
iter(dist_dataset)
|
||||
|
||||
@combinations.generate(combinations.combine(
|
||||
mode=["graph", "eager"],
|
||||
input_type=["input_fn", "dataset"]))
|
||||
def testOneDeviceCPU(self, input_type):
|
||||
input_type=["input_fn", "dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"]))
|
||||
def testOneDeviceCPU(self, input_type, api_type, iteration_type):
|
||||
worker_device_pairs = [("", ["/device:CPU:0"])]
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
||||
else:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||
|
||||
expected_values = [[i] for i in range(10)]
|
||||
|
||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||
expected_values)
|
||||
self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn,
|
||||
worker_device_pairs, expected_values)
|
||||
|
||||
@combinations.generate(combinations.combine(
|
||||
mode=["graph", "eager"],
|
||||
input_type=["input_fn", "dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"],
|
||||
required_gpus=1))
|
||||
def testTwoDevicesOneGPUOneCPU(self, input_type):
|
||||
def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type):
|
||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
||||
else:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||
|
||||
expected_values = [[i, i+1] for i in range(0, 10, 2)]
|
||||
|
||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||
expected_values)
|
||||
self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn,
|
||||
worker_device_pairs, expected_values)
|
||||
|
||||
@combinations.generate(combinations.combine(
|
||||
mode=["graph", "eager"],
|
||||
input_type=["input_fn", "dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"],
|
||||
required_gpus=1))
|
||||
def testTupleDataset(self, input_type):
|
||||
def testTupleDataset(self, input_type, api_type, iteration_type):
|
||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||
|
||||
def dataset_fn(ctx):
|
||||
del ctx
|
||||
dataset1 = dataset_ops.Dataset.range(10)
|
||||
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
if tf2.enabled():
|
||||
dataset1 = dataset_ops.Dataset.range(10)
|
||||
dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2)
|
||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
else:
|
||||
dataset1 = dataset_ops.DatasetV2.range(10)
|
||||
dataset2 = dataset_ops.DatasetV2.range(10).map(lambda x: x**2)
|
||||
return dataset_ops.DatasetV2.zip((dataset1, dataset2))
|
||||
|
||||
expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)]
|
||||
|
||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||
expected_values)
|
||||
self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn,
|
||||
worker_device_pairs, expected_values)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["graph", "eager"],
|
||||
input_type=["input_fn", "dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"],
|
||||
required_gpus=1))
|
||||
def testUnevenDatasetBatches(self, input_type):
|
||||
def testUnevenDatasetBatches(self, input_type, api_type, iteration_type):
|
||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch(2)
|
||||
else:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
|
||||
|
||||
# The last global batch only contains data for one replica.
|
||||
expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]]
|
||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||
expected_values, enable_get_next_as_optional=True)
|
||||
self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn,
|
||||
worker_device_pairs, expected_values,
|
||||
enable_get_next_as_optional=True)
|
||||
|
||||
@combinations.generate(combinations.combine(
|
||||
mode=["graph", "eager"],
|
||||
input_type=["dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"],
|
||||
split_batch_by=[None, 2],
|
||||
required_gpus=1))
|
||||
def testBatchSplitting(self, input_type, split_batch_by):
|
||||
def testBatchSplitting(self, input_type, api_type, iteration_type,
|
||||
split_batch_by):
|
||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||
batch_size = 10
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size)
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(100).batch(batch_size)
|
||||
else:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size)
|
||||
|
||||
updated_batch_size = (
|
||||
batch_size // split_batch_by if split_batch_by else batch_size)
|
||||
@ -176,13 +280,13 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase,
|
||||
range(i+updated_batch_size, i+2*updated_batch_size)]
|
||||
for i in range(0, 100, updated_batch_size*2)]
|
||||
|
||||
self._test_iterator(input_type, dataset_fn, worker_device_pairs,
|
||||
expected_values, sess=None,
|
||||
split_batch_by=split_batch_by)
|
||||
self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn,
|
||||
worker_device_pairs, expected_values, sess=None,
|
||||
split_batch_by=split_batch_by)
|
||||
|
||||
|
||||
class InputIteratorMultiWorkerTest(
|
||||
multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase,
|
||||
class DistributedIteratorMultiWorkerTest(
|
||||
multi_worker_test_base.MultiWorkerTestBase, DistributedIteratorTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _cpu_devices(self):
|
||||
@ -206,77 +310,109 @@ class InputIteratorMultiWorkerTest(
|
||||
|
||||
@combinations.generate(combinations.combine(
|
||||
mode=["graph"],
|
||||
input_type=["input_fn", "dataset"]))
|
||||
def testOneDevicePerWorker(self, input_type):
|
||||
input_type=["input_fn", "dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"]))
|
||||
def testOneDevicePerWorker(self, input_type, api_type, iteration_type):
|
||||
worker_devices = self._cpu_devices()
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(4)
|
||||
else:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
|
||||
|
||||
if input_type == "dataset":
|
||||
# Autosharded
|
||||
expected_values = [[0, 1], [2, 3]]
|
||||
else:
|
||||
expected_values = [[0, 0], [1, 1], [2, 2], [3, 3]]
|
||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||
expected_values, sess)
|
||||
self._test_input_iteration(input_type, api_type, iteration_type,
|
||||
dataset_fn, worker_devices,
|
||||
expected_values, sess)
|
||||
|
||||
@combinations.generate(combinations.combine(
|
||||
mode=["graph"],
|
||||
input_type=["input_fn", "dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"],
|
||||
required_gpus=1))
|
||||
def testTwoDevicesPerWorker(self, input_type):
|
||||
def testTwoDevicesPerWorker(self, input_type, api_type, iteration_type):
|
||||
worker_devices = self._cpu_and_one_gpu_devices()
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(4)
|
||||
else:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(4)
|
||||
|
||||
if input_type == "dataset":
|
||||
# Autosharded
|
||||
expected_values = [[0, 2, 1, 3]]
|
||||
else:
|
||||
expected_values = [[0, 1, 0, 1], [2, 3, 2, 3]]
|
||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||
expected_values, sess)
|
||||
self._test_input_iteration(input_type, api_type, iteration_type,
|
||||
dataset_fn, worker_devices,
|
||||
expected_values, sess)
|
||||
|
||||
@combinations.generate(combinations.combine(
|
||||
mode=["graph"],
|
||||
input_type=["input_fn", "dataset"]))
|
||||
def testTupleDataset(self, input_type):
|
||||
input_type=["input_fn", "dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"]))
|
||||
def testTupleDataset(self, input_type, api_type, iteration_type):
|
||||
worker_devices = self._cpu_devices()
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
|
||||
def dataset_fn(ctx):
|
||||
del ctx
|
||||
dataset1 = dataset_ops.Dataset.range(4)
|
||||
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
|
||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
if tf2.enabled():
|
||||
dataset1 = dataset_ops.DatasetV2.range(4)
|
||||
dataset2 = dataset_ops.DatasetV2.range(4).map(lambda x: x**2)
|
||||
return dataset_ops.DatasetV2.zip((dataset1, dataset2))
|
||||
else:
|
||||
dataset1 = dataset_ops.Dataset.range(4)
|
||||
dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2)
|
||||
return dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
|
||||
if input_type == "dataset":
|
||||
# Autosharded
|
||||
expected_values = [[(0, 0), (1, 1)], [(2, 4), (3, 9)]]
|
||||
else:
|
||||
expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)]
|
||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||
expected_values, sess)
|
||||
self._test_input_iteration(input_type, api_type, iteration_type,
|
||||
dataset_fn, worker_devices, expected_values,
|
||||
sess)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["graph"], input_type=["input_fn", "dataset"], required_gpus=1))
|
||||
def testUnevenDatasetBatches(self, input_type):
|
||||
mode=["graph"],
|
||||
input_type=["input_fn", "dataset"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"],
|
||||
required_gpus=1))
|
||||
def testUnevenDatasetBatches(self, input_type, api_type, iteration_type):
|
||||
worker_devices = self._cpu_and_one_gpu_devices()
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch(2)
|
||||
else:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
|
||||
if input_type == "dataset":
|
||||
# Autosharded
|
||||
expected_values = [[[0, 1], [4, 5], [2, 3], [6, 7]], [[8], [], [], []]]
|
||||
else:
|
||||
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
|
||||
[[4, 5], [6, 7], [4, 5], [6, 7]], [[8], [], [8], []]]
|
||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||
expected_values, sess,
|
||||
enable_get_next_as_optional=True)
|
||||
self._test_input_iteration(input_type, api_type, iteration_type,
|
||||
dataset_fn, worker_devices, expected_values,
|
||||
sess, enable_get_next_as_optional=True)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["graph"], input_type=["input_fn"], required_gpus=1))
|
||||
def testDifferentDatasets(self, input_type):
|
||||
mode=["graph"], input_type=["input_fn"],
|
||||
api_type=["wrap_into_iterator", "wrap_into_dataset"],
|
||||
iteration_type=["get_next", "for_loop"],
|
||||
required_gpus=1))
|
||||
def testDifferentDatasets(self, input_type, api_type, iteration_type):
|
||||
worker_devices = self._cpu_and_one_gpu_devices()
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
|
||||
@ -288,10 +424,9 @@ class InputIteratorMultiWorkerTest(
|
||||
|
||||
expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]],
|
||||
[[4, 5], [6, 7], [4, 5], [6, 7]], [[], [], [8], []]]
|
||||
self._test_iterator(input_type, dataset_fn, worker_devices,
|
||||
expected_values, sess,
|
||||
enable_get_next_as_optional=True)
|
||||
|
||||
self._test_input_iteration(input_type, api_type, iteration_type,
|
||||
dataset_fn, worker_devices, expected_values,
|
||||
sess, enable_get_next_as_optional=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -597,6 +597,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
return input_lib.InputFunctionIterator(
|
||||
input_fn, self._input_workers, input_contexts)
|
||||
|
||||
def _experimental_distribute_dataset(self, dataset):
|
||||
return input_lib.get_distributed_dataset(dataset, self._input_workers,
|
||||
self._num_replicas_in_sync)
|
||||
|
||||
def _experimental_make_numpy_dataset(self, numpy_input, session):
|
||||
return numpy_dataset.one_host_numpy_dataset(
|
||||
numpy_input, self._host_input_device, session)
|
||||
|
@ -105,6 +105,11 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
||||
del destinations
|
||||
return tensor
|
||||
|
||||
def _experimental_distribute_dataset(self, dataset):
|
||||
# Note that split_batch_by argument is not passed because it is always 1 in
|
||||
# this strategy, and adding it adds unnecessary overhead to the dataset.
|
||||
return input_lib.get_distributed_dataset(dataset, self._input_workers)
|
||||
|
||||
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
||||
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
|
||||
initial_loop_values=None):
|
||||
|
@ -282,6 +282,10 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
values.validate_colocate(colocate_with_variable, self)
|
||||
|
||||
def _experimental_distribute_dataset(self, dataset):
|
||||
return input_lib.get_distributed_dataset(dataset, self._input_workers,
|
||||
self._num_replicas_in_sync)
|
||||
|
||||
def _make_dataset_iterator(self, dataset):
|
||||
return input_lib.DatasetIterator(dataset, self._input_workers,
|
||||
self._num_replicas_in_sync)
|
||||
|
@ -325,6 +325,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
numpy_input, numpy_dataset.SingleDevice(self._host_device),
|
||||
session)
|
||||
|
||||
def _experimental_distribute_dataset(self, dataset):
|
||||
return input_lib.get_distributed_dataset(dataset, self._input_workers,
|
||||
self._num_replicas_in_sync)
|
||||
|
||||
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
||||
# TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
|
||||
# a mechanism to infer the outputs of `fn`. Pending b/110550782.
|
||||
|
@ -24,6 +24,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -24,6 +24,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -23,6 +23,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -24,6 +24,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -24,6 +24,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -24,6 +24,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -23,6 +23,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -23,6 +23,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -22,6 +22,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -23,6 +23,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -23,6 +23,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -23,6 +23,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -23,6 +23,10 @@ tf_class {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
Reference in New Issue
Block a user