From 50eee22d86337b999dca809a63ab8114010c2728 Mon Sep 17 00:00:00 2001 From: Anjali Sridhar Date: Sat, 13 Apr 2019 19:01:24 -0700 Subject: [PATCH] Add `tf.distribute.Strategy.experimental_distribute_dataset` to clone the dataset on different devices and return a `DistributedDataset`. You can iterate on this dataset in eager mode in a pythonic fashion since we implement the __iter__ protocol for the wrapped dataset. PiperOrigin-RevId: 243459250 --- tensorflow/python/distribute/BUILD | 18 + .../collective_all_reduce_strategy.py | 6 + .../distribute/custom_training_loop_test.py | 114 ++++++ .../python/distribute/distribute_lib.py | 39 +++ tensorflow/python/distribute/input_lib.py | 281 ++++++++++----- .../python/distribute/input_lib_test.py | 327 +++++++++++++----- .../python/distribute/mirrored_strategy.py | 4 + .../python/distribute/one_device_strategy.py | 5 + .../distribute/parameter_server_strategy.py | 4 + tensorflow/python/distribute/tpu_strategy.py | 4 + ...orflow.distribute.-mirrored-strategy.pbtxt | 4 + ...flow.distribute.-one-device-strategy.pbtxt | 4 + .../v1/tensorflow.distribute.-strategy.pbtxt | 4 + ...perimental.-central-storage-strategy.pbtxt | 4 + ...ntal.-multi-worker-mirrored-strategy.pbtxt | 4 + ...erimental.-parameter-server-strategy.pbtxt | 4 + ...tribute.experimental.-t-p-u-strategy.pbtxt | 4 + ...orflow.distribute.-mirrored-strategy.pbtxt | 4 + ...flow.distribute.-one-device-strategy.pbtxt | 4 + .../v2/tensorflow.distribute.-strategy.pbtxt | 4 + ...perimental.-central-storage-strategy.pbtxt | 4 + ...ntal.-multi-worker-mirrored-strategy.pbtxt | 4 + ...erimental.-parameter-server-strategy.pbtxt | 4 + ...tribute.experimental.-t-p-u-strategy.pbtxt | 4 + 24 files changed, 681 insertions(+), 177 deletions(-) create mode 100644 tensorflow/python/distribute/custom_training_loop_test.py diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 28499b73afa..e1a7a1031d4 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -764,6 +764,24 @@ cuda_py_test( ], ) +distribute_py_test( + name = "custom_training_loop_test", + srcs = ["custom_training_loop_test.py"], + main = "custom_training_loop_test.py", + tags = [ + "multi_and_single_gpu", + ], + deps = [ + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:multi_worker_test_base", + "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/eager:test", + "@absl_py//absl/testing:parameterized", + ], +) + distribute_py_test( name = "minimize_loss_test", srcs = ["minimize_loss_test.py"], diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index 99d4808f5bb..4b5eae0b0f9 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -350,6 +350,12 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): num_replicas_in_sync=self._num_replicas_in_sync) return input_context + def _experimental_distribute_dataset(self, dataset): + input_context = self._make_input_context() + return input_lib.get_distributed_dataset(dataset, self._input_workers, + self._num_replicas_in_sync, + input_context=input_context) + def _make_dataset_iterator(self, dataset): """Distributes the dataset to each local GPU.""" input_context = self._make_input_context() diff --git a/tensorflow/python/distribute/custom_training_loop_test.py b/tensorflow/python/distribute/custom_training_loop_test.py new file mode 100644 index 00000000000..ab6fac8c4a7 --- /dev/null +++ b/tensorflow/python/distribute/custom_training_loop_test.py @@ -0,0 +1,114 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for custom training loops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +from tensorflow.python import tf2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.eager import def_function +from tensorflow.python.eager import test + + +class InputIterationTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.strategies_minus_tpu, + mode=["eager"] + )) + def testFullEager(self, distribution): + dataset = self._get_dataset() + + def train_step(data): + return data + + dist_dataset = distribution.experimental_distribute_dataset(dataset) + results = [] + for x in dist_dataset: + output = distribution.experimental_local_results( + distribution.experimental_run_v2(train_step, args=(x,))) + results.append(output) + self._validate_outputs(results) + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.strategies_minus_tpu, + mode=["eager"] + )) + def testStepInFunction(self, distribution): + dataset = self._get_dataset() + + @def_function.function + def train_step(data): + return data + + dist_dataset = distribution.experimental_distribute_dataset(dataset) + results = [] + for x in dist_dataset: + output = distribution.experimental_local_results( + distribution.experimental_run_v2(train_step, args=(x,))) + results.append(output) + self._validate_outputs(results) + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.strategies_minus_tpu + + [strategy_combinations.tpu_strategy_one_step], + mode=["eager"] + )) + def testRunInFunction(self, distribution): + dataset = self._get_dataset() + + def train_step(data): + return data + + @def_function.function + def f_train_step(input_data): + return distribution.experimental_local_results( + distribution.experimental_run_v2(train_step, args=(input_data,))) + + dist_dataset = distribution.experimental_distribute_dataset(dataset) + results = [] + for x in dist_dataset: + output = f_train_step(x) + results.append(output) + self._validate_outputs(results) + + def _get_dataset(self): + if tf2.enabled(): + return dataset_ops.DatasetV2.range(10).batch(2) + else: + return dataset_ops.Dataset.range(10).batch(2) + + def _validate_outputs(self, actual_results): + expected_results = [[i, i+1] for i in range(0, 10, 2)] + self.assertEqual(len(expected_results), len(actual_results)) + + for i, expected_result in enumerate(expected_results): + final_result = [] + actual_result = actual_results[i] + for val in actual_result: + final_result.extend(val.numpy()) + self.assertAllEqual(expected_result, final_result) + +if __name__ == "__main__": + test.main() + diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index c5dbeb49ca2..16dc12cc070 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -377,6 +377,39 @@ class Strategy(object): args = (input_iterator.get_next(),) if input_iterator is not None else () return self.experimental_run_v2(fn, args=args) + def experimental_distribute_dataset(self, dataset): + """Distributes a tf.data.Dataset instance provided via `dataset`. + + Data from the given dataset will be distributed evenly across all the + compute replicas. This function assumes that the input dataset is batched + by the global batch size. + + The following is an example: + + ```python + strategy = tf.distribute.MirroredStrategy() + + # Create a dataset + dataset = dataset_ops.Dataset.range(10).batch(2) + + # Distribute that dataset + dist_dataset = strategy.experimental_distribute_dataset(dataset) + # Iterate over the distributed dataset + for x in dist_dataset: + # process dataset elements + strategy.experimental_run_v2(train_step, args=(x,)) + ``` + + Args: + dataset: `tf.data.Dataset` that will be distributed evenly across all + replicas. + + Returns: + A `DistributedDataset` which returns inputs for each step of the + computation. + """ + return self._extended._experimental_distribute_dataset(dataset) # pylint: disable=protected-access + def experimental_run_v2(self, fn, args=(), kwargs=None): """Runs ops in `fn` on each replica, with the given arguments. @@ -1062,6 +1095,9 @@ class StrategyExtendedV2(object): def _make_input_fn_iterator(self, input_fn, replication_mode): raise NotImplementedError("must be implemented in descendants") + def _experimental_distribute_dataset(self, dataset): + raise NotImplementedError("must be implemented in descendants") + def _reduce(self, reduce_op, value): # Default implementation until we have an implementation for each strategy. return self._local_results( @@ -1671,6 +1707,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1): def variable_created_in_scope(self, v): return v._distribute_strategy is None # pylint: disable=protected-access + def _experimental_distribute_dataset(self, dataset): + return dataset + def _make_dataset_iterator(self, dataset): return _DefaultDistributionExtended.DefaultInputIterator(dataset) diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index b7b74064cdc..be619f72486 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -29,6 +29,7 @@ from tensorflow.python.distribute import values from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as tf_device +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -38,6 +39,35 @@ from tensorflow.python.ops import math_ops from tensorflow.python.util import nest +def get_distributed_dataset(dataset, input_workers, split_batch_by=None, + input_context=None): + """Returns a wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance. + + This is a common function that is used by all strategies to return the right + tf.data.Dataset wrapped instance depending on the `dataset` argument type. + + Args: + dataset: a tf.data.DatasetV1 or tf.data.DatasetV2 instance. + input_workers: an InputWorkers object which specifies devices on which + iterators should be created. + split_batch_by: Optional integer. If present, we "split" each batch of the + dataset by `split_batch_by` value. + input_context: `InputContext` for sharding. Only pass this in for between + graph multi-worker cases where there is only one `input_worker`. In + these cases, we will shard based on the `input_pipeline_id` and + `num_input_pipelines` in the `InputContext`. + + Returns: + A wrapped tf.data.DatasetV1 or tf.data.DatasetV2 instance. + """ + if isinstance(dataset, dataset_ops.DatasetV1): + return DistributedDatasetV1(dataset, input_workers, split_batch_by, + input_context) + else: + return DistributedDataset(dataset, input_workers, split_batch_by, + input_context) + + class InputWorkers(object): """A 1-to-many mapping from input worker devices to compute devices.""" @@ -95,31 +125,7 @@ class InputWorkers(object): self.__class__.__name__, debug_repr, self._device_map) -class InputIterator(object): - """An input iterator, intended to be passed to `DistributionStrategy.run`.""" - - def get_next(self): - """Returns the next inputs for all replicas.""" - raise NotImplementedError("must be implemented in descendants") - - def initialize(self): - """Initialize the underlying input dataset, when applicable. - - In eager mode, this will create a new iterator and return it. - In graph mode, this will initialize the same underlying iterator(s). - - Users are required to call this if - - This iterator was returned from a call to `make_input_fn_iterator` with an - input function that returns a dataset. - - Or this iterator was returned from a call to `make_dataset_iterator`. - - Returns: - A list of initialization ops to be executed. - """ - raise NotImplementedError("must be implemented in descendants") - - -class InputIteratorImpl(InputIterator): +class DistributedIterator(object): """Common implementation for all input iterators.""" def __init__(self, input_workers, iterators, **kwargs): @@ -127,11 +133,11 @@ class InputIteratorImpl(InputIterator): # be correctly handled. self._enable_get_next_as_optional = False if len(kwargs) > 1: - raise ValueError("InputIteratorImpl constructor only takes one " + raise ValueError("DistributedIterator constructor only takes one " "experimental flag now") if len(kwargs) == 1: if "_enable_get_next_as_optional" not in kwargs: - raise ValueError("InputIteratorImpl constructor does not support " + raise ValueError("DistributedIterator constructor does not support " "arguments: {}".format(kwargs)) self._enable_get_next_as_optional = ( kwargs["_enable_get_next_as_optional"]) @@ -143,6 +149,18 @@ class InputIteratorImpl(InputIterator): self._iterators = iterators self._input_workers = input_workers + def next(self): + return self.__next__() + + def __next__(self): + if not context.executing_eagerly(): + raise RuntimeError("__iter__ is only supported " + "when eager execution is enabled.") + try: + return self.get_next() + except errors.OutOfRangeError: + raise StopIteration + def get_next(self, name=None): """Returns the next input from the iterator for all replicas.""" if not self._enable_get_next_as_optional: @@ -226,16 +244,33 @@ class InputIteratorImpl(InputIterator): return values.regroup(self._input_workers.device_map, replicas) + # We need a private initializer method for re-initializing multidevice + # iterators when used with Keras training loops. If we don't reinitialize the + # iterator we run into memory leak issues (b/123315763). + @property + def _initializer(self): + init_ops = [] + for it in self._iterators: + init_ops.extend(it.initialize()) + return control_flow_ops.group(init_ops) + + +class DistributedIteratorV1(DistributedIterator): + """Input Iterator for tf.data.DatasetV1.""" + + # TODO(anjalisridhar): Move to using `initializer` instead to be consistent + # with tf.data iterator APIs. def initialize(self): """Initialze underlying iterators. Returns: A list of any initializer ops that should be run. """ - init_ops = [] - for it in self._iterators: - init_ops.extend(it.initialize()) - return init_ops + return super(DistributedIteratorV1, self)._initializer + + @property + def initializer(self): + return self.initialize() # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. @property @@ -260,7 +295,113 @@ class InputIteratorImpl(InputIterator): return None -class InputFunctionIterator(InputIteratorImpl): +class DistributedDataset(object): + """Wrapped tf.data.DatasetV2 that supports prefetching to multiple devices.""" + + def __init__(self, dataset, input_workers, split_batch_by=None, + input_context=None, **kwargs): + """Distribute the dataset on all workers. + + If `split_batch_by` is not None, we "split" each batch of the dataset by + `split_batch_by` value. + + Args: + dataset: `tf.data.Dataset` that will be used as the input source. + input_workers: an `InputWorkers` object. + split_batch_by: Optional integer. If present, we "split" each batch of the + dataset by `split_batch_by` value. + input_context: `InputContext` for sharding. Only pass this in for between + graph multi-worker cases where there is only one `input_worker`. In + these cases, we will shard based on the `input_pipeline_id` and + `num_input_pipelines` in the `InputContext`. + **kwargs: Additional experimental flags. Will be removed in future. + """ + # We clone and shard the dataset on each worker. The current setup tries to + # shard the dataset by files if possible so that each worker sees a + # different subset of files. If that is not possible, will attempt to shard + # the final input such that each worker will run the entire preprocessing + # pipeline and only receive its own shard of the dataset. + assert isinstance(input_workers, InputWorkers) + if split_batch_by: + dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access + + self._cloned_datasets = [] + if input_context: + # Between-graph where we rely on the input_context for sharding + assert input_workers.num_workers == 1 + dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access + dataset, input_context.num_input_pipelines, + input_context.input_pipeline_id) + self._cloned_datasets.append(dataset) + else: + for i, worker in enumerate(input_workers.worker_devices): + with ops.device(worker): + cloned_dataset = dataset + if not context.executing_eagerly(): + cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access + cloned_dataset = cloned_dataset.with_options(dataset.options()) + # TODO(b/129506833): Figure out between graph cases + cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access + cloned_dataset, len(input_workers.worker_devices), i) + self._cloned_datasets.append(cloned_dataset) + + self._input_workers = input_workers + # TODO(anjalisridhar): Identify if we need to set this property on the + # iterator. + self._element_structure = dataset._element_structure # pylint: disable=protected-access + self._kwargs = kwargs + + def __iter__(self): + # TODO(anjalisridhar): Remove this restriction once we can create + # iterators in graph mode. + if context.executing_eagerly(): + worker_iterators = _create_iterators_per_worker(self._cloned_datasets, + self._input_workers) + iterator = DistributedIterator(self._input_workers, worker_iterators, + **self._kwargs) + iterator._element_structure = self._element_structure # pylint: disable=protected-access + return iterator + else: + raise RuntimeError("__iter__ is only supported when eager " + "execution is enabled.") + + +class DistributedDatasetV1(DistributedDataset): + """Wrapped tf.data.DatasetV1 that supports prefetching to multiple devices.""" + + def __init__(self, dataset, input_workers, split_batch_by=None, + input_context=None, **kwargs): + self._input_workers = input_workers + super(DistributedDatasetV1, self).__init__(dataset, input_workers, + split_batch_by=split_batch_by, + input_context=input_context, + **kwargs) + + def make_one_shot_iterator(self): + """Get a one time use iterator for DistributedDatasetV1.""" + return self._get_iterator() + + def make_initializable_iterator(self): + """Get an initializable iterator for DistributedDatasetV1.""" + # Eager mode generates already initialized iterators. Hence we cannot create + # an initializable iterator. + if context.executing_eagerly(): + raise ValueError("Cannot create initializable iterator in Eager mode. " + "Please use `make_one_shot_iterator` instead.") + return self._get_iterator() + + def _get_iterator(self): + worker_iterators = _create_iterators_per_worker(self._cloned_datasets, + self._input_workers) + iterator = DistributedIteratorV1(self._input_workers, worker_iterators, + **self._kwargs) + iterator._element_structure = self._element_structure # pylint: disable=protected-access + return iterator + + +# TODO(anjalisridhar): This class will be soon be removed in favor of newer +# APIs. +class InputFunctionIterator(DistributedIteratorV1): """Iterator created from input function.""" def __init__(self, input_fn, input_workers, input_contexts, **kwargs): @@ -305,7 +446,9 @@ class InputFunctionIterator(InputIteratorImpl): input_workers, iterators, **kwargs) -class DatasetIterator(InputIteratorImpl): +# TODO(anjalisridhar): This class will soon be removed and users should move +# to using DistributedIterator. +class DatasetIterator(DistributedIteratorV1): """Iterator created from input dataset.""" def __init__(self, dataset, input_workers, split_batch_by=None, @@ -313,20 +456,7 @@ class DatasetIterator(InputIteratorImpl): """Make an iterator for the dataset on given devices. If `split_batch_by` is not None, we "split" each batch of the - dataset by `split_batch_by` value. To achieve this, we first unbatch the - input dataset and then rebatch it with the per replica batch size that is - calculated using `global_batch_size // split_batch_by`. - The currently supported datasets are as follows: - `dataset.batch()` is the last operation on the dataset OR - `dataset.apply(map_and_batch)` is the last operation on the dataset OR - `dataset.batch().prefetch()` are the last 2 operations on the dataset OR - `dataset.apply(map_and_batch).prefetch()` are the last 2 operations. - - We clone and shard the dataset on each worker. The current setup tries to - shard the dataset by files if possible so that each worker sees a different - subset of files. If that is not possible, will attempt to shard the final - input such that each worker will run the entire preprocessing pipeline and - only receive its own shard of the dataset. + dataset by `split_batch_by` value. Args: dataset: `tf.data.Dataset` that will be used as the input source. @@ -339,40 +469,14 @@ class DatasetIterator(InputIteratorImpl): `num_input_pipelines` in the `InputContext`. **kwargs: Additional experimental flags. Will be removed in future. """ - assert isinstance(input_workers, InputWorkers) - if split_batch_by: - dataset = batching._RebatchDataset(dataset, split_batch_by) # pylint: disable=protected-access - - iterators = [] - - if input_context: - # Between-graph where we rely on the input_context for sharding - assert input_workers.num_workers == 1 - worker = input_workers.worker_devices[0] - worker_devices = input_workers.compute_devices_for_worker(0) - dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access - dataset, input_context.num_input_pipelines, - input_context.input_pipeline_id) - iterator = _SingleWorkerDatasetIterator(dataset, worker, worker_devices) - iterators.append(iterator) - else: - # In-graph cases where we depend on the list of workers for sharding - for i, worker in enumerate(input_workers.worker_devices): - with ops.device(worker): - worker_devices = input_workers.compute_devices_for_worker(i) - cloned_dataset = dataset - if not context.executing_eagerly(): - cloned_dataset = input_ops._clone_dataset(dataset) # pylint: disable=protected-access - cloned_dataset = cloned_dataset.with_options(dataset.options()) - cloned_dataset = input_ops.auto_shard_dataset( # pylint: disable=protected-access - cloned_dataset, len(input_workers.worker_devices), i) - iterator = _SingleWorkerDatasetIterator(cloned_dataset, worker, - worker_devices) - iterators.append(iterator) - - self._element_structure = dataset._element_structure # pylint: disable=protected-access - - super(DatasetIterator, self).__init__(input_workers, iterators, **kwargs) + dist_dataset = DistributedDatasetV1(dataset, input_workers, + split_batch_by=split_batch_by, + input_context=input_context) + worker_iterators = _create_iterators_per_worker( + dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access + super(DatasetIterator, self).__init__(input_workers, worker_iterators, # pylint: disable=protected-access + **kwargs) + self._element_structure = dist_dataset._element_structure # pylint: disable=protected-access def _dummy_tensor_fn(value_structure): @@ -554,6 +658,21 @@ class _SingleWorkerCallableIterator(object): return [] +def _create_iterators_per_worker(worker_datasets, input_workers): + """Create a multidevice iterator on each of the workers.""" + assert isinstance(input_workers, InputWorkers) + + assert len(worker_datasets) == len(input_workers.worker_devices) + iterators = [] + for i, worker in enumerate(input_workers.worker_devices): + with ops.device(worker): + worker_devices = input_workers.compute_devices_for_worker(i) + iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker, + worker_devices) + iterators.append(iterator) + return iterators + + # TODO(sourabhbajaj): Remove this in lieu of distributed datasets def _get_batched_dataset(d): """Get the batched dataset from `d`.""" diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 57302632ede..2ca11024e43 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from absl.testing import parameterized +from tensorflow.python import tf2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribute_lib @@ -32,14 +33,10 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.util import nest -class InputIteratorTestBase(test.TestCase): - - def _create_iterator(self, input_type, dataset_fn, worker_device_pairs, - devices, split_batch_by, - enable_get_next_as_optional): - device_map = values.ReplicaDeviceMap(devices) - input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) +class DistributedIteratorTestBase(test.TestCase): + def _wrap_iterator(self, input_type, dataset_fn, input_workers, devices, + split_batch_by, enable_get_next_as_optional): if input_type == "input_fn": input_contexts = [] for i in range(input_workers.num_workers): @@ -59,116 +56,223 @@ class InputIteratorTestBase(test.TestCase): _enable_get_next_as_optional=enable_get_next_as_optional) return iterator - def _test_iterator(self, - input_type, - dataset_fn, - worker_device_pairs, - expected_values, - sess=None, - split_batch_by=None, - enable_get_next_as_optional=False): + def _wrap_dataset(self, input_type, dataset, input_workers, + split_batch_by, enable_get_next_as_optional): + if isinstance(dataset, dataset_ops.Dataset): + return input_lib.DistributedDatasetV1( + dataset, input_workers, + split_batch_by, + _enable_get_next_as_optional=enable_get_next_as_optional) + else: + return input_lib.DistributedDataset( + dataset, input_workers, + split_batch_by, + _enable_get_next_as_optional=enable_get_next_as_optional) + + def _test_input_iteration(self, + input_type, + api_type, + iteration_type, + dataset_fn, + worker_device_pairs, + expected_values, + sess=None, + split_batch_by=None, + enable_get_next_as_optional=False): + if iteration_type == "for_loop" and not context.executing_eagerly(): + self.skipTest("unsupported test combination.") + + if api_type == "wrap_into_iterator" and iteration_type == "for_loop": + self.skipTest("unsupported test combination.") + + if api_type == "wrap_into_dataset" and input_type == "input_fn": + self.skipTest("unsupported test combination.") + devices = nest.flatten([ds for _, ds in worker_device_pairs]) - iterator = self._create_iterator( - input_type, dataset_fn, worker_device_pairs, devices, split_batch_by, - enable_get_next_as_optional) + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) - evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) - evaluate(control_flow_ops.group(iterator.initialize())) + if api_type == "wrap_into_iterator": + iterator = self._wrap_iterator( + input_type, dataset_fn, input_workers, devices, split_batch_by, + enable_get_next_as_optional) + else: + # wrapping into a dataset: + given_dataset = dataset_fn(distribute_lib.InputContext()) + dataset = self._wrap_dataset(input_type, given_dataset, input_workers, + split_batch_by, enable_get_next_as_optional) - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertEqual(len(expected_value), len(computed_value)) - for i in range(len(expected_value)): - self.assertAllEqual(expected_value[i], computed_value[i]) + if context.executing_eagerly(): + iterator = iter(dataset) + else: + # In graph mode currently we only have support for creating iterators + # for datasetV1 instances. + if not isinstance(dataset, dataset_ops.DatasetV1): + self.skipTest("unsupported test combination") + iterator = dataset.make_one_shot_iterator() - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) + if iteration_type == "get_next": + evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) + if isinstance(iterator, input_lib.DistributedIteratorV1): + evaluate(control_flow_ops.group(iterator.initialize())) + else: + evaluate(control_flow_ops.group(iterator._initializer)) - # After re-initializing the iterator, should be able to iterate again. - evaluate(control_flow_ops.group(iterator.initialize())) + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_replica(r, + next_element) for r in range(len(devices))]) + self.assertEqual(len(expected_value), len(computed_value)) + for i in range(len(expected_value)): + self.assertAllEqual(expected_value[i], computed_value[i]) - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [values.select_replica(r, next_element) for r in range(len(devices))]) - self.assertEqual(len(expected_value), len(computed_value)) - for i in range(len(expected_value)): - self.assertAllEqual(expected_value[i], computed_value[i]) + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate( + [values.select_replica(r, + next_element) for r in range(len(devices))]) + + # After re-initializing the iterator, should be able to iterate again. + if isinstance(iterator, input_lib.DistributedIteratorV1): + evaluate(control_flow_ops.group(iterator.initialize())) + else: + evaluate(control_flow_ops.group(iterator._initializer)) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate( + [values.select_replica(r, + next_element) for r in range(len(devices))]) + self.assertEqual(len(expected_value), len(computed_value)) + for i in range(len(expected_value)): + self.assertAllEqual(expected_value[i], computed_value[i]) + + if iteration_type == "for_loop" and context.executing_eagerly(): + actual_values = [] + for x in dataset: + computed_value = self.evaluate( + [values.select_replica(r, x) for r in range(len(devices))]) + actual_values.append(computed_value) + for i, expected_value in enumerate(expected_values): + self.assertEqual(len(expected_value), len(actual_values[i])) + for j in range(len(expected_value)): + self.assertAllEqual(expected_value[j], actual_values[i][j]) -class InputIteratorSingleWorkerTest(InputIteratorTestBase, - parameterized.TestCase): +class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, + parameterized.TestCase): + + def testGraphModeError(self): + with context.graph_mode(): + worker_device_pairs = [("", ["/device:CPU:0"])] + devices = nest.flatten([ds for _, ds in worker_device_pairs]) + device_map = values.ReplicaDeviceMap(devices) + input_workers = input_lib.InputWorkers(device_map, worker_device_pairs) + dataset = dataset_ops.Dataset.range(10).batch(2) + + with self.assertRaisesRegexp(RuntimeError, + "__iter__ is only " + "supported when eager execution is " + "enabled."): + dist_dataset = input_lib.DistributedDatasetV1(dataset, input_workers) + iter(dist_dataset) @combinations.generate(combinations.combine( mode=["graph", "eager"], - input_type=["input_fn", "dataset"])) - def testOneDeviceCPU(self, input_type): + input_type=["input_fn", "dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"])) + def testOneDeviceCPU(self, input_type, api_type, iteration_type): worker_device_pairs = [("", ["/device:CPU:0"])] - dataset_fn = lambda _: dataset_ops.Dataset.range(10) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(10) expected_values = [[i] for i in range(10)] - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) + self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn, + worker_device_pairs, expected_values) @combinations.generate(combinations.combine( mode=["graph", "eager"], input_type=["input_fn", "dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"], required_gpus=1)) - def testTwoDevicesOneGPUOneCPU(self, input_type): + def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda _: dataset_ops.Dataset.range(10) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(10) expected_values = [[i, i+1] for i in range(0, 10, 2)] - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) + self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn, + worker_device_pairs, expected_values) @combinations.generate(combinations.combine( mode=["graph", "eager"], input_type=["input_fn", "dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"], required_gpus=1)) - def testTupleDataset(self, input_type): + def testTupleDataset(self, input_type, api_type, iteration_type): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] def dataset_fn(ctx): del ctx - dataset1 = dataset_ops.Dataset.range(10) - dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) + if tf2.enabled(): + dataset1 = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(10).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + else: + dataset1 = dataset_ops.DatasetV2.range(10) + dataset2 = dataset_ops.DatasetV2.range(10).map(lambda x: x**2) + return dataset_ops.DatasetV2.zip((dataset1, dataset2)) expected_values = [[(i, i**2), (i+1, (i+1)**2)] for i in range(0, 10, 2)] - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values) + self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn, + worker_device_pairs, expected_values) @combinations.generate( combinations.combine( mode=["graph", "eager"], input_type=["input_fn", "dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"], required_gpus=1)) - def testUnevenDatasetBatches(self, input_type): + def testUnevenDatasetBatches(self, input_type, api_type, iteration_type): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] - dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch(2) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2) # The last global batch only contains data for one replica. expected_values = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8], []]] - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values, enable_get_next_as_optional=True) + self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn, + worker_device_pairs, expected_values, + enable_get_next_as_optional=True) @combinations.generate(combinations.combine( mode=["graph", "eager"], input_type=["dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"], split_batch_by=[None, 2], required_gpus=1)) - def testBatchSplitting(self, input_type, split_batch_by): + def testBatchSplitting(self, input_type, api_type, iteration_type, + split_batch_by): worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] batch_size = 10 - dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(100).batch(batch_size) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(100).batch(batch_size) updated_batch_size = ( batch_size // split_batch_by if split_batch_by else batch_size) @@ -176,13 +280,13 @@ class InputIteratorSingleWorkerTest(InputIteratorTestBase, range(i+updated_batch_size, i+2*updated_batch_size)] for i in range(0, 100, updated_batch_size*2)] - self._test_iterator(input_type, dataset_fn, worker_device_pairs, - expected_values, sess=None, - split_batch_by=split_batch_by) + self._test_input_iteration(input_type, api_type, iteration_type, dataset_fn, + worker_device_pairs, expected_values, sess=None, + split_batch_by=split_batch_by) -class InputIteratorMultiWorkerTest( - multi_worker_test_base.MultiWorkerTestBase, InputIteratorTestBase, +class DistributedIteratorMultiWorkerTest( + multi_worker_test_base.MultiWorkerTestBase, DistributedIteratorTestBase, parameterized.TestCase): def _cpu_devices(self): @@ -206,77 +310,109 @@ class InputIteratorMultiWorkerTest( @combinations.generate(combinations.combine( mode=["graph"], - input_type=["input_fn", "dataset"])) - def testOneDevicePerWorker(self, input_type): + input_type=["input_fn", "dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"])) + def testOneDevicePerWorker(self, input_type, api_type, iteration_type): worker_devices = self._cpu_devices() with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda _: dataset_ops.Dataset.range(4) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(4) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(4) + if input_type == "dataset": # Autosharded expected_values = [[0, 1], [2, 3]] else: expected_values = [[0, 0], [1, 1], [2, 2], [3, 3]] - self._test_iterator(input_type, dataset_fn, worker_devices, - expected_values, sess) + self._test_input_iteration(input_type, api_type, iteration_type, + dataset_fn, worker_devices, + expected_values, sess) @combinations.generate(combinations.combine( mode=["graph"], input_type=["input_fn", "dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"], required_gpus=1)) - def testTwoDevicesPerWorker(self, input_type): + def testTwoDevicesPerWorker(self, input_type, api_type, iteration_type): worker_devices = self._cpu_and_one_gpu_devices() with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda _: dataset_ops.Dataset.range(4) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(4) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(4) + if input_type == "dataset": # Autosharded expected_values = [[0, 2, 1, 3]] else: expected_values = [[0, 1, 0, 1], [2, 3, 2, 3]] - self._test_iterator(input_type, dataset_fn, worker_devices, - expected_values, sess) + self._test_input_iteration(input_type, api_type, iteration_type, + dataset_fn, worker_devices, + expected_values, sess) @combinations.generate(combinations.combine( mode=["graph"], - input_type=["input_fn", "dataset"])) - def testTupleDataset(self, input_type): + input_type=["input_fn", "dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"])) + def testTupleDataset(self, input_type, api_type, iteration_type): worker_devices = self._cpu_devices() with context.graph_mode(), self.cached_session() as sess: def dataset_fn(ctx): del ctx - dataset1 = dataset_ops.Dataset.range(4) - dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) - return dataset_ops.Dataset.zip((dataset1, dataset2)) + if tf2.enabled(): + dataset1 = dataset_ops.DatasetV2.range(4) + dataset2 = dataset_ops.DatasetV2.range(4).map(lambda x: x**2) + return dataset_ops.DatasetV2.zip((dataset1, dataset2)) + else: + dataset1 = dataset_ops.Dataset.range(4) + dataset2 = dataset_ops.Dataset.range(4).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) if input_type == "dataset": # Autosharded expected_values = [[(0, 0), (1, 1)], [(2, 4), (3, 9)]] else: expected_values = [[(i, i**2), (i, i**2)] for i in range(0, 4)] - self._test_iterator(input_type, dataset_fn, worker_devices, - expected_values, sess) + self._test_input_iteration(input_type, api_type, iteration_type, + dataset_fn, worker_devices, expected_values, + sess) @combinations.generate( combinations.combine( - mode=["graph"], input_type=["input_fn", "dataset"], required_gpus=1)) - def testUnevenDatasetBatches(self, input_type): + mode=["graph"], + input_type=["input_fn", "dataset"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"], + required_gpus=1)) + def testUnevenDatasetBatches(self, input_type, api_type, iteration_type): worker_devices = self._cpu_and_one_gpu_devices() with context.graph_mode(), self.cached_session() as sess: - dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2) + if tf2.enabled(): + dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch(2) + else: + dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2) if input_type == "dataset": # Autosharded expected_values = [[[0, 1], [4, 5], [2, 3], [6, 7]], [[8], [], [], []]] else: expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]], [[4, 5], [6, 7], [4, 5], [6, 7]], [[8], [], [8], []]] - self._test_iterator(input_type, dataset_fn, worker_devices, - expected_values, sess, - enable_get_next_as_optional=True) + self._test_input_iteration(input_type, api_type, iteration_type, + dataset_fn, worker_devices, expected_values, + sess, enable_get_next_as_optional=True) @combinations.generate( combinations.combine( - mode=["graph"], input_type=["input_fn"], required_gpus=1)) - def testDifferentDatasets(self, input_type): + mode=["graph"], input_type=["input_fn"], + api_type=["wrap_into_iterator", "wrap_into_dataset"], + iteration_type=["get_next", "for_loop"], + required_gpus=1)) + def testDifferentDatasets(self, input_type, api_type, iteration_type): worker_devices = self._cpu_and_one_gpu_devices() with context.graph_mode(), self.cached_session() as sess: @@ -288,10 +424,9 @@ class InputIteratorMultiWorkerTest( expected_values = [[[0, 1], [2, 3], [0, 1], [2, 3]], [[4, 5], [6, 7], [4, 5], [6, 7]], [[], [], [8], []]] - self._test_iterator(input_type, dataset_fn, worker_devices, - expected_values, sess, - enable_get_next_as_optional=True) - + self._test_input_iteration(input_type, api_type, iteration_type, + dataset_fn, worker_devices, expected_values, + sess, enable_get_next_as_optional=True) if __name__ == "__main__": test.main() diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 44c89c57d38..088a14e5af3 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -597,6 +597,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): return input_lib.InputFunctionIterator( input_fn, self._input_workers, input_contexts) + def _experimental_distribute_dataset(self, dataset): + return input_lib.get_distributed_dataset(dataset, self._input_workers, + self._num_replicas_in_sync) + def _experimental_make_numpy_dataset(self, numpy_input, session): return numpy_dataset.one_host_numpy_dataset( numpy_input, self._host_input_device, session) diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py index c61af88e28c..ed67d62d3dd 100644 --- a/tensorflow/python/distribute/one_device_strategy.py +++ b/tensorflow/python/distribute/one_device_strategy.py @@ -105,6 +105,11 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1): del destinations return tensor + def _experimental_distribute_dataset(self, dataset): + # Note that split_batch_by argument is not passed because it is always 1 in + # this strategy, and adding it adds unnecessary overhead to the dataset. + return input_lib.get_distributed_dataset(dataset, self._input_workers) + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, initial_loop_values=None): diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index d507762c2e8..89812a1610d 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -282,6 +282,10 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): def _validate_colocate_with_variable(self, colocate_with_variable): values.validate_colocate(colocate_with_variable, self) + def _experimental_distribute_dataset(self, dataset): + return input_lib.get_distributed_dataset(dataset, self._input_workers, + self._num_replicas_in_sync) + def _make_dataset_iterator(self, dataset): return input_lib.DatasetIterator(dataset, self._input_workers, self._num_replicas_in_sync) diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index c3815d83517..fdfde162974 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -325,6 +325,10 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): numpy_input, numpy_dataset.SingleDevice(self._host_device), session) + def _experimental_distribute_dataset(self, dataset): + return input_lib.get_distributed_dataset(dataset, self._input_workers, + self._num_replicas_in_sync) + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # a mechanism to infer the outputs of `fn`. Pending b/110550782. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt index c5ec4ab8a95..498fc109665 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-mirrored-strategy.pbtxt @@ -24,6 +24,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-one-device-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-one-device-strategy.pbtxt index e2e0f32252f..4f6bf1a6a37 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-one-device-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-one-device-strategy.pbtxt @@ -24,6 +24,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt index 687ce6662ac..7daf95b5d8d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt @@ -23,6 +23,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt index 30f3e80274b..63fd73498a8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt @@ -24,6 +24,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt index 142c681d5a3..8129a9c9146 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt @@ -24,6 +24,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt index 73637e25213..14648a38582 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt @@ -24,6 +24,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt index a8cc1d7c579..8273c9981a5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt @@ -28,6 +28,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt index 3775a1f1a03..d4336e674c7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt @@ -23,6 +23,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt index df0c2e650cb..f460ace8cbc 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt @@ -23,6 +23,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt index 4ebeefac4ea..941bc034063 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt @@ -22,6 +22,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt index 1eb2ede8460..c0e2e8b2389 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt @@ -23,6 +23,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt index 09c95633907..3547b085a6f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt @@ -23,6 +23,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt index 7a890b98701..230b8515114 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt @@ -23,6 +23,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt index 651721d22dd..a2feb6191ab 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt @@ -23,6 +23,10 @@ tf_class { name: "configure" argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " } + member_method { + name: "experimental_distribute_dataset" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"